gan example mnist

notebook
Published

January 2, 2020

import torch
from torch import nn, optim
import torchvision.transforms as transforms
import torchvision
import matplotlib.pyplot as plt
# Custom Reshape layer
class Reshape(nn.Module):
    def __init__(self, *args):
        super(Reshape, self).__init__()
        self.shape = args

    def forward(self, x):
        return x.view(x.shape[0], *self.shape)

# Define Generator network with convolutional layers
class Generator(nn.Module):
    def __init__(self, z_dim=100):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # First layer
            nn.Linear(z_dim, 7*7*256),
            nn.BatchNorm1d(7*7*256),
            nn.ReLU(True),
            
            # Reshape to start convolutions
            Reshape(256, 7, 7),
            
            # Convolution layers
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(64, 1, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)

# Define Discriminator network with convolutional layers
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # First conv layer
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.3),
            
            # Second conv layer
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.3),
            
            # Flatten and dense layers
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x)

# Hyperparameters
z_dim = 100
lr = 2e-4      # Slightly adjusted learning rate
beta1 = 0.5    # Beta1 for Adam optimizer
batch_size = 128  # Increased batch size
epochs = 20    # Increased epochs

# Set up device
device = torch.device("mps")

# Data loading with normalization [-1, 1]
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize models and move to MPS
generator = Generator(z_dim).to(device)
discriminator = Discriminator().to(device)

# Optimizers with beta1
optimG = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimD = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# Loss function
criterion = nn.BCELoss()

# Training loop with improved stability
for epoch in range(epochs):
    for batch_idx, (real_images, _) in enumerate(loader):
        batch_size_current = real_images.shape[0]
        
        # Move data to MPS
        real_images = real_images.to(device)
        
        # Train Discriminator
        discriminator.zero_grad()
        label_real = torch.ones(batch_size_current, 1).to(device)
        label_fake = torch.zeros(batch_size_current, 1).to(device)
        
        output_real = discriminator(real_images)
        d_loss_real = criterion(output_real, label_real)
        
        noise = torch.randn(batch_size_current, z_dim).to(device)
        fake_images = generator(noise)
        output_fake = discriminator(fake_images.detach())
        d_loss_fake = criterion(output_fake, label_fake)
        
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimD.step()
        
        # Train Generator
        generator.zero_grad()
        output_fake = discriminator(fake_images)
        g_loss = criterion(output_fake, label_real)
        
        g_loss.backward()
        optimG.step()
        
        if batch_idx % 100 == 0:
            print(f'Epoch [{epoch}/{epochs}] Batch [{batch_idx}/{len(loader)}] '
                  f'd_loss: {d_loss.item():.4f} g_loss: {g_loss.item():.4f}')

    # Generate and save sample images after each epoch
    if (epoch + 1) % 5 == 0:
        with torch.no_grad():
            test_noise = torch.randn(5, z_dim).to(device)
            generated_images = generator(test_noise)
            
            plt.figure(figsize=(10, 2))
            for i in range(5):
                plt.subplot(1, 5, i + 1)
                # Move tensor back to CPU for plotting
                plt.imshow(generated_images[i].cpu().squeeze().numpy(), cmap='gray')
                plt.axis('off')
            plt.show()
Epoch [0/20] Batch [0/469] d_loss: 1.4239 g_loss: 0.7677
Epoch [0/20] Batch [100/469] d_loss: 1.3850 g_loss: 0.7145
Epoch [0/20] Batch [200/469] d_loss: 1.2254 g_loss: 0.8191
Epoch [0/20] Batch [300/469] d_loss: 0.8246 g_loss: 1.1075
Epoch [0/20] Batch [400/469] d_loss: 0.7940 g_loss: 1.4329
Epoch [1/20] Batch [0/469] d_loss: 0.6574 g_loss: 1.4599
Epoch [1/20] Batch [100/469] d_loss: 0.6153 g_loss: 1.8150
Epoch [1/20] Batch [200/469] d_loss: 1.1240 g_loss: 1.0195
Epoch [1/20] Batch [300/469] d_loss: 1.4135 g_loss: 0.8613
Epoch [1/20] Batch [400/469] d_loss: 1.4124 g_loss: 0.9036
Epoch [2/20] Batch [0/469] d_loss: 1.3630 g_loss: 0.8527
Epoch [2/20] Batch [100/469] d_loss: 1.2163 g_loss: 0.9640
Epoch [2/20] Batch [200/469] d_loss: 1.3488 g_loss: 0.7491
Epoch [2/20] Batch [300/469] d_loss: 1.2910 g_loss: 0.8418
Epoch [2/20] Batch [400/469] d_loss: 1.2164 g_loss: 0.8686
Epoch [3/20] Batch [0/469] d_loss: 1.2882 g_loss: 0.8355
Epoch [3/20] Batch [100/469] d_loss: 1.2517 g_loss: 0.9931
Epoch [3/20] Batch [200/469] d_loss: 1.2008 g_loss: 0.9636
Epoch [3/20] Batch [300/469] d_loss: 1.2776 g_loss: 1.1015
Epoch [3/20] Batch [400/469] d_loss: 1.2742 g_loss: 0.9338
Epoch [4/20] Batch [0/469] d_loss: 1.2320 g_loss: 0.8290
Epoch [4/20] Batch [100/469] d_loss: 1.1950 g_loss: 0.8534
Epoch [4/20] Batch [200/469] d_loss: 1.2451 g_loss: 0.8164
Epoch [4/20] Batch [300/469] d_loss: 1.1428 g_loss: 0.8949
Epoch [4/20] Batch [400/469] d_loss: 1.2496 g_loss: 0.9143

Epoch [5/20] Batch [0/469] d_loss: 1.3455 g_loss: 0.8387
Epoch [5/20] Batch [100/469] d_loss: 1.2845 g_loss: 0.8872
Epoch [5/20] Batch [200/469] d_loss: 1.2622 g_loss: 0.8158
Epoch [5/20] Batch [300/469] d_loss: 1.2124 g_loss: 0.8662
Epoch [5/20] Batch [400/469] d_loss: 1.2249 g_loss: 0.9399
Epoch [6/20] Batch [0/469] d_loss: 1.2702 g_loss: 0.9003
Epoch [6/20] Batch [100/469] d_loss: 1.2332 g_loss: 0.9916
Epoch [6/20] Batch [200/469] d_loss: 1.2569 g_loss: 0.9638
Epoch [6/20] Batch [300/469] d_loss: 1.1346 g_loss: 0.8724
Epoch [6/20] Batch [400/469] d_loss: 1.2020 g_loss: 0.8598
Epoch [7/20] Batch [0/469] d_loss: 1.3056 g_loss: 0.9416
Epoch [7/20] Batch [100/469] d_loss: 1.2384 g_loss: 0.9818
Epoch [7/20] Batch [200/469] d_loss: 1.2414 g_loss: 1.0582
Epoch [7/20] Batch [300/469] d_loss: 1.2354 g_loss: 0.9607
Epoch [7/20] Batch [400/469] d_loss: 1.1738 g_loss: 0.8974
Epoch [8/20] Batch [0/469] d_loss: 1.1154 g_loss: 0.8861
Epoch [8/20] Batch [100/469] d_loss: 1.2155 g_loss: 0.8860
Epoch [8/20] Batch [200/469] d_loss: 1.2716 g_loss: 0.8284
Epoch [8/20] Batch [300/469] d_loss: 1.2200 g_loss: 0.8618
Epoch [8/20] Batch [400/469] d_loss: 1.2214 g_loss: 0.9329
Epoch [9/20] Batch [0/469] d_loss: 1.1621 g_loss: 0.9471
Epoch [9/20] Batch [100/469] d_loss: 1.2423 g_loss: 1.0075
Epoch [9/20] Batch [200/469] d_loss: 1.1191 g_loss: 1.0043
Epoch [9/20] Batch [300/469] d_loss: 1.1531 g_loss: 0.9500
Epoch [9/20] Batch [400/469] d_loss: 1.2360 g_loss: 0.9542

Epoch [10/20] Batch [0/469] d_loss: 1.1941 g_loss: 0.8702
Epoch [10/20] Batch [100/469] d_loss: 1.2769 g_loss: 0.9748
Epoch [10/20] Batch [200/469] d_loss: 1.2864 g_loss: 0.9974
Epoch [10/20] Batch [300/469] d_loss: 1.2404 g_loss: 0.7765
Epoch [10/20] Batch [400/469] d_loss: 1.1253 g_loss: 0.8605
Epoch [11/20] Batch [0/469] d_loss: 1.3187 g_loss: 0.9972
Epoch [11/20] Batch [100/469] d_loss: 1.1936 g_loss: 1.0513
Epoch [11/20] Batch [200/469] d_loss: 1.1790 g_loss: 0.9714
Epoch [11/20] Batch [300/469] d_loss: 1.2461 g_loss: 1.0150
Epoch [11/20] Batch [400/469] d_loss: 1.3452 g_loss: 1.0451
Epoch [12/20] Batch [0/469] d_loss: 1.2775 g_loss: 0.7984
Epoch [12/20] Batch [100/469] d_loss: 1.2940 g_loss: 1.1000
Epoch [12/20] Batch [200/469] d_loss: 1.1843 g_loss: 1.0152
Epoch [12/20] Batch [300/469] d_loss: 1.2907 g_loss: 0.9341
Epoch [12/20] Batch [400/469] d_loss: 1.2473 g_loss: 0.9298
Epoch [13/20] Batch [0/469] d_loss: 1.1615 g_loss: 1.1194
Epoch [13/20] Batch [100/469] d_loss: 1.2015 g_loss: 0.8435
Epoch [13/20] Batch [200/469] d_loss: 1.2681 g_loss: 0.9003
Epoch [13/20] Batch [300/469] d_loss: 1.2993 g_loss: 0.9530
Epoch [13/20] Batch [400/469] d_loss: 1.1187 g_loss: 0.9645
Epoch [14/20] Batch [0/469] d_loss: 1.2024 g_loss: 0.9993
Epoch [14/20] Batch [100/469] d_loss: 1.2476 g_loss: 0.9358
Epoch [14/20] Batch [200/469] d_loss: 1.1476 g_loss: 1.0313
Epoch [14/20] Batch [300/469] d_loss: 1.1285 g_loss: 0.9711
Epoch [14/20] Batch [400/469] d_loss: 1.2464 g_loss: 0.9414

Epoch [15/20] Batch [0/469] d_loss: 1.1634 g_loss: 1.0355
Epoch [15/20] Batch [100/469] d_loss: 1.1517 g_loss: 0.9213
Epoch [15/20] Batch [200/469] d_loss: 1.3466 g_loss: 0.9449
Epoch [15/20] Batch [300/469] d_loss: 1.2163 g_loss: 0.8658
Epoch [15/20] Batch [400/469] d_loss: 1.2534 g_loss: 1.0694
Epoch [16/20] Batch [0/469] d_loss: 1.1684 g_loss: 0.9559
Epoch [16/20] Batch [100/469] d_loss: 1.1436 g_loss: 1.0252
Epoch [16/20] Batch [200/469] d_loss: 1.2818 g_loss: 0.8296
Epoch [16/20] Batch [300/469] d_loss: 1.2035 g_loss: 0.9181
Epoch [16/20] Batch [400/469] d_loss: 1.3543 g_loss: 0.9777
Epoch [17/20] Batch [0/469] d_loss: 1.1087 g_loss: 1.0090
Epoch [17/20] Batch [100/469] d_loss: 1.1419 g_loss: 1.0329
Epoch [17/20] Batch [200/469] d_loss: 1.1968 g_loss: 1.0202
Epoch [17/20] Batch [300/469] d_loss: 1.3013 g_loss: 1.0606
Epoch [17/20] Batch [400/469] d_loss: 1.3185 g_loss: 1.0476
Epoch [18/20] Batch [0/469] d_loss: 1.1677 g_loss: 0.9461
Epoch [18/20] Batch [100/469] d_loss: 1.1302 g_loss: 0.9498
Epoch [18/20] Batch [200/469] d_loss: 1.2141 g_loss: 1.0275
Epoch [18/20] Batch [300/469] d_loss: 1.2473 g_loss: 0.9661
Epoch [18/20] Batch [400/469] d_loss: 1.2429 g_loss: 1.0230
Epoch [19/20] Batch [0/469] d_loss: 1.1401 g_loss: 0.9990
Epoch [19/20] Batch [100/469] d_loss: 1.1564 g_loss: 0.8947
Epoch [19/20] Batch [200/469] d_loss: 1.1397 g_loss: 1.0643
Epoch [19/20] Batch [300/469] d_loss: 1.1948 g_loss: 1.1179
Epoch [19/20] Batch [400/469] d_loss: 1.2008 g_loss: 0.9639

© HakyImLab and Listed Authors - CC BY 4.0 License