5 min read
On this page

Transformers

Self-Attention

The core mechanism: each token attends to all other tokens to compute a context-aware representation.

Given input embeddings X (sequence_length x d_model):

Q = X * W_Q    # queries  (n x d_k)
K = X * W_K    # keys     (n x d_k)
V = X * W_V    # values   (n x d_v)

Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V

The sqrt(d_k) scaling prevents dot products from growing large, which would push softmax into regions with tiny gradients.

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.shape[-1]
    scores = Q @ K.transpose(-2, -1) / np.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    weights = softmax(scores, dim=-1)
    return weights @ V

Complexity

  • Time: O(n^2 * d) where n = sequence length, d = dimension
  • Memory: O(n^2) for the attention matrix
  • This quadratic scaling is the primary bottleneck for long sequences

Multi-Head Attention

Run h attention heads in parallel, each with its own projections:

head_i = Attention(X * W_Q^i, X * W_K^i, X * W_V^i)
MultiHead(X) = Concat(head_1, ..., head_h) * W_O

Where W_Q^i, W_K^i: (d_model x d_k), W_V^i: (d_model x d_v), d_k = d_v = d_model / h.

Each head can learn different attention patterns: syntactic, semantic, positional, etc.

Parameter count: 4 * d_model^2 (Q, K, V projections + output projection).

Positional Encoding

Self-attention is permutation-equivariant -- it has no notion of position. Positional information must be injected.

Sinusoidal (Original Transformer)

PE(pos, 2i)   = sin(pos / 10000^{2i/d_model})
PE(pos, 2i+1) = cos(pos / 10000^{2i/d_model})

Properties:

  • Deterministic, no learned parameters
  • Can generalize to unseen sequence lengths (in theory)
  • Each dimension oscillates at a different frequency
  • PE(pos+k) can be expressed as a linear function of PE(pos)

Rotary Position Embedding (RoPE)

Encode position by rotating the query and key vectors:

f(x, pos) = R(pos) * x

where R(pos) is a block-diagonal rotation matrix:
R(pos) = diag(R_1(pos), R_2(pos), ..., R_{d/2}(pos))

R_i(pos) = [[cos(pos*theta_i), -sin(pos*theta_i)],
             [sin(pos*theta_i),  cos(pos*theta_i)]]

theta_i = 10000^{-2i/d}

Key insight: the dot product q^T k depends on relative position (pos_q - pos_k), giving relative position encoding without explicit pairwise computation.

Used in LLaMA, PaLM, and most modern LLMs.

ALiBi (Attention with Linear Biases)

Add a linear bias to attention scores based on distance:

softmax(q_i^T k_j - m * |i - j|)

where m is a head-specific slope (geometric sequence from 2^{-8/n} to 2^{-8}).

  • No positional embeddings at all -- bias directly in attention
  • Strong length extrapolation
  • Used in BLOOM, MPT

Transformer Architecture

Encoder Block

x = x + MultiHeadAttention(LayerNorm(x))    # Pre-norm variant
x = x + FFN(LayerNorm(x))

FFN(x) = GELU(x * W_1 + b_1) * W_2 + b_2   # d_model -> d_ff -> d_model

d_ff is typically 4 * d_model. The FFN acts as a per-position "memory" storing learned patterns.

Decoder Block

Same as encoder but with:

  1. Causal mask: prevent attending to future positions (lower triangular mask)
  2. Cross-attention: queries from decoder, keys/values from encoder output

Pre-Norm vs Post-Norm

  • Post-norm (original): LayerNorm after residual addition. Harder to train, needs careful warmup.
  • Pre-norm: LayerNorm before attention/FFN. More stable training, slightly worse final performance without tuning.

Most modern models use pre-norm.

Key Architectures

BERT (Bidirectional Encoder Representations from Transformers)

  • Architecture: encoder-only transformer
  • Pretraining: masked language modeling (predict [MASK] tokens, 15% randomly) + next sentence prediction
  • Fine-tuning: add task-specific head, fine-tune all parameters
  • Sizes: BERT-base (110M, 12 layers, 768 dim), BERT-large (340M, 24 layers, 1024 dim)

Best for: classification, NER, question answering, sentence similarity.

GPT (Generative Pre-trained Transformer)

  • Architecture: decoder-only transformer with causal masking
  • Pretraining: next-token prediction (autoregressive language modeling)
  • Inference: generate tokens one at a time, each conditioned on all previous
  • Scaling: GPT-2 (1.5B), GPT-3 (175B), GPT-4 (undisclosed)

The scaling laws (Kaplan et al., 2020): loss scales as power laws in model size, data size, and compute.

T5 (Text-to-Text Transfer Transformer)

  • Architecture: encoder-decoder
  • Approach: frame every NLP task as text-to-text (e.g., "translate English to French: ...")
  • Pretraining: span corruption (mask contiguous spans, predict them)

Vision Transformer (ViT)

Apply transformer to images by treating patches as tokens:

1. Split image into P x P patches (e.g., 16x16)
2. Flatten and linearly project each patch to d_model
3. Prepend [CLS] token, add positional embeddings
4. Apply standard transformer encoder
5. Classify using [CLS] token representation

Requires large datasets or strong regularization. With sufficient data/pretraining, matches or exceeds CNNs.

Efficient Transformers

FlashAttention

Exact attention computation using tiling and kernel fusion:

  1. Divide Q, K, V into blocks that fit in SRAM (fast memory)
  2. Compute attention block-by-block using online softmax trick
  3. Never materialize the full n x n attention matrix in HBM (slow memory)
Standard attention:   O(n^2) memory, many HBM reads
FlashAttention:       O(n) memory, fewer HBM reads (2-4x faster)

FlashAttention-2/3 further optimize with better work partitioning across GPU warps.

Multi-Query Attention (MQA) and Grouped-Query Attention (GQA)

  • MQA: all heads share one set of K, V projections. Reduces KV cache size during inference.
  • GQA: group heads, share K, V within groups. Compromise between MHA and MQA.
  • GQA with g groups: g sets of K, V shared across h/g heads each.

KV Cache

During autoregressive generation, cache computed K, V for all previous tokens to avoid recomputation:

Memory per token: 2 * n_layers * d_model * bytes_per_param
For a 7B model: ~0.5 MB per token
At 128K context: ~64 GB just for KV cache

This motivates MQA/GQA and quantized KV caches.

Other Efficiency Techniques

| Method | Approach | Complexity | |---------------------|---------------------------------------|---------------| | Sparse attention | Attend to subset of positions | O(n * sqrt(n))| | Linear attention | Replace softmax with kernel approx | O(n * d^2) | | Sliding window | Local attention within fixed window | O(n * w) | | Ring attention | Distribute long sequences across GPUs | O(n^2/p) |

Mixture of Experts (MoE)

Replace the dense FFN with a sparse mixture:

y = sum_{i=1}^{E} G(x)_i * Expert_i(x)

where G(x) is a gating network that routes each token to the top-k experts (typically k=1 or 2 out of E=8..64 experts).

def moe_layer(x, experts, gate, top_k=2):
    gate_scores = softmax(gate(x))              # (batch, n_experts)
    top_k_scores, top_k_indices = topk(gate_scores, top_k)

    output = zeros_like(x)
    for i in range(top_k):
        expert_idx = top_k_indices[:, i]
        expert_out = experts[expert_idx](x)
        output += top_k_scores[:, i:i+1] * expert_out

    return output

Benefits:

  • Scale model parameters without proportional compute increase
  • Each token uses only k/E of total parameters
  • Mixtral 8x7B: 47B total params but ~13B active per token

Challenges: load balancing (auxiliary loss to encourage uniform expert usage), training instability, communication overhead in distributed settings.

Training at Scale

Key Practices

  • Mixed precision (FP16/BF16 + FP32 master weights): 2x memory savings, faster matmuls
  • Gradient accumulation: simulate larger batch sizes
  • Data parallelism: replicate model, split data across GPUs
  • Tensor parallelism: split individual layers across GPUs
  • Pipeline parallelism: split layers across GPUs, micro-batch pipelining
  • ZeRO: shard optimizer states, gradients, and parameters across GPUs

Tokenization

  • BPE (Byte Pair Encoding): iteratively merge most frequent character pairs
  • WordPiece: similar to BPE but uses likelihood-based merging
  • SentencePiece: language-agnostic, treats input as raw bytes/characters
  • Typical vocab size: 32K-128K tokens