Pytorch 数据集 - len(train_dataset) 返回零

问题描述 投票:0回答:1

我正在尝试在 pytorch 中创建自定义数据集和数据加载器,以微调 DONUT 模型。对于上下文,我的数据集组织如下:

dataset/
├── train/
│   ├── image1.jpg
│   ├── image2.jpg
│   ├── metadata.jsonl
│   └── ...
├── validation/
│   ├── image1.jpg
│   ├── image2.jpg
│   ├── metadata.jsonl
│   └── ...
└── ...

我已经编写了我的自定义数据集类:

class DonutOCRDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform):
        self.root_dir = root_dir
        self.transform = transform
        self.data = self.load_data()
    
    def load_data(self):
        data = []
        for folder in os.listdir(self.root_dir):  # Use self.root_dir here
            folder_path = os.path.join(self.root_dir, folder)
            if os.path.isdir(folder_path):
                metadata_path = os.path.join(folder_path, 'metadata.jsonl')
                with open(metadata_path, 'r') as f:
                    metadata = [json.loads(line) for line in f]
                data.extend([(item["file_name"], item["ground_truth"]) for item in metadata])
        return data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        img_path_full = os.path.join(self.root_dir, img_path)
        print(f"Loading image: {img_path_full}")
        img = Image.open(img_path_full).convert('RGB')

        if self.transform:
            img = self.transform(img)

        return img, label

下面我尝试定义我的转换,实例化数据集和数据加载器:

# Define root_dir
root_dir = r'C:\Users\Company\Documents\.....\240111_donut_1\dataset'


# Define your transformation
transform = transforms.Compose([
    transforms.Resize((640, 460)),
    transforms.ToTensor(),
])

# Instantiate the dataset
train_dataset = DonutOCRDataset(os.path.join(root_dir, 'train'), transform=transform)
val_dataset = DonutOCRDataset(os.path.join(root_dir, 'validation'), transform=transform)

batch_size = 8
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size, shuffle=False)

但是,我后来发现当我打印 len(train_dataset) 和 len(val_dataset) 时,它们都返回 0。

有人知道我的代码有什么问题吗?

pytorch dataset dataloader donut
1个回答
0
投票

DonutOCRDataset
训练对象中,你通过的
root_dir
dataset/train
。然后,在
load_data()
中,您正在该目录中查找子目录(位于
if os.path.isdir(folder_path)
),这些子目录似乎不存在于您的目录结构中。因此,if 条件可能永远不会满足,并且数据集对象中的
self.data
仍然是一个空列表,长度为零。删除
load_data()
中的 for 循环应该可以解决问题。

© www.soinside.com 2019 - 2024. All rights reserved.