我想使用 matplotlib imsave() 将包含单色图像数据的 2D numpy 数组保存到 png 文件,并使用颜色图 (viridis)。该数组包含应绘制为透明的 np.nan 值,imsave 不能直接处理该值,因此我在将其提供给 imsave 之前将其转换为 RGBA。只要我将 imsave() 上的 origin 选项保留在默认的“上部”上,这就可以完美地工作。如果我将其更改为“lower”,matplotlib 会抛出“ndarray 不是 C 连续的”错误。 (实际上它来自底层的 PIL 库)。
这里出了什么问题?
自包含示例:
import numpy as np
from matplotlib import cm
import matplotlib.pyplot as plt
#create test data
image_no_nan = np.eye(100)
image = np.copy(image_no_nan)
image[image == 0] = np.nan
# to map the image to RGBA, scale to [0-255]
colmap = plt.get_cmap("viridis", 256)
lut = (colmap.colors[..., 0:4] * 255).astype(np.uint8)
rescaled = (
(image_no_nan.astype(float) - image_no_nan.min())
* 255
/ (image_no_nan.max() - image_no_nan.min())
).astype(np.uint8)
result = np.zeros((*rescaled.shape, 4), dtype=np.uint8)
# Take entries from RGB LUT according to greyscale values in image
result = np.take(lut, rescaled, axis=0, out=result)
# apply mask
mask = np.zeros((rescaled.shape), dtype=np.uint8)
mask[~np.isnan(image)] = 255
result[:,:,3]= mask
# try fixing the upcoming ndarray is not C-contiguous error
result = result.copy(order="C") # doesn't affect error
result = np.ascontiguousarray(result) # doesn't affect error
print(result.flags) # the ndarray is actually C-contiguous
plt.imsave(fname="test_upper.png", arr=result, format="png", origin="upper")# no problem
plt.imsave(fname="test_lower.png", arr=result, format="png", origin="lower")# error
对我来说,这看起来像是 matplotlib 中的一个错误。
以下是触发它所需的条件:
这是一个最小的复制案例:
import numpy as np
import matplotlib.pyplot as plt
result = np.zeros((100, 100, 4), dtype='uint8')
print(result.flags) # the ndarray is actually C-contiguous
plt.imsave(fname="test_upper.png", arr=result, format="png", origin="upper")# no problem
plt.imsave(fname="test_lower.png", arr=result, format="png", origin="lower")# error
研究 matplotlib 的代码,我发现了这个:
if origin == "lower":
arr = arr[::-1]
if (isinstance(arr, memoryview) and arr.format == "B"
and arr.ndim == 3 and arr.shape[-1] == 4):
# Such an ``arr`` would also be handled fine by sm.to_rgba below
# (after casting with asarray), but it is useful to special-case it
# because that's what backend_agg passes, and can be in fact used
# as is, saving a few operations.
rgba = arr
else:
sm = cm.ScalarMappable(cmap=cmap)
sm.set_clim(vmin, vmax)
rgba = sm.to_rgba(arr, bytes=True)
if pil_kwargs is None:
pil_kwargs = {}
else:
# we modify this below, so make a copy (don't modify caller's dict)
pil_kwargs = pil_kwargs.copy()
pil_shape = (rgba.shape[1], rgba.shape[0])
image = PIL.Image.frombuffer(
"RGBA", pil_shape, rgba, "raw", "RGBA", 0, 1)
代码链接。
如果
origin == "lower"
,则第一个数组以零拷贝方式反转。如果发生这种情况,则 arr
不再是 C 连续的。然后它使用 ScalarMappable 转换为 rgba。但是,如果输入已经是 rgba 格式,则不会复制它。因此,使用 RGB 可以掩盖该错误,因为副本将是 C 连续的。
然后调用
PIL.Image.frombuffer
,这似乎假设其输入是 C 连续的。 (Pillow 似乎没有记录这个假设,所以这实际上可能是一个 Pillow bug。)
作为解决方法,您可以通过以下代码避免使用
origin="lower"
:
result = np.ascontiguousarray(result[::-1])
plt.imsave(fname="test_upper.png", arr=result, format="png")
我建议你在 matplotlib 中打开一个问题,这样就可以为未来的用户解决这个问题。