我正在处理文本摘要任务,并尝试将 .csv 数据集添加到
tensorflow_datasets
(这是运行预先训练的变压器所必需的)。我正在关注本教程 https://www.tensorflow.org/datasets/add_dataset 但我仍然不知道如何添加它。
这是我到目前为止所拥有的:
import tensorflow_datasets.public_api as tfds
# TODO(data.csv): BibTeX citation
_CITATION = """
"""
_HOMEPAGE = "https:..."
# TODO(data.csv):
_DESCRIPTION = """A textual corpus of ...
"""
_DOCUMENT = "text"
_SUMMARY = "summary"
manual_dir = './'
class new_dataset(tfds.core.GeneratorBasedBuilder):
"""TODO(data.csv): Short description of my dataset."""
# TODO(data.csv): Set up version.
VERSION = tfds.core.Version('0.1.0')
def _info(self):
return tfds.core.DatasetInfo(
builder=self,
description=_DESCRIPTION,
features=tfds.features.FeaturesDict({
_DOCUMENT: tfds.features.Text(),
_SUMMARY: tfds.features.Text()
}),
supervised_keys=(_DOCUMENT, _SUMMARY),
homepage="https://...",
citation=_CITATION,
)
def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
# TODO(data.csv): Downloads the data and defines the splits
# dl_manager is a tfds.download.DownloadManager that can be used to
# download and extract URLs
return [
tfds.core.SplitGenerator(
name=tfds.Split.TRAIN,
# These kwargs will be passed to _generate_examples
gen_kwargs={},
),
]
def _generate_examples(self):
# Yields examples from the dataset
yield 'key', {}
如果我的数据集是包含 2 列的 .csv 文件:“文本”和“摘要”,如何正确定义
def _split_generators
和 def _generate_examples
?这个Python文件和我的数据集在同一个目录中。
我按照 Tensorflow 上的说明进行操作:https://www.tensorflow.org/datasets/add_dataset。
您必须将原始 csv 或 txt 数据按特定的目录顺序放置。我创建了一个名为 dataset_DATA 的文件夹,然后在 dataset_DATA 中创建了 2 个文件夹(训练和测试)。然后我将 dataset_train.csv 放在 ../dataset_DATA/train 中,将 dataset_test.csv 放在 ../dataset_DATA/test 中。然后,我压缩了文件夹 dataset_DATA (zip -r dataset_DATA.zip dataset_DATA/)
接下来我使用 TFDS CLI 并运行 (tfds new dataset_name)。它创建了一个名为 dataset_name 的文件夹(他们称之为项目文件夹)。
接下来,我在 dataset_name 项目文件夹中创建了两个文件夹(downloads、downloads/manual),用于存储手动下载的数据(mkdir /home/path/dataset_name/downloads、mkdir /home/path/dataset_name/downloads/manual)。然后,我将数据移至 downloads/manual 文件夹 (mv /home/path/dataset_DATA.zip /home/path/dataset_name/downloads/manual)。
接下来,我修改了 /home/path/dataset_name/dataset_name_dataset_builder.py 文件。以下是我用于文本的内容。
def _info(self -> tfds.core.DatasetInfo: return self.dataset_info_from_configs(features=tfds.features.FeaturesDict({'idx': tfds.features.Scalar(dtype=tf.int32), 'label': tfds.features.ClassLabel(num_classes=3), 'sentence': tfds.features.Text(),}), disable_shuffling=False, homepage='') def _split_generators(self, dl_manager): tfds.download.DownloadManager): # Way 0: Download from GCP Storage extracted_path = dl_manager.download_and_extract('https://storage.googleapis.com/BUCKET_NAME/dataset_DATA.zip') # OR
# Way 1: Manually download the file to the directory, and link to current directory
archive_path = dl_manager.manual_dir/'dataset_DATA.zip'
extracted_path = dl_manager.extract(archive_path)
return {
'train': self._generate_examples(path=extracted_path/'dataset_DATA/train/dataset_train.csv'),
'test': self._generate_examples(path=extracted_path/'dataset_DATA/test/dataset_test.csv')
}
def _generate_examples(self, path):
with path.open() as f:
for row in csv.DictReader(f):
# my csv file has 3 columns ['idx', 'label', 'sentence'] where the first row has the header ['idx', 'label', 'sentence']
# key = csv file column 'idx' = [0, 1, 2, .., n]
# csv file column 'label' = [0, 1, or 2]
# csv file column 'sentence' = ['text0', 'text1', .., 'textn']
key = row['idx']
yield key, {'idx': row['idx'], 'label': row['label'], 'sentence': row['sentence']}
我运行了 tfds build 命令(tfds build --overwrite --data_dir=/home/path/dataset_name/ --download_dir=/home/path/dataset_name/downloads/ --extract_dir=/home/path/dataset_name/downloads /extracted/ --manual_dir=/home/path/dataset_name/downloads/manual --file_format='tfrecord' --publish_dir=/home/path/dataset_name/dataset_name/ /home/path/dataset_name)。请务必在运行 tfds 之前创建 extract_dir 和publish_dir 目录,否则会出现错误(即:mkdir /home/path/dataset_name/downloads/extracted/)。
数据集应以tfrecord格式保存在/home/path/dataset_name/dataset_name/1.0.0。您可以使用构建器加载数据集(即:builder = tfds.builder_from_directory('/home/path/dataset_name/dataset_name/1.0.0'), ds = builder.as_dataset(split='train[75%:]') ).