我在加载构成数据集的数据时遇到问题。我以前的(工作)方法是使用 pandas DataFrame,但对于较大的数据集,训练过程会被终止,因为数据占用了太多内存。所以我决定使用 TensorFlow 的 Dataset 类来克服这个问题,但我无法加载单个文件。
具体来说,我尝试加载各个.npz文件的各个文件路径作为示例,然后使用Dataset类的map方法单独加载.npz文件的内容。每个 .npz 文件都是形状 (1, x, x, z) 的 numpy 数组,并且包含在指定其标签名称的文件夹中。
这是我用来加载数据集的方法:
IMAGE_SUPPORTED_EXTENSIONS = ('.jpg', '.jpeg', '.png')
def load_dataset(self):
data = []
for label in self.labels:
folder = self.main_folder / label
file_paths = [str(file_path) for file_path in folder.glob('*') if file_path.suffix in TENSOR_SUPPORTED_EXTENSIONS]
latenst_spaces = [DatasetLoader.load_tensor(file_path) for file_path in folder.glob('*') if file_path.suffix in TENSOR_SUPPORTED_EXTENSIONS]
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
# Zip dataset with labels
dataset = dataset.map(lambda x: (x, label))
dataset = dataset.map(map_function)
data.append(dataset)
# Concatenate datasets from different labels
dataset = data[0]
for i in range(1, len(data)):
dataset = dataset.concatenate(data[i])
return dataset
这是传递给map方法的函数:
def map_function(element):
file_path, label = element
npz_data = DatasetLoader.load_tensor(file_path)
return (npz_data, label)
@staticmethod
def load_tensor(file_path):
file_path = tf.get_static_value(tf_tensor)
file_path = Path(file_path)
if file_path.suffix not in ('.npy', '.npz'):
raise ValueError(f"Extension {file_path.suffix} not suppported.")
try:
with np.load(file_path) as tensor:
if file_path.suffix == ".npz":
for _, item in tensor.items():
tensor = item
return np.array(tensor).squeeze()
except Exception as e:
print(f"Error loading {file_path.stem} file: {str(e)}.", "\nFile path: ", file_path)
raise RuntimeError(f"Error loading {file_path.stem} file: {str(e)}.") from e
这是一个使用 2 个数组创建
npz
,然后加载它们的示例:
In [9]: x,y = np.ones((2,3)), np.arange(5)
In [10]: np.savez('test.npz', **{'x':x, 'y':y})
负载:
In [12]: alist = []
...: with np.load('test.npz') as data:
...: for i,v in data.items():
...: print(i,v)
...: alist.append(v)
...:
x [[1. 1. 1.]
[1. 1. 1.]]
y [0 1 2 3 4]
In [13]: alist
Out[13]:
[array([[1., 1., 1.],
[1., 1., 1.]]),
array([0, 1, 2, 3, 4])]