我想加载一批不同分辨率的图像,并将它们动态分割成大小相等的不重叠的补丁,以将它们提供给 Resnet18 模型,PyTorch 中是否有一个现有的转换类可以做到这一点,如果没有,怎么做我实现了自己的课程。
这是代码:
transform = transforms.Compose([
ImageResizer(), # Custom class to resize the image to the next multiple of 224 (takes as input PIL image and returns PIL image)
#Patch(patch_size=(224, 224)), # Custom class to divide the image into patches of 224x224 (takes as input PIL image and returns a list of PIL images)
transforms.ToTensor(),
])
dataset = ImageFolder(root="<path>", transform=transform)
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
这是我的 ImageResizer 代码的样子:
class ImageResizer:
"""
A class to resize the image to the next multiple of 224, so that the images can be divided into 224x224 patches later.
"""
def __init__(self):
pass
def get_new_dimensions(width : int, height : int, patch_height : int = 224, patch_width : int = 224):
"""
Get the new dimensions of the image after resizing.
Parameters:
- width: The width of the image.
- height: The height of the image.
- patch_height: The height of the patch.
- patch_width: The width of the patch.
Returns:
- new_height: The new height of the image.
- new_width: The new width of the image.
"""
width_coef = int(np.round(width / patch_width).astype(np.int32))
height_coef = int(np.round(height / patch_height).astype(np.int32))
new_width = width_coef * patch_width
new_height = height_coef * patch_height
return new_width, new_height
def __call__(self, image):
"""
Resize the given image to the next multiple of 224.
Parameters:
- image: an image of type pillow.
Returns:
- resized_image: The resized image of type pillow.
"""
width, height = image.size
new_width, new_height = ImageResizer.get_new_dimensions(width, height)
# Resize the image
resized_image = image.resize((new_width, new_height))
return resized_image
transforms
预计将一个数据点(本例中为图像)作为输入并返回单个转换后的数据点,因此目前不可能使用自定义 transform
修补图像并返回修补程序列表.
一个可能的解决方案是为
collate_fn
函数提供自定义实现,并将其作为参数传递给 DataLoader
类。
collate_fn
函数将元组列表作为输入(元组的第一个元素是数据点,第二个元素是标签),并返回两个张量的元组,第一个张量代表一批图像,第二个代表对应的标签。
您可以在下面找到您想要的功能的可能实现:
def make_paches(
img : torch.Tensor,
patch_width : int,
patch_height : int
) -> list[torch.Tensor]:
patches = img \
.unfold(1,patch_width,patch_width) \
.unfold(2,patch_height,patch_height) \
.flatten(1,2) \
.permute(1,0,2,3)
patches = list(patches)
return patches
def collate_fn(batch : list[tuple[torch.Tensor, int]]) -> tuple[torch.Tensor, torch.Tensor]:
new_x = []
new_y = []
for x, y in batch:
patches = make_paches(x, 224, 224)
new_x.extend(patches)
new_y.extend([y for _ in range(len(patches))])
new_x = torch.stack(new_x)
new_y = torch.tensor(new_y)
return new_x,new_y
dataset = datasets.ImageFolder(root="<your-path>", transform=transform)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
torch.Tensor.unfold
修补图像张量的可能解决方案:
class Patch(nn.Module):
def __init__(self, patch_size):
super().__init__()
self.patch_size = patch_size
def forward(self, x):
b, c, h, w = x.shape
ph, pw = self.patch_size
out = x.unfold(-2, ph, ph).unfold(-1, pw, pw)
out = out.contiguous().view(b, c, -1, ph, pw).permute(0,2,1,4,3)
return out