Python 代码 - 加载自己创建的 pickle 文件时出现 KeyError,但同样适用于通用 caltech 数据集

问题描述 投票:0回答:0

参考代码:https://github.com/med-air/FedBN/blob/master/federated/fed_office.py

我正在测试使用 office-caltech 数据集的代码。 该数据集采用 pickle 格式。 当我阅读 pickle 文件时,内容和格式是这样的

(数组(['office_caltech_10/webcam/laptop_computer/frame_0016.jpg', 'office_caltech_10/webcam/mouse/frame_0003.jpg', 'office_caltech_10/webcam/laptop_computer/frame_0020.jpg', .... 'office_caltech_10/webcam/back_pack/frame_0012.jpg'], dtype='

源码:下https://github.com/med-air/FedBN/blob/master/utils/data_utils.py

class OfficeDataset(Dataset):
 def __init__(self, base_path, site, train=True, transform=None):
    if train:
        self.paths, self.text_labels = np.load('../data/office_caltech_10/{}_train.pkl'.format(site), allow_pickle=True)
    else:
        self.paths, self.text_labels = np.load('../data/office_caltech_10/{}_test.pkl'.format(site), allow_pickle=True)
        
    label_dict={'back_pack':0, 'bike':1, 'calculator':2, 'headphones':3, 'keyboard':4, 'laptop_computer':5, 'monitor':6, 'mouse':7, 'mug':8, 'projector':9}
    self.labels = [label_dict[text] for text in self.text_labels]
    self.transform = transform
    self.base_path = base_path if base_path is not None else '../data'

 def __len__(self):
    return len(self.labels)

 def __getitem__(self, idx):
    img_path = os.path.join(self.base_path, self.paths[idx])
    label = self.labels[idx]
    image = Image.open(img_path)

    if len(image.split()) != 3:
        image = transforms.Grayscale(num_output_channels=3)(image)

    if self.transform is not None:
        image = self.transform(image)

    return image, label

我尝试使用我自己的图像类型数据集创建一个 pickle 文件。 /../../aData/SpFolder 下的两个文件夹/train & /val /train 和 /val 下有相同的一组 6 个子文件夹,每个文件夹有 20-40 个 .jpg 文件

/train 
    /brigade
     -- p1.jpg
   /navy
     --p100.jpg
  /army
 /air
/force
/gator

我将这两个文件夹(/train 和 /val)转换为 pickle 文件作为 aData_train.pkl 和 aData_test.pkl

源代码:转换为泡菜格式

  import os
  import numpy as np
  import pickle as pkl
  import os.path
  import imageio
  import imageio.v2 as imageio

  train_dirpath = r'/home/11/22/33/data/train'
  test_dirpath = r'/home/11/22/33/data/val'

  # list to store files
  Trainfiles = []
  Testfiles = []

  classNames_train = []
  classNames_test = []

  dirs = os.listdir(train_dirpath)
  total_dirs = len(dirs)

  for i, file in enumerate(dirs):
    if os.path.isdir(train_dirpath):    
       newpath = os.path.join(train_dirpath,file)
       for fname in os.listdir(newpath):
           if fname.endswith(".jpg"):
              full_path = os.path.join(newpath,fname)        
              image = imageio.imread(full_path)            
              label = fname.split('.')
           
              Trainfiles.append(full_path)
              #Add the .classname (ie b1/c1...g1) in ths array
              classNames_train.append(file)
              file_count_train = file_count_train+1
           
    #Add data type of the image at end of Trainfiles        
    Trainfiles.append(image.dtype)
    #Add class names at the end of Trainfiles        
    Trainfiles.append(classNames_train)
      
           
  with open('/home/11/22/33/data/aData_train.pkl', 'wb') as f:
     pkl.dump(Trainfiles, f)
     f.close() 

当我阅读 .pkl 文件时,这是与 caltech pickle 格式不匹配的格式。

'/home/11/22/33/aData/SpFolder/train/brigade/p1.jpg','/home/11/22/33/aData/SpFolder/train/brigade/p5.jpg','/home /11/22/33/aData/SpFolder/train/brigade/Control.000.jpg','/home/11/22/33/aData/SpFolder/train/brigade/1-Fail.000.jpg',' /home/11/22/33/aData/SpFolder/train/navy/n11-00-1.jpg','/home/11/22/33/aData/SpFolder/train/navy/nT65.jpg',' /home/11/22/33/aData/SpFolder/train/navy/Control11.00.jpg','/home/11/22/33/aData/SpFolder/train/navy/1-2.000.jpg',' /home/11/22/33/aData/SpFolder/train/army/a525.jpg','/home/11/22/33/aData/SpFolder/train/army/GHT56-00.jpg',dtype(' uint8'), ['旅', '海军', '陆军']

我将 aData_train.pkl 传递给下面的代码

源代码:

     if train:
          self.paths, self.text_labels=np.load('../data/aData/aData_train.pkl'.format(site), allow_pickle=True)

错误是 valueError: too many values to unpack (expected 2)

我将线拆分为

  self.text_labels = np.load('../data/aData/aData_train.pkl'.format(site), allow_pickle=True)

当执行上面的代码时,错误是

KeyError: '/home/11/22/33/aData/SpFolder/train/brigade/p1.jpg'

整个代码适用于 Caltech 数据集

如何为我的数据集获取与加州理工学院相同的格式?我猜 pickleconversion 编码是问题。不确定如何获得与加州理工学院一样的格式。

环境:使用 Python 3.8

python pickle
© www.soinside.com 2019 - 2024. All rights reserved.