如何将各个 .npz 文件作为样本加载到 Tensorflow 数据集中

问题描述 投票:0回答:1

我在加载构成数据集的数据时遇到问题。我以前的(工作)方法是使用 pandas DataFrame,但对于较大的数据集,训练过程会被终止,因为数据占用了太多内存。所以我决定使用 TensorFlow 的 Dataset 类来克服这个问题,但我无法加载单个文件。

具体来说,我尝试加载各个.npz文件的各个文件路径作为示例,然后使用Dataset类的map方法单独加载.npz文件的内容。每个 .npz 文件都是形状 (1, x, x, z) 的 numpy 数组,并且包含在指定其标签名称的文件夹中。

这是我用来加载数据集的方法:

IMAGE_SUPPORTED_EXTENSIONS = ('.jpg', '.jpeg', '.png')

def load_dataset(self):
data = []

        for label in self.labels:
            folder = self.main_folder / label
            file_paths = [str(file_path) for file_path in folder.glob('*') if file_path.suffix in TENSOR_SUPPORTED_EXTENSIONS]
            latenst_spaces = [DatasetLoader.load_tensor(file_path) for file_path in folder.glob('*') if file_path.suffix in TENSOR_SUPPORTED_EXTENSIONS]
            dataset = tf.data.Dataset.from_tensor_slices(file_paths)
            
            # Zip dataset with labels
            dataset = dataset.map(lambda x: (x, label))
            
            dataset = dataset.map(map_function)
            data.append(dataset)
        
        # Concatenate datasets from different labels
        dataset = data[0]
        for i in range(1, len(data)):
            dataset = dataset.concatenate(data[i])
        
        return dataset

这是传递给map方法的函数:

def map_function(element):
    file_path, label = element
    npz_data = DatasetLoader.load_tensor(file_path)
    return (npz_data, label)
    @staticmethod

def load_tensor(file_path):
        file_path = tf.get_static_value(tf_tensor)
        file_path = Path(file_path)
        if file_path.suffix not in ('.npy', '.npz'):
            raise ValueError(f"Extension {file_path.suffix} not suppported.")
        try:
            with np.load(file_path) as tensor:
                if file_path.suffix == ".npz":
                    for _, item in tensor.items():
                        tensor = item
            return np.array(tensor).squeeze()
        except Exception as e:
            print(f"Error loading {file_path.stem} file: {str(e)}.", "\nFile path: ", file_path)
            raise RuntimeError(f"Error loading {file_path.stem} file: {str(e)}.") from e
numpy tensorflow dataset tensorflow2.0 tensorflow-datasets
1个回答
0
投票

这是一个使用 2 个数组创建

npz
,然后加载它们的示例:

In [9]: x,y = np.ones((2,3)), np.arange(5)
In [10]: np.savez('test.npz', **{'x':x, 'y':y})

负载:

In [12]: alist = []
    ...: with np.load('test.npz') as data:
    ...:     for i,v in data.items():
    ...:         print(i,v)
    ...:         alist.append(v)    
    ...:         
x [[1. 1. 1.]
 [1. 1. 1.]]
y [0 1 2 3 4]

In [13]: alist
Out[13]: 
[array([[1., 1., 1.],
        [1., 1., 1.]]),
 array([0, 1, 2, 3, 4])]
© www.soinside.com 2019 - 2024. All rights reserved.