mnist数据集Interpolation

问题描述 投票:0回答:0

我正在使用这个公式处理 MNIST 数据集:设 I1 和 I2 是两个不同数字的图像,alpha 的插值 D((E(I1)alpha)+(E(I2)(1-alpha))) [0,1] 其中 D 表示解码器,E 表示编码器。

这是我的代码:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Define hyperparameters
batch_size = 128
learning_rate = 0.001
num_epochs = 15
latent_dim = 10,
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define the device
# Define transforms to normalize the data
transform = transforms.Compose([
    transforms.ToTensor()])


# Load the MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# Define data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define the encoder architecture
class Encoder(nn.Module):
    def  __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.fc = nn.Linear(64*7*7, latent_dim)
        
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Define the decoder architecture
class Decoder(nn.Module):
    def  __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 64*7*7)
        self.conv1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv2 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv3 = nn.ConvTranspose2d(16, 1, kernel_size=3, stride=1, padding=1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 64, 7, 7)
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.sigmoid(self.conv3(x))
        return x

# Define the autoencoder model
class Autoencoder(nn.Module):
    def  __init__(self, latent_dim):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Define the loss function and optimizer
#criterion = nn.BCELoss()
#optimizer = optim.Adam(autoencoder.parameters(), lr=learning_rate)


# Instantiate the autoencoder model
latent_dim=25
autoencoder = Autoencoder(latent_dim)

# Load the trained weights
#checkpoint = torch.load('autoencoder_weights.pth') # Replace 'autoencoder_weights.pth' with the path to your saved weights file
#autoencoder.load_state_dict(checkpoint['model_state_dict'])
autoencoder.to(device) # Move the model to the device (GPU or CPU) for inference
autoencoder.eval() # Set the model to evaluation mode

# Perform interpolation in latent space and generate images
with torch.no_grad():
    I1, I2 = test_dataset[0][0].unsqueeze(0).to(device), test_dataset[1][0].unsqueeze(0).to(device) # Example images of two different digits
    alphas = torch.linspace(0, 1, steps=11).to(device) # Generate 11 alphas from 0 to 1
    interpolated_images = []
    for alpha in alphas:
        latent1 = autoencoder.encoder(I1) # Encode I1 to obtain latent representation
        latent2 = autoencoder.encoder(I2) # Encode I2 to obtain latent representation
        interpolated_latent = (latent1 * alpha) + (latent2 * (1 - alpha)) # Interpolate in latent space
        interpolated_image = autoencoder.decoder(interpolated_latent) # Decode interpolated latent representation to obtain image
        interpolated_images.append(interpolated_image.squeeze().cpu().numpy()) # Convert image to numpy array and append to list

# Visualize the interpolated images
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, len(interpolated_images), figsize=(15, 15))
for i, image in enumerate(interpolated_images):
    axes[i].imshow(image, cmap='gray')
    axes[i].axis('off')
    axes[i].set_title(f'Alpha: {alphas[i]:.1f}')

plt.show()

不幸的是,我尝试的每件事都会返回模糊的图像: enter image description here

我想知道我的代码哪里出错了,我该如何让它工作?

image conv-neural-network convolution mnist
© www.soinside.com 2019 - 2024. All rights reserved.