import torch
from torch import nn, optim
import torchvision.transforms as transforms
import torchvision
import matplotlib.pyplot as plt
gan example mnist
notebook
# Custom Reshape layer
class Reshape(nn.Module):
def __init__(self, *args):
super(Reshape, self).__init__()
self.shape = args
def forward(self, x):
return x.view(x.shape[0], *self.shape)
# Define Generator network with convolutional layers
class Generator(nn.Module):
def __init__(self, z_dim=100):
super(Generator, self).__init__()
self.gen = nn.Sequential(
# First layer
7*7*256),
nn.Linear(z_dim, 7*7*256),
nn.BatchNorm1d(True),
nn.ReLU(
# Reshape to start convolutions
256, 7, 7),
Reshape(
# Convolution layers
256, 128, kernel_size=4, stride=2, padding=1),
nn.ConvTranspose2d(128),
nn.BatchNorm2d(True),
nn.ReLU(
128, 64, kernel_size=4, stride=2, padding=1),
nn.ConvTranspose2d(64),
nn.BatchNorm2d(True),
nn.ReLU(
64, 1, kernel_size=3, stride=1, padding=1),
nn.ConvTranspose2d(
nn.Tanh()
)
def forward(self, x):
return self.gen(x)
# Define Discriminator network with convolutional layers
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
# First conv layer
1, 64, kernel_size=4, stride=2, padding=1),
nn.Conv2d(0.2, inplace=True),
nn.LeakyReLU(0.3),
nn.Dropout2d(
# Second conv layer
64, 128, kernel_size=4, stride=2, padding=1),
nn.Conv2d(128),
nn.BatchNorm2d(0.2, inplace=True),
nn.LeakyReLU(0.3),
nn.Dropout2d(
# Flatten and dense layers
nn.Flatten(),128 * 7 * 7, 1024),
nn.Linear(1024),
nn.BatchNorm1d(0.2, inplace=True),
nn.LeakyReLU(0.3),
nn.Dropout(
1024, 1),
nn.Linear(
nn.Sigmoid()
)
def forward(self, x):
return self.disc(x)
# Hyperparameters
= 100
z_dim = 2e-4 # Slightly adjusted learning rate
lr = 0.5 # Beta1 for Adam optimizer
beta1 = 128 # Increased batch size
batch_size = 20 # Increased epochs
epochs
# Set up device
= torch.device("mps")
device
# Data loading with normalization [-1, 1]
= transforms.Compose([
transform
transforms.ToTensor(),0.5,), (0.5,))
transforms.Normalize((
])
= torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True)
dataset = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
loader
# Initialize models and move to MPS
= Generator(z_dim).to(device)
generator = Discriminator().to(device)
discriminator
# Optimizers with beta1
= optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimG = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
optimD
# Loss function
= nn.BCELoss()
criterion
# Training loop with improved stability
for epoch in range(epochs):
for batch_idx, (real_images, _) in enumerate(loader):
= real_images.shape[0]
batch_size_current
# Move data to MPS
= real_images.to(device)
real_images
# Train Discriminator
discriminator.zero_grad()= torch.ones(batch_size_current, 1).to(device)
label_real = torch.zeros(batch_size_current, 1).to(device)
label_fake
= discriminator(real_images)
output_real = criterion(output_real, label_real)
d_loss_real
= torch.randn(batch_size_current, z_dim).to(device)
noise = generator(noise)
fake_images = discriminator(fake_images.detach())
output_fake = criterion(output_fake, label_fake)
d_loss_fake
= d_loss_real + d_loss_fake
d_loss
d_loss.backward()
optimD.step()
# Train Generator
generator.zero_grad()= discriminator(fake_images)
output_fake = criterion(output_fake, label_real)
g_loss
g_loss.backward()
optimG.step()
if batch_idx % 100 == 0:
print(f'Epoch [{epoch}/{epochs}] Batch [{batch_idx}/{len(loader)}] '
f'd_loss: {d_loss.item():.4f} g_loss: {g_loss.item():.4f}')
# Generate and save sample images after each epoch
if (epoch + 1) % 5 == 0:
with torch.no_grad():
= torch.randn(5, z_dim).to(device)
test_noise = generator(test_noise)
generated_images
=(10, 2))
plt.figure(figsizefor i in range(5):
1, 5, i + 1)
plt.subplot(# Move tensor back to CPU for plotting
='gray')
plt.imshow(generated_images[i].cpu().squeeze().numpy(), cmap'off')
plt.axis( plt.show()
Epoch [0/20] Batch [0/469] d_loss: 1.4239 g_loss: 0.7677
Epoch [0/20] Batch [100/469] d_loss: 1.3850 g_loss: 0.7145
Epoch [0/20] Batch [200/469] d_loss: 1.2254 g_loss: 0.8191
Epoch [0/20] Batch [300/469] d_loss: 0.8246 g_loss: 1.1075
Epoch [0/20] Batch [400/469] d_loss: 0.7940 g_loss: 1.4329
Epoch [1/20] Batch [0/469] d_loss: 0.6574 g_loss: 1.4599
Epoch [1/20] Batch [100/469] d_loss: 0.6153 g_loss: 1.8150
Epoch [1/20] Batch [200/469] d_loss: 1.1240 g_loss: 1.0195
Epoch [1/20] Batch [300/469] d_loss: 1.4135 g_loss: 0.8613
Epoch [1/20] Batch [400/469] d_loss: 1.4124 g_loss: 0.9036
Epoch [2/20] Batch [0/469] d_loss: 1.3630 g_loss: 0.8527
Epoch [2/20] Batch [100/469] d_loss: 1.2163 g_loss: 0.9640
Epoch [2/20] Batch [200/469] d_loss: 1.3488 g_loss: 0.7491
Epoch [2/20] Batch [300/469] d_loss: 1.2910 g_loss: 0.8418
Epoch [2/20] Batch [400/469] d_loss: 1.2164 g_loss: 0.8686
Epoch [3/20] Batch [0/469] d_loss: 1.2882 g_loss: 0.8355
Epoch [3/20] Batch [100/469] d_loss: 1.2517 g_loss: 0.9931
Epoch [3/20] Batch [200/469] d_loss: 1.2008 g_loss: 0.9636
Epoch [3/20] Batch [300/469] d_loss: 1.2776 g_loss: 1.1015
Epoch [3/20] Batch [400/469] d_loss: 1.2742 g_loss: 0.9338
Epoch [4/20] Batch [0/469] d_loss: 1.2320 g_loss: 0.8290
Epoch [4/20] Batch [100/469] d_loss: 1.1950 g_loss: 0.8534
Epoch [4/20] Batch [200/469] d_loss: 1.2451 g_loss: 0.8164
Epoch [4/20] Batch [300/469] d_loss: 1.1428 g_loss: 0.8949
Epoch [4/20] Batch [400/469] d_loss: 1.2496 g_loss: 0.9143
Epoch [5/20] Batch [0/469] d_loss: 1.3455 g_loss: 0.8387
Epoch [5/20] Batch [100/469] d_loss: 1.2845 g_loss: 0.8872
Epoch [5/20] Batch [200/469] d_loss: 1.2622 g_loss: 0.8158
Epoch [5/20] Batch [300/469] d_loss: 1.2124 g_loss: 0.8662
Epoch [5/20] Batch [400/469] d_loss: 1.2249 g_loss: 0.9399
Epoch [6/20] Batch [0/469] d_loss: 1.2702 g_loss: 0.9003
Epoch [6/20] Batch [100/469] d_loss: 1.2332 g_loss: 0.9916
Epoch [6/20] Batch [200/469] d_loss: 1.2569 g_loss: 0.9638
Epoch [6/20] Batch [300/469] d_loss: 1.1346 g_loss: 0.8724
Epoch [6/20] Batch [400/469] d_loss: 1.2020 g_loss: 0.8598
Epoch [7/20] Batch [0/469] d_loss: 1.3056 g_loss: 0.9416
Epoch [7/20] Batch [100/469] d_loss: 1.2384 g_loss: 0.9818
Epoch [7/20] Batch [200/469] d_loss: 1.2414 g_loss: 1.0582
Epoch [7/20] Batch [300/469] d_loss: 1.2354 g_loss: 0.9607
Epoch [7/20] Batch [400/469] d_loss: 1.1738 g_loss: 0.8974
Epoch [8/20] Batch [0/469] d_loss: 1.1154 g_loss: 0.8861
Epoch [8/20] Batch [100/469] d_loss: 1.2155 g_loss: 0.8860
Epoch [8/20] Batch [200/469] d_loss: 1.2716 g_loss: 0.8284
Epoch [8/20] Batch [300/469] d_loss: 1.2200 g_loss: 0.8618
Epoch [8/20] Batch [400/469] d_loss: 1.2214 g_loss: 0.9329
Epoch [9/20] Batch [0/469] d_loss: 1.1621 g_loss: 0.9471
Epoch [9/20] Batch [100/469] d_loss: 1.2423 g_loss: 1.0075
Epoch [9/20] Batch [200/469] d_loss: 1.1191 g_loss: 1.0043
Epoch [9/20] Batch [300/469] d_loss: 1.1531 g_loss: 0.9500
Epoch [9/20] Batch [400/469] d_loss: 1.2360 g_loss: 0.9542
Epoch [10/20] Batch [0/469] d_loss: 1.1941 g_loss: 0.8702
Epoch [10/20] Batch [100/469] d_loss: 1.2769 g_loss: 0.9748
Epoch [10/20] Batch [200/469] d_loss: 1.2864 g_loss: 0.9974
Epoch [10/20] Batch [300/469] d_loss: 1.2404 g_loss: 0.7765
Epoch [10/20] Batch [400/469] d_loss: 1.1253 g_loss: 0.8605
Epoch [11/20] Batch [0/469] d_loss: 1.3187 g_loss: 0.9972
Epoch [11/20] Batch [100/469] d_loss: 1.1936 g_loss: 1.0513
Epoch [11/20] Batch [200/469] d_loss: 1.1790 g_loss: 0.9714
Epoch [11/20] Batch [300/469] d_loss: 1.2461 g_loss: 1.0150
Epoch [11/20] Batch [400/469] d_loss: 1.3452 g_loss: 1.0451
Epoch [12/20] Batch [0/469] d_loss: 1.2775 g_loss: 0.7984
Epoch [12/20] Batch [100/469] d_loss: 1.2940 g_loss: 1.1000
Epoch [12/20] Batch [200/469] d_loss: 1.1843 g_loss: 1.0152
Epoch [12/20] Batch [300/469] d_loss: 1.2907 g_loss: 0.9341
Epoch [12/20] Batch [400/469] d_loss: 1.2473 g_loss: 0.9298
Epoch [13/20] Batch [0/469] d_loss: 1.1615 g_loss: 1.1194
Epoch [13/20] Batch [100/469] d_loss: 1.2015 g_loss: 0.8435
Epoch [13/20] Batch [200/469] d_loss: 1.2681 g_loss: 0.9003
Epoch [13/20] Batch [300/469] d_loss: 1.2993 g_loss: 0.9530
Epoch [13/20] Batch [400/469] d_loss: 1.1187 g_loss: 0.9645
Epoch [14/20] Batch [0/469] d_loss: 1.2024 g_loss: 0.9993
Epoch [14/20] Batch [100/469] d_loss: 1.2476 g_loss: 0.9358
Epoch [14/20] Batch [200/469] d_loss: 1.1476 g_loss: 1.0313
Epoch [14/20] Batch [300/469] d_loss: 1.1285 g_loss: 0.9711
Epoch [14/20] Batch [400/469] d_loss: 1.2464 g_loss: 0.9414
Epoch [15/20] Batch [0/469] d_loss: 1.1634 g_loss: 1.0355
Epoch [15/20] Batch [100/469] d_loss: 1.1517 g_loss: 0.9213
Epoch [15/20] Batch [200/469] d_loss: 1.3466 g_loss: 0.9449
Epoch [15/20] Batch [300/469] d_loss: 1.2163 g_loss: 0.8658
Epoch [15/20] Batch [400/469] d_loss: 1.2534 g_loss: 1.0694
Epoch [16/20] Batch [0/469] d_loss: 1.1684 g_loss: 0.9559
Epoch [16/20] Batch [100/469] d_loss: 1.1436 g_loss: 1.0252
Epoch [16/20] Batch [200/469] d_loss: 1.2818 g_loss: 0.8296
Epoch [16/20] Batch [300/469] d_loss: 1.2035 g_loss: 0.9181
Epoch [16/20] Batch [400/469] d_loss: 1.3543 g_loss: 0.9777
Epoch [17/20] Batch [0/469] d_loss: 1.1087 g_loss: 1.0090
Epoch [17/20] Batch [100/469] d_loss: 1.1419 g_loss: 1.0329
Epoch [17/20] Batch [200/469] d_loss: 1.1968 g_loss: 1.0202
Epoch [17/20] Batch [300/469] d_loss: 1.3013 g_loss: 1.0606
Epoch [17/20] Batch [400/469] d_loss: 1.3185 g_loss: 1.0476
Epoch [18/20] Batch [0/469] d_loss: 1.1677 g_loss: 0.9461
Epoch [18/20] Batch [100/469] d_loss: 1.1302 g_loss: 0.9498
Epoch [18/20] Batch [200/469] d_loss: 1.2141 g_loss: 1.0275
Epoch [18/20] Batch [300/469] d_loss: 1.2473 g_loss: 0.9661
Epoch [18/20] Batch [400/469] d_loss: 1.2429 g_loss: 1.0230
Epoch [19/20] Batch [0/469] d_loss: 1.1401 g_loss: 0.9990
Epoch [19/20] Batch [100/469] d_loss: 1.1564 g_loss: 0.8947
Epoch [19/20] Batch [200/469] d_loss: 1.1397 g_loss: 1.0643
Epoch [19/20] Batch [300/469] d_loss: 1.1948 g_loss: 1.1179
Epoch [19/20] Batch [400/469] d_loss: 1.2008 g_loss: 0.9639