假设我有一个分割模型(
model
),我想将其预测批量转换为枕头图像。而且,为了简单起见,我们假设一切都是在 CPU 上完成的(不涉及 GPU)。
如果我这样做:
import torch
from torchvision.transforms import ToPILImage
transform = ToPILImage()
model.eval()
for i, (x, y) in enumerate(dataloader):
y_hat = torch.sigmoid(model(x)) # returns a tensor (batch_size, 1, H, W)
y_hat = (y_hat > 0.5).float()
img = transform(y_hat)
我得到:
ValueError: pic should be 2/3 dimensional. Got 4 dimensions.
很公平。让我尝试使用
vmap
将其批量转换:
import torch
from torchvision.transforms import ToPILImage
transform = ToPILImage()
batch_transform = torch.func.vmap(transform)
model.eval()
for i, (x, y) in enumerate(dataloader):
y_hat = torch.sigmoid(model(x)) # returns a tensor (batch_size, 1, H, W)
y_hat = (y_hat > 0.5).float()
img = batch_transform(y_hat)
这会产生以下错误:
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
为什么会有这样的表现?它与我选择 vmap 的函数有什么关系吗?我遵循了文档中的模式,这应该可行。如何对一批图像执行此操作?
ToPILImage
变换对 2D (H,W)
或 4D (C, H, W)
张量进行操作。这意味着您必须迭代批处理元素并应用转换:
imgs = [transform(t) for t in y_hat]
torchvision.utils.make_grid
从张量列表构造网格:
img = transform(make_grid(y_hat))
torchvision.utils.save_image
,可以调用make_grid,转换为PIL.Image
,并保存到文件系统:
save_image(y_hat, 'pred.jpg')