我正在使用这个公式处理 MNIST 数据集:设 I1 和 I2 是两个不同数字的图像,alpha 的插值 D((E(I1)alpha)+(E(I2)(1-alpha))) [0,1] 其中 D 表示解码器,E 表示编码器。
这是我的代码:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# Define hyperparameters
batch_size = 128
learning_rate = 0.001
num_epochs = 15
latent_dim = 10,
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define the device
# Define transforms to normalize the data
transform = transforms.Compose([
transforms.ToTensor()])
# Load the MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
# Define data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Define the encoder architecture
class Encoder(nn.Module):
def __init__(self, latent_dim):
super(Encoder, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
self.fc = nn.Linear(64*7*7, latent_dim)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = torch.relu(self.conv3(x))
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# Define the decoder architecture
class Decoder(nn.Module):
def __init__(self, latent_dim):
super(Decoder, self).__init__()
self.fc = nn.Linear(latent_dim, 64*7*7)
self.conv1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
self.conv2 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)
self.conv3 = nn.ConvTranspose2d(16, 1, kernel_size=3, stride=1, padding=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), 64, 7, 7)
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = self.sigmoid(self.conv3(x))
return x
# Define the autoencoder model
class Autoencoder(nn.Module):
def __init__(self, latent_dim):
super(Autoencoder, self).__init__()
self.encoder = Encoder(latent_dim)
self.decoder = Decoder(latent_dim)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
# Define the loss function and optimizer
#criterion = nn.BCELoss()
#optimizer = optim.Adam(autoencoder.parameters(), lr=learning_rate)
# Instantiate the autoencoder model
latent_dim=25
autoencoder = Autoencoder(latent_dim)
# Load the trained weights
#checkpoint = torch.load('autoencoder_weights.pth') # Replace 'autoencoder_weights.pth' with the path to your saved weights file
#autoencoder.load_state_dict(checkpoint['model_state_dict'])
autoencoder.to(device) # Move the model to the device (GPU or CPU) for inference
autoencoder.eval() # Set the model to evaluation mode
# Perform interpolation in latent space and generate images
with torch.no_grad():
I1, I2 = test_dataset[0][0].unsqueeze(0).to(device), test_dataset[1][0].unsqueeze(0).to(device) # Example images of two different digits
alphas = torch.linspace(0, 1, steps=11).to(device) # Generate 11 alphas from 0 to 1
interpolated_images = []
for alpha in alphas:
latent1 = autoencoder.encoder(I1) # Encode I1 to obtain latent representation
latent2 = autoencoder.encoder(I2) # Encode I2 to obtain latent representation
interpolated_latent = (latent1 * alpha) + (latent2 * (1 - alpha)) # Interpolate in latent space
interpolated_image = autoencoder.decoder(interpolated_latent) # Decode interpolated latent representation to obtain image
interpolated_images.append(interpolated_image.squeeze().cpu().numpy()) # Convert image to numpy array and append to list
# Visualize the interpolated images
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, len(interpolated_images), figsize=(15, 15))
for i, image in enumerate(interpolated_images):
axes[i].imshow(image, cmap='gray')
axes[i].axis('off')
axes[i].set_title(f'Alpha: {alphas[i]:.1f}')
plt.show()
我想知道我的代码哪里出错了,我该如何让它工作?