import torch
from torch import nn, optim
import torchvision.transforms as transforms
import torchvision
import matplotlib.pyplot as plt
gan example mnist
notebook
# Custom Reshape layer
class Reshape(nn.Module):
def __init__(self, *args):
super(Reshape, self).__init__()
self.shape = args
def forward(self, x):
return x.view(x.shape[0], *self.shape)
# Define Generator network with convolutional layers
class Generator(nn.Module):
def __init__(self, z_dim=100):
super(Generator, self).__init__()
self.gen = nn.Sequential(
# First layer
7*7*256),
nn.Linear(z_dim, 7*7*256),
nn.BatchNorm1d(True),
nn.ReLU(
# Reshape to start convolutions
256, 7, 7),
Reshape(
# Convolution layers
256, 128, kernel_size=4, stride=2, padding=1),
nn.ConvTranspose2d(128),
nn.BatchNorm2d(True),
nn.ReLU(
128, 64, kernel_size=4, stride=2, padding=1),
nn.ConvTranspose2d(64),
nn.BatchNorm2d(True),
nn.ReLU(
64, 1, kernel_size=3, stride=1, padding=1),
nn.ConvTranspose2d(
nn.Tanh()
)
def forward(self, x):
return self.gen(x)
# Define Discriminator network with convolutional layers
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
# First conv layer
1, 64, kernel_size=4, stride=2, padding=1),
nn.Conv2d(0.2, inplace=True),
nn.LeakyReLU(0.3),
nn.Dropout2d(
# Second conv layer
64, 128, kernel_size=4, stride=2, padding=1),
nn.Conv2d(128),
nn.BatchNorm2d(0.2, inplace=True),
nn.LeakyReLU(0.3),
nn.Dropout2d(
# Flatten and dense layers
nn.Flatten(),128 * 7 * 7, 1024),
nn.Linear(1024),
nn.BatchNorm1d(0.2, inplace=True),
nn.LeakyReLU(0.3),
nn.Dropout(
1024, 1),
nn.Linear(
nn.Sigmoid()
)
def forward(self, x):
return self.disc(x)
# Hyperparameters
= 100
z_dim = 2e-4 # Slightly adjusted learning rate
lr = 0.5 # Beta1 for Adam optimizer
beta1 = 128 # Increased batch size
batch_size = 20 # Increased epochs
epochs
# Set up device
= torch.device("mps")
device
# Data loading with normalization [-1, 1]
= transforms.Compose([
transform
transforms.ToTensor(),0.5,), (0.5,))
transforms.Normalize((
])
= torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True)
dataset = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
loader
# Initialize models and move to MPS
= Generator(z_dim).to(device)
generator = Discriminator().to(device)
discriminator
# Optimizers with beta1
= optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimG = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
optimD
# Loss function
= nn.BCELoss()
criterion
# Training loop with improved stability
for epoch in range(epochs):
for batch_idx, (real_images, _) in enumerate(loader):
= real_images.shape[0]
batch_size_current
# Move data to MPS
= real_images.to(device)
real_images
# Train Discriminator
discriminator.zero_grad()= torch.ones(batch_size_current, 1).to(device)
label_real = torch.zeros(batch_size_current, 1).to(device)
label_fake
= discriminator(real_images)
output_real = criterion(output_real, label_real)
d_loss_real
= torch.randn(batch_size_current, z_dim).to(device)
noise = generator(noise)
fake_images = discriminator(fake_images.detach())
output_fake = criterion(output_fake, label_fake)
d_loss_fake
= d_loss_real + d_loss_fake
d_loss
d_loss.backward()
optimD.step()
# Train Generator
generator.zero_grad()= discriminator(fake_images)
output_fake = criterion(output_fake, label_real)
g_loss
g_loss.backward()
optimG.step()
if batch_idx % 100 == 0:
print(f'Epoch [{epoch}/{epochs}] Batch [{batch_idx}/{len(loader)}] '
f'd_loss: {d_loss.item():.4f} g_loss: {g_loss.item():.4f}')
# Generate and save sample images after each epoch
if (epoch + 1) % 5 == 0:
with torch.no_grad():
= torch.randn(5, z_dim).to(device)
test_noise = generator(test_noise)
generated_images
=(10, 2))
plt.figure(figsizefor i in range(5):
1, 5, i + 1)
plt.subplot(# Move tensor back to CPU for plotting
='gray')
plt.imshow(generated_images[i].cpu().squeeze().numpy(), cmap'off')
plt.axis( plt.show()
Epoch [0/20] Batch [0/469] d_loss: 1.4587 g_loss: 0.7364
Epoch [0/20] Batch [100/469] d_loss: 1.3710 g_loss: 0.7146
Epoch [0/20] Batch [200/469] d_loss: 1.1290 g_loss: 0.8700
Epoch [0/20] Batch [300/469] d_loss: 0.9860 g_loss: 1.1322
Epoch [0/20] Batch [400/469] d_loss: 0.4725 g_loss: 1.9018
Epoch [1/20] Batch [0/469] d_loss: 0.5324 g_loss: 1.8864
Epoch [1/20] Batch [100/469] d_loss: 0.3427 g_loss: 2.4291
Epoch [1/20] Batch [200/469] d_loss: 0.2379 g_loss: 2.6575
Epoch [1/20] Batch [300/469] d_loss: 0.2059 g_loss: 3.3945
Epoch [1/20] Batch [400/469] d_loss: 0.1537 g_loss: 3.2048
Epoch [2/20] Batch [0/469] d_loss: 0.3866 g_loss: 2.3630
Epoch [2/20] Batch [100/469] d_loss: 0.2228 g_loss: 2.8267
Epoch [2/20] Batch [200/469] d_loss: 0.1287 g_loss: 3.4258
Epoch [2/20] Batch [300/469] d_loss: 0.1049 g_loss: 3.9926
Epoch [2/20] Batch [400/469] d_loss: 0.2627 g_loss: 3.3500
Epoch [3/20] Batch [0/469] d_loss: 0.1071 g_loss: 3.8937
Epoch [3/20] Batch [100/469] d_loss: 0.1239 g_loss: 3.4904
Epoch [3/20] Batch [200/469] d_loss: 0.0526 g_loss: 4.2760
Epoch [3/20] Batch [300/469] d_loss: 0.1703 g_loss: 3.7801
Epoch [3/20] Batch [400/469] d_loss: 1.5493 g_loss: 0.7978
Epoch [4/20] Batch [0/469] d_loss: 1.4743 g_loss: 0.8553
Epoch [4/20] Batch [100/469] d_loss: 1.5439 g_loss: 0.8430
Epoch [4/20] Batch [200/469] d_loss: 1.5271 g_loss: 0.8607
Epoch [4/20] Batch [300/469] d_loss: 1.3586 g_loss: 0.7892
Epoch [4/20] Batch [400/469] d_loss: 1.3883 g_loss: 0.8830
Epoch [5/20] Batch [0/469] d_loss: 1.5093 g_loss: 0.9424
Epoch [5/20] Batch [100/469] d_loss: 1.4322 g_loss: 1.0056
Epoch [5/20] Batch [200/469] d_loss: 1.3403 g_loss: 0.9821
Epoch [5/20] Batch [300/469] d_loss: 1.3601 g_loss: 0.8536
Epoch [5/20] Batch [400/469] d_loss: 1.3104 g_loss: 0.7968
Epoch [6/20] Batch [0/469] d_loss: 1.1751 g_loss: 0.9554
Epoch [6/20] Batch [100/469] d_loss: 1.3790 g_loss: 0.9607
Epoch [6/20] Batch [200/469] d_loss: 1.2350 g_loss: 0.7746
Epoch [6/20] Batch [300/469] d_loss: 1.2191 g_loss: 0.8583
Epoch [6/20] Batch [400/469] d_loss: 1.2714 g_loss: 0.9669
Epoch [7/20] Batch [0/469] d_loss: 1.3526 g_loss: 0.8643
Epoch [7/20] Batch [100/469] d_loss: 1.3798 g_loss: 0.9228
Epoch [7/20] Batch [200/469] d_loss: 1.3123 g_loss: 0.8387
Epoch [7/20] Batch [300/469] d_loss: 1.2502 g_loss: 0.9403
Epoch [7/20] Batch [400/469] d_loss: 1.3042 g_loss: 0.9029
Epoch [8/20] Batch [0/469] d_loss: 1.2249 g_loss: 0.8839
Epoch [8/20] Batch [100/469] d_loss: 1.2266 g_loss: 0.8993
Epoch [8/20] Batch [200/469] d_loss: 1.2863 g_loss: 0.8482
Epoch [8/20] Batch [300/469] d_loss: 1.2024 g_loss: 0.9778
Epoch [8/20] Batch [400/469] d_loss: 1.1278 g_loss: 0.8528
Epoch [9/20] Batch [0/469] d_loss: 1.1930 g_loss: 0.9338
Epoch [9/20] Batch [100/469] d_loss: 1.3779 g_loss: 0.8452
Epoch [9/20] Batch [200/469] d_loss: 1.2220 g_loss: 1.0130
Epoch [9/20] Batch [300/469] d_loss: 1.2013 g_loss: 0.9121
Epoch [9/20] Batch [400/469] d_loss: 1.2612 g_loss: 1.0147
Epoch [10/20] Batch [0/469] d_loss: 1.1912 g_loss: 0.9411
Epoch [10/20] Batch [100/469] d_loss: 1.1474 g_loss: 0.9353
Epoch [10/20] Batch [200/469] d_loss: 1.2389 g_loss: 0.8523
Epoch [10/20] Batch [300/469] d_loss: 1.1961 g_loss: 0.7835
Epoch [10/20] Batch [400/469] d_loss: 1.1892 g_loss: 0.9633
Epoch [11/20] Batch [0/469] d_loss: 1.1884 g_loss: 0.8253
Epoch [11/20] Batch [100/469] d_loss: 1.2928 g_loss: 0.9628
Epoch [11/20] Batch [200/469] d_loss: 1.1756 g_loss: 0.9246
Epoch [11/20] Batch [300/469] d_loss: 1.2237 g_loss: 0.9206
Epoch [11/20] Batch [400/469] d_loss: 1.1077 g_loss: 0.9726
Epoch [12/20] Batch [0/469] d_loss: 1.2371 g_loss: 1.0076
Epoch [12/20] Batch [100/469] d_loss: 1.2275 g_loss: 0.9172
Epoch [12/20] Batch [200/469] d_loss: 1.2303 g_loss: 0.9334
Epoch [12/20] Batch [300/469] d_loss: 1.1339 g_loss: 0.9581
Epoch [12/20] Batch [400/469] d_loss: 1.2166 g_loss: 1.0264
Epoch [13/20] Batch [0/469] d_loss: 1.2645 g_loss: 0.8292
Epoch [13/20] Batch [100/469] d_loss: 1.2263 g_loss: 0.9388
Epoch [13/20] Batch [200/469] d_loss: 1.2474 g_loss: 0.9823
Epoch [13/20] Batch [300/469] d_loss: 1.2536 g_loss: 0.9741
Epoch [13/20] Batch [400/469] d_loss: 1.2097 g_loss: 0.9343
Epoch [14/20] Batch [0/469] d_loss: 1.2619 g_loss: 0.9633
Epoch [14/20] Batch [100/469] d_loss: 1.3268 g_loss: 0.9094
Epoch [14/20] Batch [200/469] d_loss: 1.2769 g_loss: 0.9404
Epoch [14/20] Batch [300/469] d_loss: 1.1858 g_loss: 0.9609
Epoch [14/20] Batch [400/469] d_loss: 1.2346 g_loss: 1.0119
Epoch [15/20] Batch [0/469] d_loss: 1.1835 g_loss: 0.9720
Epoch [15/20] Batch [100/469] d_loss: 1.2663 g_loss: 0.9126
Epoch [15/20] Batch [200/469] d_loss: 1.2762 g_loss: 0.8213
Epoch [15/20] Batch [300/469] d_loss: 1.1572 g_loss: 0.8374
Epoch [15/20] Batch [400/469] d_loss: 1.0717 g_loss: 0.9020
Epoch [16/20] Batch [0/469] d_loss: 1.2640 g_loss: 0.9940
Epoch [16/20] Batch [100/469] d_loss: 1.2792 g_loss: 1.0197
Epoch [16/20] Batch [200/469] d_loss: 1.3113 g_loss: 0.9563
Epoch [16/20] Batch [300/469] d_loss: 1.2178 g_loss: 0.9577
Epoch [16/20] Batch [400/469] d_loss: 1.2632 g_loss: 0.9852
Epoch [17/20] Batch [0/469] d_loss: 1.2279 g_loss: 0.9235
Epoch [17/20] Batch [100/469] d_loss: 1.3499 g_loss: 0.9113
Epoch [17/20] Batch [200/469] d_loss: 1.1605 g_loss: 0.9119
Epoch [17/20] Batch [300/469] d_loss: 1.2327 g_loss: 0.8470
Epoch [17/20] Batch [400/469] d_loss: 1.1778 g_loss: 0.9232
Epoch [18/20] Batch [0/469] d_loss: 1.4804 g_loss: 0.8219
Epoch [18/20] Batch [100/469] d_loss: 1.3487 g_loss: 0.8752
Epoch [18/20] Batch [200/469] d_loss: 1.2347 g_loss: 0.8730
Epoch [18/20] Batch [300/469] d_loss: 1.2076 g_loss: 0.9315
Epoch [18/20] Batch [400/469] d_loss: 1.2874 g_loss: 0.8813
Epoch [19/20] Batch [0/469] d_loss: 1.2302 g_loss: 1.0094
Epoch [19/20] Batch [100/469] d_loss: 1.2215 g_loss: 0.8352
Epoch [19/20] Batch [200/469] d_loss: 1.2204 g_loss: 0.9764
Epoch [19/20] Batch [300/469] d_loss: 1.3497 g_loss: 0.8585
Epoch [19/20] Batch [400/469] d_loss: 1.2276 g_loss: 0.9919