我目前有一个已使用加载的数据集
(ds_train, ds_test, ds_val), ds_info = tfds.load('speech_commands', split=splits, data_dir=Flags.data_dir, with_info=True)
已经被映射和操纵了几次。
我想在 Pytorch 中使用它作为 pytorch 数据集(带有数据加载器),但这要求它是可订阅的(带有
__getitem__
)。
实现此目的的一种方法是使用张量流 data_source 对象,如此处所述。
但是,这从构建器开始,而我已经从
tfds.load
获得了一些东西,它已经调用了 builder.as_dataset
而不是 builder.as_data_source
有什么方法可以将我当前的数据集
(ds_train, ds_test, ds_val)
转换成数据源供外部使用吗?
我认为可能的另一种方法是
tf.data.Dataset.as_numpy_iterator()
,但这又返回一个不可索引的迭代器。
TL;DR:您必须以支持随机访问的文件格式重新生成数据。
tfds.data_source
需要 array_record
数据格式,因为它支持随机访问。因此,您必须使用 file_format='array_record'
准备数据。当省略 file_format
时,它隐式默认为 tfrecord,这就是为什么您无法受益于随机访问。要解决此问题:
tfds.data_source(
'speech_commands',
split=splits,
data_dir=Flags.data_dir,
builder_kwargs={'file_format': 'array_record'}, # For random access
)
如果数据集已存在于
data_dir
中,则必须将其从 Flags.data_dir
中删除才能下载并使用正确的文件格式重新准备。或者,您可以在另一个data_dir
中下载并准备它。
file_format
可以与 tfds.data_source
一起使用(如上面的代码片段所示,这相当于 tfds.load)或与 tfds.builder
(tfds.builder(..., file_format='array_record')
) 一起使用。