我正在试验https://github.com/astirn/IIC中的聚类模型(已经尝试与他联系)
[与大多数研究论文一样,它使用Mnist数据集。在这里,他们首先将数据集名称定义为“ mnist”,这足以使张量流从其标准在线数据集中导入mnist。然后,他使用tensorflow_dataset.load()函数加载数据集
我已经为我的数据集创建了一个tfrecord文件,现在我只需要替换前面脚本指向“ mnist”的部分(下面代码中的第1行),而不是指向我的本地数据集。
我是否只用第一行的文件路径替换'mnist'?? >>
来自实际训练模型文件的代码:
if __name__ == '__main__': # pick a data set DATA_SET = 'mnist' # define splits DS_CONFIG = { # mnist data set parameters 'mnist': { 'batch_size': 700, 'num_repeats': 5, 'mdl_input_dims': [24, 24, 1]} } # load the data set TRAIN_SET, TEST_SET, SET_INFO = load(data_set_name=DATA_SET, **DS_CONFIG[DATA_SET]) # configure the common model elements MDL_CONFIG = { # mist hyper-parameters 'mnist': { 'num_classes': SET_INFO.features['label'].num_classes, 'learning_rate': 1e-4, 'num_repeats': DS_CONFIG[DATA_SET]['num_repeats'], 'save_dir': None}, }
来自“数据准备文件”的代码,其中他使用tensorflor_dataset.load作为tfds.load调用数据集:
def load(data_set_name, **kwargs): """ :param data_set_name: data set name--call tfds.list_builders() for options :return: train_ds: TensorFlow Dataset object for the training data test_ds: TensorFlow Dataset object for the testing data info: data set info object """ # get data and its info ds, info = tfds.load(name=data_set_name, split=tfds.Split.ALL, with_info=True)
感谢您的帮助
我正在尝试使用https://github.com/astirn/IIC的聚类模型(已经尝试与他联系)。它像大多数研究论文一样使用Mnist数据集。他们首先在这里定义...
根据docs,您需要将download
参数用作False
,并将data_dir
与目录名称一起使用: