如何使用keras知道WGAN中生成的图像的标签

问题描述 投票:0回答:1

我正在研究医学图像分类问题,并且患有低数据集问题。所以想用WGAN生成图像。在给定代码中,WGAN代码示例采用MNIST数据集。在图像生成之后,很容易识别它们所属的类。但在医学图像的情况下,生成图像后,很难确定生成的图像属于哪个类别,因为它们从以下代码中保存在组中:

def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close()

所以我必须执行哪些更改才能获得生成图像的标签。

python tensorflow keras
1个回答
1
投票

vanilla版本的WGAN无法有条件地生成图像。因此,您训练过的WGAN只能生成图像,而不知道它们属于哪个类。

为了能够生成特定标签的图像,请结帐条件gans。 qazxsw poi是一篇让你入门的中篇文章。

备选方案是从原始训练数据中训练鉴别器,并使用该鉴别器帮助您手动对图像进行分类。

© www.soinside.com 2019 - 2024. All rights reserved.