我正在尝试在 pytorch 中创建自定义数据集和数据加载器,以微调 DONUT 模型。对于上下文,我的数据集组织如下:
dataset/
├── train/
│ ├── image1.jpg
│ ├── image2.jpg
│ ├── metadata.jsonl
│ └── ...
├── validation/
│ ├── image1.jpg
│ ├── image2.jpg
│ ├── metadata.jsonl
│ └── ...
└── ...
我已经编写了我的自定义数据集类:
class DonutOCRDataset(torch.utils.data.Dataset):
def __init__(self, root_dir, transform):
self.root_dir = root_dir
self.transform = transform
self.data = self.load_data()
def load_data(self):
data = []
for folder in os.listdir(self.root_dir): # Use self.root_dir here
folder_path = os.path.join(self.root_dir, folder)
if os.path.isdir(folder_path):
metadata_path = os.path.join(folder_path, 'metadata.jsonl')
with open(metadata_path, 'r') as f:
metadata = [json.loads(line) for line in f]
data.extend([(item["file_name"], item["ground_truth"]) for item in metadata])
return data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_path, label = self.data[idx]
img_path_full = os.path.join(self.root_dir, img_path)
print(f"Loading image: {img_path_full}")
img = Image.open(img_path_full).convert('RGB')
if self.transform:
img = self.transform(img)
return img, label
下面我尝试定义我的转换,实例化数据集和数据加载器:
# Define root_dir
root_dir = r'C:\Users\Company\Documents\.....\240111_donut_1\dataset'
# Define your transformation
transform = transforms.Compose([
transforms.Resize((640, 460)),
transforms.ToTensor(),
])
# Instantiate the dataset
train_dataset = DonutOCRDataset(os.path.join(root_dir, 'train'), transform=transform)
val_dataset = DonutOCRDataset(os.path.join(root_dir, 'validation'), transform=transform)
batch_size = 8
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size, shuffle=False)
但是,我后来发现当我打印 len(train_dataset) 和 len(val_dataset) 时,它们都返回 0。
有人知道我的代码有什么问题吗?
在
DonutOCRDataset
训练对象中,你通过的root_dir
是dataset/train
。然后,在 load_data()
中,您正在该目录中查找子目录(位于 if os.path.isdir(folder_path)
),这些子目录似乎不存在于您的目录结构中。因此,if 条件可能永远不会满足,并且数据集对象中的 self.data
仍然是一个空列表,长度为零。删除 load_data()
中的 for 循环应该可以解决问题。