pytorch对象在保存图像时对于数组来说太深了

问题描述 投票:0回答:2

我正在尝试从以下github rep运行代码:

https://github.com/iamkrut/image_inpainting_resnet_unet

我没有更改代码中的任何内容,当代码尝试保存图像时,它导致ValueError,对象太深。错误似乎来自这两行。

images = img_tensor.cpu().detach().permute(0,2,3,1)
plt.imsave(join(data_dir, 'samples', image), images[index,:,:,:3])

这里是错误说明

  File "train.py", line 205, in <module>
    data_dir=args.data_dir)
  File "train.py", line 94, in train_net
    plt.imsave(join(data_dir, 'samples', image), images[index,:,:,:]);
  File "C:\ProgramData\Anaconda3\envs\torch2\lib\site-packages\matplotlib\pyplot.py", line 2140, in imsave
    return matplotlib.image.imsave(fname, arr, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\torch2\lib\site-packages\matplotlib\image.py", line 1498, in imsave
    _png.write_png(rgba, fname, dpi=dpi)
ValueError: object too deep for desired array

任何人都知道可能是什么原因或如何解决?谢谢

python matplotlib pytorch
2个回答
0
投票
matplotlib软件包无法理解pytorch数据类型(张量)。您应该将张量数组转换为numpy数组,然后使用matplotlib函数。

a = torch.rand(10, 3, 20, 20) plt.imsave("test.jpg", a.cpu().detach().permute(0, 2, 3, 1)[0, ...]) # Error plt.imsave("test.jpg", a.cpu().detach().permute(0, 2, 3, 1).numpy()[0, ...])


0
投票
我设法通过将行更改为来修复代码

images=img_tensor.cpu().numpy()[0] images = np.transpose(images, (1,2,0)) plt.imsave(join(data_dir, 'samples', image), images)

仍不确定先前版本有什么问题。因此,如果有人知道,请告诉我。
© www.soinside.com 2019 - 2024. All rights reserved.