在Pyplot中显示预测的分段输出数据

问题描述 投票:0回答:1

我正在尝试从U-net绘制输出数据数组。该数组包含经过单次热编码以进行图像分割的mnist数据。

它的形状是:(28,28,11)

因此,对于原始图像中像素值为0的每个位置,单次热编码将放置[0 0 0 0 0 0 0 0 0 0 0 1 1]的数组,指示该像素为空白。

另一方面,如果像素值> 0,则一个热阵列将显示整个图像的分类。

EX:如果mnist图像为2,则值> 0的每个像素都将变成数组[0 0 1 0 0 0 0 0 0 0 0 0]。

我想知道是否有一种方法可以显示这样的数组,该数组的每个元素都由一个热数组组成。

我试图在数据上仅使用plt.imshow,但是,它抛出一个错误,指出“ TypeError:图像数据的尺寸无效”

这是我正在使用的代码

from keras import applications
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
import skimage.transform
import cv2
import sys
from keras import Input
from keras import backend as K
from keras.utils import np_utils
from keras.models import Sequential, Model 
from keras.utils import to_categorical
from keras.losses import categorical_crossentropy
from keras.optimizers import adam
from keras.layers import Conv2D, Dense, MaxPooling2D, Flatten, Dropout, GlobalAveragePooling2D
from keras.datasets import cifar10
from keras.datasets import mnist
from keras.utils import np_utils
import matplotlib.pyplot as plt
import numpy as np
np.set_printoptions(threshold=sys.maxsize)

(x_train, y_train), (x_test, y_test) = mnist.load_data()

y_train = y_train[:10]

data = np.random.choice(255, (10,128,128))


## do you calculation of brightness here
## and expand it to one row per pixel
arr = data.reshape(-1,1)/255
## repeat labels to match the expanded pixel
labels = y_train.repeat(128*128).reshape(-1,1)

ind_row = np.arange(len(arr))
ind_col = np.where(arr>0, labels, 10).ravel()

one_hot_coded_arr = np.zeros((len(arr), 11))
one_hot_coded_arr[ind_row,ind_col]=1

## convert back to desired shape
one_hot_coded_arr = one_hot_coded_arr.reshape(-1, 128,128,11)
#print(one_hot_coded_arr[:28,:])
print(one_hot_coded_arr.shape)


plt.imshow(one_hot_coded_arr, interpolation='nearest')
plt.axis("off")
plt.show()

我想显示这样的图像:https://documentation.sas.com/api/docsets/casdlpg/8.4/content/images/mnistout2.png

但是我一直遇到错误“ TypeError:图像数据的尺寸无效”

任何帮助都会很棒,谢谢!

python numpy mnist imshow
1个回答
0
投票

您有太多尺寸。 matplotlib.pyplot仅绘制2个尺寸(x,y)。

因此,您首先应该选择要显示的图像,即output[n]。接下来,由于它是一种热编码,请使用np.argmax(output[n], axis=-1)函数“取消编码”。

让我知道是否可行。

© www.soinside.com 2019 - 2024. All rights reserved.