Building a GPT - companion notebook annotated

Companion notebook from Karpathys video on building a minimal GPT, annotated by cursors LLM with summary from gemini.
Author

Andrey Karpathy

Published

April 15, 2025

Building a GPT

Companion notebook to the Zero To Hero video on GPT. Downloaded from here

(https://github.com/karpathy/nanoGPT)

download the tiny shakespeare dataset

Show the code
# Download the tiny shakespeare dataset
#!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Show the code
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()
Show the code
# print the length of the dataset
print("length of dataset in characters: ", len(text))
length of dataset in characters:  1115394
Show the code
# let's look at the first 1000 characters
print(text[:1000])
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.

Show the code
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65

mapping characters to integers and vice versa

Show the code
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))
[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there

encode the data into torch tensor

Show the code
# let's now encode the entire text dataset and store it into a torch.Tensor
import torch # we use PyTorch: [https://pytorch.org](https://pytorch.org)
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this
torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
        47, 59, 57,  1, 47, 57,  1, 41, 46, 47, 43, 44,  1, 43, 52, 43, 51, 63,
         1, 58, 53,  1, 58, 46, 43,  1, 54, 43, 53, 54, 50, 43,  8,  0,  0, 13,
        50, 50, 10,  0, 35, 43,  1, 49, 52, 53, 61,  5, 58,  6,  1, 61, 43,  1,
        49, 52, 53, 61,  5, 58,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47, 58,
        47, 64, 43, 52, 10,  0, 24, 43, 58,  1, 59, 57,  1, 49, 47, 50, 50,  1,
        46, 47, 51,  6,  1, 39, 52, 42,  1, 61, 43,  5, 50, 50,  1, 46, 39, 60,
        43,  1, 41, 53, 56, 52,  1, 39, 58,  1, 53, 59, 56,  1, 53, 61, 52,  1,
        54, 56, 47, 41, 43,  8,  0, 21, 57,  5, 58,  1, 39,  1, 60, 43, 56, 42,
        47, 41, 58, 12,  0,  0, 13, 50, 50, 10,  0, 26, 53,  1, 51, 53, 56, 43,
         1, 58, 39, 50, 49, 47, 52, 45,  1, 53, 52,  5, 58, 11,  1, 50, 43, 58,
         1, 47, 58,  1, 40, 43,  1, 42, 53, 52, 43, 10,  1, 39, 61, 39, 63,  6,
         1, 39, 61, 39, 63,  2,  0,  0, 31, 43, 41, 53, 52, 42,  1, 15, 47, 58,
        47, 64, 43, 52, 10,  0, 27, 52, 43,  1, 61, 53, 56, 42,  6,  1, 45, 53,
        53, 42,  1, 41, 47, 58, 47, 64, 43, 52, 57,  8,  0,  0, 18, 47, 56, 57,
        58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 35, 43,  1, 39, 56, 43,  1,
        39, 41, 41, 53, 59, 52, 58, 43, 42,  1, 54, 53, 53, 56,  1, 41, 47, 58,
        47, 64, 43, 52, 57,  6,  1, 58, 46, 43,  1, 54, 39, 58, 56, 47, 41, 47,
        39, 52, 57,  1, 45, 53, 53, 42,  8,  0, 35, 46, 39, 58,  1, 39, 59, 58,
        46, 53, 56, 47, 58, 63,  1, 57, 59, 56, 44, 43, 47, 58, 57,  1, 53, 52,
         1, 61, 53, 59, 50, 42,  1, 56, 43, 50, 47, 43, 60, 43,  1, 59, 57, 10,
         1, 47, 44,  1, 58, 46, 43, 63,  0, 61, 53, 59, 50, 42,  1, 63, 47, 43,
        50, 42,  1, 59, 57,  1, 40, 59, 58,  1, 58, 46, 43,  1, 57, 59, 54, 43,
        56, 44, 50, 59, 47, 58, 63,  6,  1, 61, 46, 47, 50, 43,  1, 47, 58,  1,
        61, 43, 56, 43,  0, 61, 46, 53, 50, 43, 57, 53, 51, 43,  6,  1, 61, 43,
         1, 51, 47, 45, 46, 58,  1, 45, 59, 43, 57, 57,  1, 58, 46, 43, 63,  1,
        56, 43, 50, 47, 43, 60, 43, 42,  1, 59, 57,  1, 46, 59, 51, 39, 52, 43,
        50, 63, 11,  0, 40, 59, 58,  1, 58, 46, 43, 63,  1, 58, 46, 47, 52, 49,
         1, 61, 43,  1, 39, 56, 43,  1, 58, 53, 53,  1, 42, 43, 39, 56, 10,  1,
        58, 46, 43,  1, 50, 43, 39, 52, 52, 43, 57, 57,  1, 58, 46, 39, 58,  0,
        39, 44, 44, 50, 47, 41, 58, 57,  1, 59, 57,  6,  1, 58, 46, 43,  1, 53,
        40, 48, 43, 41, 58,  1, 53, 44,  1, 53, 59, 56,  1, 51, 47, 57, 43, 56,
        63,  6,  1, 47, 57,  1, 39, 57,  1, 39, 52,  0, 47, 52, 60, 43, 52, 58,
        53, 56, 63,  1, 58, 53,  1, 54, 39, 56, 58, 47, 41, 59, 50, 39, 56, 47,
        57, 43,  1, 58, 46, 43, 47, 56,  1, 39, 40, 59, 52, 42, 39, 52, 41, 43,
        11,  1, 53, 59, 56,  0, 57, 59, 44, 44, 43, 56, 39, 52, 41, 43,  1, 47,
        57,  1, 39,  1, 45, 39, 47, 52,  1, 58, 53,  1, 58, 46, 43, 51,  1, 24,
        43, 58,  1, 59, 57,  1, 56, 43, 60, 43, 52, 45, 43,  1, 58, 46, 47, 57,
         1, 61, 47, 58, 46,  0, 53, 59, 56,  1, 54, 47, 49, 43, 57,  6,  1, 43,
        56, 43,  1, 61, 43,  1, 40, 43, 41, 53, 51, 43,  1, 56, 39, 49, 43, 57,
        10,  1, 44, 53, 56,  1, 58, 46, 43,  1, 45, 53, 42, 57,  1, 49, 52, 53,
        61,  1, 21,  0, 57, 54, 43, 39, 49,  1, 58, 46, 47, 57,  1, 47, 52,  1,
        46, 59, 52, 45, 43, 56,  1, 44, 53, 56,  1, 40, 56, 43, 39, 42,  6,  1,
        52, 53, 58,  1, 47, 52,  1, 58, 46, 47, 56, 57, 58,  1, 44, 53, 56,  1,
        56, 43, 60, 43, 52, 45, 43,  8,  0,  0])

split up the data into train and validation sets

Show the code
# split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

define the block size

Show the code
block_size = 8
train_data[:block_size+1]
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

define the context and target: 8 examples in one batch

Show the code
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")
when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58

define the batch size and get the batch

Show the code
torch.manual_seed(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")
inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
----
when input is [24] the target: 43
when input is [24, 43] the target: 58
when input is [24, 43, 58] the target: 5
when input is [24, 43, 58, 5] the target: 57
when input is [24, 43, 58, 5, 57] the target: 1
when input is [24, 43, 58, 5, 57, 1] the target: 46
when input is [24, 43, 58, 5, 57, 1, 46] the target: 43
when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39
when input is [44] the target: 53
when input is [44, 53] the target: 56
when input is [44, 53, 56] the target: 1
when input is [44, 53, 56, 1] the target: 58
when input is [44, 53, 56, 1, 58] the target: 46
when input is [44, 53, 56, 1, 58, 46] the target: 39
when input is [44, 53, 56, 1, 58, 46, 39] the target: 58
when input is [44, 53, 56, 1, 58, 46, 39, 58] the target: 1
when input is [52] the target: 58
when input is [52, 58] the target: 1
when input is [52, 58, 1] the target: 58
when input is [52, 58, 1, 58] the target: 46
when input is [52, 58, 1, 58, 46] the target: 39
when input is [52, 58, 1, 58, 46, 39] the target: 58
when input is [52, 58, 1, 58, 46, 39, 58] the target: 1
when input is [52, 58, 1, 58, 46, 39, 58, 1] the target: 46
when input is [25] the target: 17
when input is [25, 17] the target: 27
when input is [25, 17, 27] the target: 10
when input is [25, 17, 27, 10] the target: 0
when input is [25, 17, 27, 10, 0] the target: 21
when input is [25, 17, 27, 10, 0, 21] the target: 1
when input is [25, 17, 27, 10, 0, 21, 1] the target: 54
when input is [25, 17, 27, 10, 0, 21, 1, 54] the target: 39

start with a simple model: the bigram language model

Show the code
# define the bigram language model
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

cross entropy loss

Loss = -\sum_{i}(y_i * \log(p_i))x

where:

y_i = actual probability (0 or 1 for the i-th class) p_i = predicted probability for the i-th class \sum = sum over all classes (characters)

This is the loss for a single token prediction. The total loss reported by F.cross_entropy is the average loss across all B*T tokens in the batch, where:

B = batch_size T = block_size (sequence length)

Before training, we would expect the model to predict the next character from a uniform distribution (random guessing). The probability for the correct character would be 1 / \text{vocab_size}.

Expected initial loss \approx - \log(1 / \text{vocab_size}) = \log(\text{vocab_size}) \log(65) \approx 4.1744

initialize the model and compute the loss

Show the code
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb) # xb/yb are from the previous cell (B=4, T=8)
print(logits.shape) # Expected: (B, T, C) = (4, 8, 65)
print(loss) # Expected: Around 4.17
torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)

generate text

Show the code
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ

choose AdamW as the optimizer

Show the code
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

train the model

Show the code
batch_size = 32 # Redefine batch size for training
for steps in range(100): # # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())
4.65630578994751

generate text starting with 0=\n as initial context

Show the code
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))

oTo.JUZ!!zqe!
xBP qbs$Gy'AcOmrLwwt
p$x;Seh-onQbfM?OjKbn'NwUAW -Np3fkz$FVwAUEa-wzWC -wQo-R!v -Mj?,SPiTyZ;o-opr$mOiPJEYD-CfigkzD3p3?zvS;ADz;.y?o,ivCuC'zqHxcVT cHA
rT'Fd,SBMZyOslg!NXeF$sBe,juUzLq?w-wzP-h
ERjjxlgJzPbHxf$ q,q,KCDCU fqBOQT
SV&CW:xSVwZv'DG'NSPypDhKStKzC -$hslxIVzoivnp ,ethA:NCCGoi
tN!ljjP3fwJMwNelgUzzPGJlgihJ!d?q.d
pSPYgCuCJrIFtb
jQXg
pA.P LP,SPJi
DBcuBM:CixjJ$Jzkq,OLf3KLQLMGph$O 3DfiPHnXKuHMlyjxEiyZib3FaHV-oJa!zoc'XSP :CKGUhd?lgCOF$;;DTHZMlvvcmZAm;:iv'MMgO&Ywbc;BLCUd&vZINLIzkuTGZa
D.?

The mathematical trick in self-attention

toy example illustrating how matrix multiplication can be used for a “weighted aggregation”

Show the code
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3)) # Lower triangular matrix of 1s
a = a / torch.sum(a, 1, keepdim=True) # Normalize rows to sum to 1
b = torch.randint(0,10,(3,2)).float() # Some data
c = a @ b # Matrix multiply
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)
a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])
Show the code
# consider the following toy example:
torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
x.shape
torch.Size([4, 8, 2])

version 1: using a for loop to compute the weighted aggregation

Show the code
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C)) # x bag-of-words (running average)
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # Select vectors from start up to time t: shape (t+1, C)
        xbow[b,t] = torch.mean(xprev, 0) # Compute mean along the time dimension (dim 0)

version 2: using matrix multiply for a weighted aggregation

Show the code
# Create the averaging weight matrix
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True) # Normalize rows to sum to 1
# Perform batched matrix multiplication
xbow2 = wei @ x # (T, T) @ (B, T, C) -> (B, T, C) via broadcasting
torch.allclose(xbow, xbow2) # Check if results are identical
True

version 3: use Softmax

Show the code
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
# Mask out future positions by setting them to -infinity before softmax
wei = wei.masked_fill(tril == 0, float('-inf'))
# Apply softmax to get row-wise probability distributions (weights)
wei = F.softmax(wei, dim=-1)
# Perform weighted aggregation
xbow3 = wei @ x
torch.allclose(xbow, xbow3) # Check if results are identical
True

softmax function

softmax(z_i) = \frac{e^{z_i}}{\sum_j e^{z_j}}

version 4: self-attention

Show the code
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels (embedding dimension)
x = torch.randn(B,T,C)

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, head_size)
q = query(x) # (B, T, head_size)
# Compute attention scores ("affinities")
wei = q @ k.transpose(-2, -1) # (B, T, hs) @ (B, hs, T) ---> (B, T, T)

# Scale the scores
# Note: Karpathy uses C**-0.5 here (sqrt(embedding_dim)). Standard Transformer uses sqrt(head_size).
wei = wei * (C**-0.5)

# Apply causal mask
tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T,T)) # This line is commented out in original, was from softmax demo
wei = wei.masked_fill(tril == 0, float('-inf')) # Mask future tokens

# Apply softmax to get attention weights
wei = F.softmax(wei, dim=-1) # (B, T, T)

# Perform weighted aggregation of Values
v = value(x) # (B, T, head_size)
out = wei @ v # (B, T, T) @ (B, T, hs) ---> (B, T, hs)
#out = wei @ x # This would aggregate original x, not the projected values 'v'

out.shape # Expected: (B, T, head_size) = (4, 8, 16)
torch.Size([4, 8, 16])
Show the code
wei[0] # Show attention weights for the first sequence in the batch
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4264, 0.5736, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3151, 0.3022, 0.3827, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3007, 0.2272, 0.2467, 0.2253, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1635, 0.2048, 0.1776, 0.1616, 0.2926, 0.0000, 0.0000, 0.0000],
        [0.1403, 0.2272, 0.1454, 0.1244, 0.2678, 0.0949, 0.0000, 0.0000],
        [0.1554, 0.1815, 0.1224, 0.1213, 0.1428, 0.1603, 0.1164, 0.0000],
        [0.0952, 0.1217, 0.1130, 0.1453, 0.1137, 0.1180, 0.1467, 0.1464]],
       grad_fn=<SelectBackward0>)

Check that X X’/C is is the correlation matrix if X is normalized


nC = 64
X = matrix(rnorm(4*64), nrow=4, ncol=nC)
## make it so that the third token is similar to the last one
X[2,] = X[4,]*0.5 + X[2,]*0.5
## normalize X
X = t(scale(t(X)))

q = X
k = X
v = X

qkt = q %*% t(k)/(nC-1)
xcor = cor(t(q),t(k))
dim(xcor)
dim(qkt)
cat("xcor\n")
xcor
cat("---\n qkt\n")
qkt

cat("are xcor and qkt equal?")
all.equal(xcor, qkt)

par(mar=c(5, 6, 4, 2) + 0.1)  # increase left margin to avoid cutting of the y label
par(pty="s")  # Set plot type to "square"
plot(c(xcor), c(qkt),cex=3,cex.lab=3,cex.axis=2,cex.main=2,cex.sub=2); abline(0,1)
par(pty="m")  # Reset to default plot type
par(mar=c(5, 4, 4, 2) + 0.1)  # Reset to default margins

Notes:

Attention is a communication mechanism. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.

  • There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens. example: “the cat sat on the mat” should be different from “the mat sat on the cat”
  • Each example across batch dimension is of course processed completely independently and never “talk” to each other.
  • In an “encoder” attention block just delete the single line that does masking with tril, allowing all tokens to communicate. This block here is called a “decoder” attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
  • “self-attention” just means that the keys and values are produced from the same source as queries (all come from x). In “cross-attention”, the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)

why scaled attention?

“Scaled” attention additionaly divides wei by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below

Show the code
# Demonstrate variance without scaling
k_unscaled = torch.randn(B,T,head_size)
q_unscaled = torch.randn(B,T,head_size)
wei_unscaled = q_unscaled @ k_unscaled.transpose(-2, -1)
print(f"k var: {k_unscaled.var():.4f}, q var: {q_unscaled.var():.4f}, wei (unscaled) var: {wei_unscaled.var():.4f}")

# Demonstrate variance *with* scaling (using head_size for illustration)
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
wei = q @ k.transpose(-2, -1) * head_size**-0.5 # Scale by sqrt(head_size)
print(f"k var: {k.var():.4f}, q var: {q.var():.4f}, wei (scaled) var: {wei.var():.4f}") # Variance should be closer to 1
k var: 1.0449, q var: 1.0700, wei (unscaled) var: 17.4690
k var: 0.9006, q var: 1.0037, wei (scaled) var: 0.9957
Show the code
k.var() # Should be close to 1
tensor(0.9006)
Show the code
q.var() # Should be close to 1
tensor(1.0037)
Show the code
wei.var() # With scaling, should be closer to 1 than head_size (16)
tensor(0.9957)
Show the code
# Softmax with small inputs (diffuse distribution)
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)
tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])
Show the code
# Softmax with large inputs (simulating unscaled attention scores) -> peaks
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1) # gets too peaky, converges to one-hot
tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])

LayerNorm1d

Show the code
class LayerNorm1d: # (used to be BatchNorm1d)
    def __init__(self, dim, eps=1e-5, momentum=0.1): # Momentum is not used in typical LayerNorm
        self.eps = eps
        # Learnable scale and shift parameters, initialized to 1 and 0
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)

    def __call__(self, x):
        # calculate the forward pass
        # Calculate mean over the *last* dimension (features/embedding)
        xmean = x.mean(1, keepdim=True) # batch mean (shape B, 1, C if input B, T, C) --> Needs adjustment for (B,C) input shape here. Assumes input is (B, dim)
        # Correction: x is (32, 100). dim=1 is correct for features. Shape (32, 1)
        xvar = x.var(1, keepdim=True) # batch variance (shape 32, 1)
        # Normalize each feature vector independently
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
        # Apply scale and shift
        self.out = self.gamma * xhat + self.beta
        return self.out

    def parameters(self):
        # Expose gamma and beta as learnable parameters
        return [self.gamma, self.beta]

torch.manual_seed(1337)
module = LayerNorm1d(100) # Create LayerNorm for 100 features
x = torch.randn(32, 100) # batch size 32 of 100-dimensional vectors
x = module(x)
x.shape # Should be (32, 100)
torch.Size([32, 100])

Explanation of layernorm

Input shape: (B, T, C) where: B = batch size T = sequence length (number of tokens) C = embedding dimension (features of each token) For each token in the sequence (each position T), LayerNorm: Takes its embedding vector of size C Calculates the mean and standard deviation of just that vector Normalizes that vector by subtracting its mean and dividing by its standard deviation Applies the learnable scale (gamma) and shift (beta) parameters So if you have a sequence like “The cat sat”, and each word is represented by a 64-dimensional embedding vector, LayerNorm would: Take “The”’s 64-dimensional vector and normalize it Take “cat”’s 64-dimensional vector and normalize it Take “sat”’s 64-dimensional vector and normalize it Each token’s vector is normalized independently of the others. This is different from BatchNorm, which would normalize across the batch dimension (i.e., looking at the same position across different examples in the batch). This per-token normalization helps maintain stable gradients during training and is particularly important in Transformers where the attention mechanism needs to work with normalized vectors to compute meaningful attention scores.

Show the code
# Mean and std of the first feature *across the batch*. Not expected to be 0 and 1.
x[:,0].mean(), x[:,0].std()
(tensor(0.1469), tensor(0.8803))
Show the code
# Mean and std *across features* for the first item in the batch. Expected to be ~0 and ~1.
x[0,:].mean(), x[0,:].std()
(tensor(2.3842e-09), tensor(1.0000))

French to English translation example:

Show the code
# <--------- ENCODE ------------------><--------------- DECODE ----------------->
# les réseaux de neurones sont géniaux! <START> neural networks are awesome!<END>

Full finished code, for reference

Show the code
# Import necessary PyTorch modules
import torch
import torch.nn as nn
from torch.nn import functional as F

# ===== HYPERPARAMETERS =====
batch_size = 16       # Number of sequences per batch (Smaller than Bigram training)
block_size = 32       # Context length (Larger than Bigram demo)
max_iters = 5000      # Total training iterations (More substantial training) TODO change to 5000 later
eval_interval = 100   # How often to check validation loss
learning_rate = 1e-3  # Optimizer learning rate
eval_iters = 200      # Number of batches to average for validation loss estimate
n_embd = 64           # Embedding dimension (Size of token vectors)
n_head = 4            # Number of attention heads
n_layer = 4           # Number of Transformer blocks (layers)
dropout = 0.0         # Dropout probability (0.0 means no dropout here)
# ==========================

# Device selection: MPS (Apple Silicon) > CUDA > CPU
if torch.backends.mps.is_available():
    device = torch.device("mps")   # Apple Silicon GPU
elif torch.cuda.is_available():
    device = torch.device("cuda")  # NVIDIA GPU
else:
    device = torch.device("cpu")   # CPU fallback
print(f"Using device: {device}")

# Set random seed for reproducibility
torch.manual_seed(1337)
if device.type == 'cuda':
    torch.cuda.manual_seed(1337)
elif device.type == 'mps':
    torch.mps.manual_seed(1337)

# Load and read the training text (assuming input.txt is available)
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# ===== DATA PREPROCESSING =====
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }   # string to index
itos = { i:ch for i,ch in enumerate(chars) }   # index to string
encode = lambda s: [stoi[c] for c in s]   # convert string to list of integers
decode = lambda l: ''.join([itos[i] for i in l])   # convert list of integers to string

# Split data into training and validation sets
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))   # first 90% for training
train_data = data[:n]
val_data = data[n:]
# =============================

# ===== DATA LOADING FUNCTION =====
def get_batch(split):
    """Generate a batch of data for training or validation."""
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device) # Move data to the target device
    return x, y
# ================================

# ===== LOSS ESTIMATION FUNCTION =====
@torch.no_grad()   # Disable gradient calculation for efficiency
def estimate_loss():
    """Estimate the loss on training and validation sets."""
    out = {}
    model.eval()   # Set model to evaluation mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()  # Set model back to training mode
    return out
# ===================================

# ===== ATTENTION HEAD IMPLEMENTATION =====
class Head(nn.Module):
    """Single head of self-attention."""
    
    def __init__(self, head_size):
        super().__init__()
        # Linear projections for Key, Query, Value
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        # Causal mask (tril). 'register_buffer' makes it part of the model state but not a parameter to be trained.
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        # Dropout layer (applied after softmax)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape # C here is n_embd
        # Project input to K, Q, V
        k = self.key(x)   # (B,T,head_size)
        q = self.query(x) # (B,T,head_size)
        # Compute attention scores, scale, mask, softmax
        # Note the scaling by C**-0.5 (sqrt(n_embd)) as discussed before
        wei = q @ k.transpose(-2,-1) * C**-0.5   # (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))   # Use dynamic slicing [:T, :T] for flexibility if T < block_size
        wei = F.softmax(wei, dim=-1)   # (B, T, T)
        wei = self.dropout(wei) # Apply dropout to attention weights
        # Weighted aggregation of values
        v = self.value(x) # (B,T,head_size)
        out = wei @ v # (B, T, T) @ (B, T, head_size) -> (B, T, head_size)
        return out
# ========================================

# ===== MULTI-HEAD ATTENTION =====
class MultiHeadAttention(nn.Module):
    """Multiple heads of self-attention in parallel."""
    
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        # Linear layer after concatenating heads
        self.proj = nn.Linear(n_embd, n_embd) # Projects back to n_embd dimension
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Compute attention for each head and concatenate results
        out = torch.cat([h(x) for h in self.heads], dim=-1) # Shape (B, T, num_heads * head_size) = (B, T, n_embd)
        # Apply final projection and dropout
        out = self.dropout(self.proj(out))
        return out
# ===============================

# ===== FEED-FORWARD NETWORK =====
class FeedFoward(nn.Module):
    """Simple position-wise feed-forward network with one hidden layer."""
    
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),   # Expand dimension (common practice)
            nn.ReLU(),                      # Non-linearity
            nn.Linear(4 * n_embd, n_embd),   # Project back to original dimension
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)
# ==============================

# ===== TRANSFORMER BLOCK =====
class Block(nn.Module):
    """Transformer block: communication (attention) followed by computation (FFN)."""
    
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head   # Calculate size for each head
        self.sa = MultiHeadAttention(n_head, head_size) # Self-Attention layer
        self.ffwd = FeedFoward(n_embd) # Feed-Forward layer
        self.ln1 = nn.LayerNorm(n_embd) # LayerNorm for Attention input
        self.ln2 = nn.LayerNorm(n_embd) # LayerNorm for FFN input

    def forward(self, x):
        # Pre-Normalization variant: Norm -> Sublayer -> Residual
        x = x + self.sa(self.ln1(x))  # Attention block
        x = x + self.ffwd(self.ln2(x)) # Feed-forward block
        return x
# ============================

# ===== LANGUAGE MODEL =====
class BigramLanguageModel(nn.Module):
    """GPT-like language model using Transformer blocks."""
    
    def __init__(self):
        super().__init__()
        # Token Embedding Table: Maps character index to embedding vector. (vocab_size, n_embd)
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        # Position Embedding Table: Maps position index (0 to block_size-1) to embedding vector. (block_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        # Sequence of Transformer Blocks
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        # Final Layer Normalization (applied after blocks)
        self.ln_f = nn.LayerNorm(n_embd)   # Final layer norm
        # Linear Head: Maps final embedding back to vocabulary size to get logits. (n_embd, vocab_size)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        # Get token embeddings from indices: (B, T) -> (B, T, n_embd)
        tok_emb = self.token_embedding_table(idx)
        # Get position embeddings: Create indices 0..T-1, look up embeddings -> (T, n_embd)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        # Combine token and position embeddings by addition: (B, T, n_embd). Broadcasting handles the addition.
        x = tok_emb + pos_emb   # (B,T,C)
        # Pass through Transformer blocks: (B, T, n_embd) -> (B, T, n_embd)
        x = self.blocks(x)
        # Apply final LayerNorm
        x = self.ln_f(x)
        # Map to vocabulary logits: (B, T, n_embd) -> (B, T, vocab_size)
        logits = self.lm_head(x)

        # Calculate loss if targets are provided (same as before)
        if targets is None:
            loss = None
        else:
            # Reshape for cross_entropy: (B*T, vocab_size) and (B*T)
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        """Generate new text given a starting sequence."""
        for _ in range(max_new_tokens):
            # Crop context `idx` to the last `block_size` tokens. Important as position embeddings only go up to block_size.
            idx_cond = idx[:, -block_size:]
            # Get predictions (logits) from the model
            logits, loss = self(idx_cond)
            # Focus on the logits for the *last* time step: (B, C)
            logits = logits[:, -1, :]
            # Convert logits to probabilities via softmax
            probs = F.softmax(logits, dim=-1)   # (B, C)
            # Sample next token index from the probability distribution
            idx_next = torch.multinomial(probs, num_samples=1)   # (B, 1)
            # Append the sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)   # (B, T+1)
        return idx
# =========================

# ===== MODEL INITIALIZATION AND TRAINING =====
# Create model instance and move it to the selected device
model = BigramLanguageModel()
m = model.to(device)
# Print number of parameters (useful for understanding model size)
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters') # Calculate and print M parameters

# Create optimizer (AdamW again)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# Training loop
for iter in range(max_iters):
    # Evaluate loss periodically
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss() # Get train/val loss using the helper function
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") # Print losses

    # Sample a batch of data
    xb, yb = get_batch('train')

    # Forward pass: Evaluate loss
    logits, loss = model(xb, yb)
    # Backward pass: Calculate gradients
    optimizer.zero_grad(set_to_none=True) # Zero gradients
    loss.backward() # Backpropagation
    # Update parameters
    optimizer.step() # Optimizer step

# Generate text from the trained model
context = torch.zeros((1, 1), dtype=torch.long, device=device) # Starting context: [[0]]
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))
# ============================================
Using device: mps
0.209729 M parameters
step 0: train loss 4.4116, val loss 4.4022
step 100: train loss 2.6568, val loss 2.6670
step 200: train loss 2.5091, val loss 2.5059
step 300: train loss 2.4194, val loss 2.4336
step 400: train loss 2.3499, val loss 2.3563
step 500: train loss 2.2963, val loss 2.3126
step 600: train loss 2.2411, val loss 2.2501
step 700: train loss 2.2053, val loss 2.2188
step 800: train loss 2.1645, val loss 2.1882
step 900: train loss 2.1238, val loss 2.1498
step 1000: train loss 2.1027, val loss 2.1297
step 1100: train loss 2.0699, val loss 2.1186
step 1200: train loss 2.0394, val loss 2.0806
step 1300: train loss 2.0255, val loss 2.0644
step 1400: train loss 1.9924, val loss 2.0376
step 1500: train loss 1.9697, val loss 2.0303
step 1600: train loss 1.9644, val loss 2.0482
step 1700: train loss 1.9413, val loss 2.0122
step 1800: train loss 1.9087, val loss 1.9949
step 1900: train loss 1.9106, val loss 1.9898
step 2000: train loss 1.8858, val loss 1.9993
step 2100: train loss 1.8722, val loss 1.9762
step 2200: train loss 1.8602, val loss 1.9636
step 2300: train loss 1.8577, val loss 1.9551
step 2400: train loss 1.8442, val loss 1.9467
step 2500: train loss 1.8153, val loss 1.9439
step 2600: train loss 1.8224, val loss 1.9363
step 2700: train loss 1.8125, val loss 1.9370
step 2800: train loss 1.8054, val loss 1.9250
step 2900: train loss 1.8045, val loss 1.9336
step 3000: train loss 1.7950, val loss 1.9202
step 3100: train loss 1.7707, val loss 1.9197
step 3200: train loss 1.7545, val loss 1.9107
step 3300: train loss 1.7569, val loss 1.9075
step 3400: train loss 1.7533, val loss 1.8942
step 3500: train loss 1.7374, val loss 1.8960
step 3600: train loss 1.7268, val loss 1.8909
step 3700: train loss 1.7277, val loss 1.8814
step 3800: train loss 1.7188, val loss 1.8889
step 3900: train loss 1.7194, val loss 1.8714
step 4000: train loss 1.7127, val loss 1.8636
step 4100: train loss 1.7073, val loss 1.8710
step 4200: train loss 1.7022, val loss 1.8597
step 4300: train loss 1.6994, val loss 1.8488
step 4400: train loss 1.7048, val loss 1.8664
step 4500: train loss 1.6860, val loss 1.8461
step 4600: train loss 1.6854, val loss 1.8304
step 4700: train loss 1.6841, val loss 1.8469
step 4800: train loss 1.6655, val loss 1.8454
step 4900: train loss 1.6713, val loss 1.8387
step 4999: train loss 1.6656, val loss 1.8277

Foast.

MENENIUS:
Praviely your niews? I cank, CORiced aggele;
Or heave worth sunt bone Ammiod, Lord,
Who is make thy batted oub! servilings
Toke as lihtch you basw to see swife,
Is letsts lown'd us; to lace and though mistrair took
And the proply enstriaghte for a shien.
Why, they foul tlead,
up is later and
behoy cried men as thou beatt his you.

HERRY VI:
There, you weaks mirre and all was imper, Then death, doth those I will read;
Weas sul't is King me, I what lady so not this dire.

ROMEO:
O, upon to death! him not this bornorow-prove.

MUCIOND:
Why leave ye no you?

DUCUCHESTEH:
But one thyies, if will the save your blages wore I mong father you hast;
Alaitle not arm thither crown tow doth.

FROM WTARDit't me reven.

WARWICK:
Or, as extress womb voishmas!
Good me you; and incaes up done! make,
Or I serigh to emmequerel, to speak, herse to supomet?

LUCIO:
The like, But twast on was theirs
poor of thou do
As hath lay but so bredaint, forweet of For which his lictless me,
That while fumseriands thy unclity,
Wheree I wam my broth? am the too to virsant, whould enterfuly,
All there, ontreman one his him;
When whom to Luvinge one the rews,
Warwixt kill himfined me the bights the with and
Thost will in him,
Mor Sonme man, make to men, Must took.

Server:
Is aid the underer you: if
The I holseld at most lost! Comioli his but a bedrip thy lord,
And then you pringent, and what you kingle is a gestreface is ears.
But take me. Tis basdeh,--
Cendom to nie,
You lordone turn to mine hath dels in woo forth.
Poy devisecity, Ineed and encont
Onking, pleasiness, here's me?
What the have of the doet.

ClaytAM:
Now tweett, cour is plose,
Ostate, and you raint this made untu
With ould to Warwith that me bone;
Will him drown the have wesest: doth,
Are goody gent yours the pot opings, time same, that BI thirself have gative. I' cown love this mind,
Nether if thou her fortune they have fight my ftlair aggainst for him burry.

BRUTUS:
Whoth lost for for leth
And, being eyes
And if for
Note

With 5000 iterations, the model is able to generate text that is similar to the training text.

© HakyImLab and Listed Authors - CC BY 4.0 License