我是新来的,我正在尝试使用 pytorch 学习 python 进行深度学习。我想使用 pytorch 编写服装数据集。我首先尝试了一个简单的任务,我想读取文件夹中的图像,然后将它们可视化
这是我尝试过的
class msr_data(Dataset):
def __init__(self, image_path, transform=None):
self.image_path = image_path
self.transform = transform
def __len__(self):
return len(self.msr_data)
def __getitem__(self, image_path):
images = []
for filename in os.listdir(image_path):
img = cv2.imread(os.path.join(image_path,filename))
im_rgb = img[:, :, ::-1]
images.append(im_rgb)
for i in images:
data=np.array(im_rgb)
return data
dataset = msr_data(image_path='files',transform=None)
for i,sample in enumerate(dataset):
print(i)
plt.imshow(sample)
plt.show()
但是我收到了上述错误。这是为什么?我该如何解决它?
每个 pytorch
Dataset
或此类的继承者都必须实现 __getitem__
函数,该函数需要整数索引作为输入。您应该使用此整数来索引数据集及其属性,并在 __init__
中初始化。
您的数据集有问题。他们是:
image_path
不用于在 __init__
中执行任何操作(例如计算数据示例的数量并存储这些路径)。您在 __getitem__
中枚举文件,这很浪费,因为每次从数据集中索引项目时都必须执行此操作。__len__
未返回正确的长度并且__len__
打电话给self.msr_data
。 msr_data
类没有属性 self.msr_data
。如果是这样,这将是一个无限递归调用。__getitem__
接收到一个字符串。这违反了父 pytorch.data.Dataset 类定义的函数原型,该类需要整数索引。纠正这些问题(变换除外):
class MSR_Data(Dataset): # Capitals preferred for class names but this is stylistic only
def __init__(self, image_path, transform=None):
# store all image paths in a list
image_files = [os.path.join(image_path,item) for item in os.listdir(image_path)]
self.images = []
for path in image_files:
img = cv2.imread(path)
im_rgb = img[:, :, ::-1]
self.images.append(im_rgb)
# store transform
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self,idx):
data=np.array(self.images[idx])
return data
数据集 = msr_data(image_path='文件',transform=None) 对于 i,枚举(数据集)中的样本: 打印(一) plt.imshow(示例) plt.show()
在这些更改之后,您应该拥有 pytorch 数据集的 MWE,尽管没有转换、标签、分区或任何其他动态增强,此类与简单列表相比没有任何优势(事实上,它只是围绕列表的包装类)图片)。我建议查看 pytorch 关于数据集对象的教程。