使用 PyTorch vmap 时出现“运行时错误:无法访问没有存储的张量的数据指针”

问题描述 投票:0回答:1

假设我有一个分割模型(

model
),我想将其预测批量转换为枕头图像。而且,为了简单起见,我们假设一切都是在 CPU 上完成的(不涉及 GPU)。

如果我这样做:

import torch
from torchvision.transforms import ToPILImage

transform = ToPILImage()
model.eval()
for i, (x, y) in enumerate(dataloader):
    y_hat = torch.sigmoid(model(x))  # returns a tensor (batch_size, 1, H, W)
    y_hat = (y_hat > 0.5).float()
    img = transform(y_hat)

我得到:

ValueError: pic should be 2/3 dimensional. Got 4 dimensions.

很公平。让我尝试使用

vmap
将其批量转换:

import torch
from torchvision.transforms import ToPILImage

transform = ToPILImage()
batch_transform = torch.func.vmap(transform)
model.eval()
for i, (x, y) in enumerate(dataloader):
    y_hat = torch.sigmoid(model(x))  # returns a tensor (batch_size, 1, H, W)
    y_hat = (y_hat > 0.5).float()
    img = batch_transform(y_hat)

这会产生以下错误:

RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

为什么会有这样的表现?它与我选择 vmap 的函数有什么关系吗?我遵循了文档中的模式,这应该可行。如何对一批图像执行此操作?

python pytorch image-segmentation torchvision
1个回答
0
投票

如错误消息所示,

ToPILImage
变换对 2D
(H,W)
或 4D
(C, H, W)
张量进行操作。这意味着您必须迭代批处理元素并应用转换:

imgs = [transform(t) for t in y_hat]

或者,您可以使用

torchvision.utils.make_grid
从张量列表构造网格:

img = transform(make_grid(y_hat))

有一个方便的

torchvision.utils.save_image
,可以调用make_grid,转换为
PIL.Image
,并保存到文件系统:

save_image(y_hat, 'pred.jpg')
© www.soinside.com 2019 - 2024. All rights reserved.