Transformer Architecture: Attention Is All You Need

Advanced Deep Learning
~8 min read Deep Learning

Definition

The Transformer architecture, introduced by Vaswani et al. in the seminal 2017 paper 'Attention Is All You Need', revolutionized deep learning by replacing recurrence and convolution entirely with attention mechanisms. Transformers process sequences in parallel rather than sequentially, enabling efficient training on massive datasets and leading to breakthrough models like GPT, BERT, and T5. The architecture consists of an encoder (for understanding input) and decoder (for generating output), each built from stacked identical layers. Core components are multi-head self-attention (allowing each position to attend to all positions), position-wise feed-forward networks, residual connections, and layer normalization. Modern Large Language Models (LLMs) like GPT-4, Claude, and Llama are based on decoder-only Transformer variants, while encoder-only models (BERT) excel at understanding tasks and encoder-decoder models (T5) handle sequence-to-sequence tasks.

Intuition

💡

Imagine a group of experts in a meeting where everyone can simultaneously listen to everyone else, weighting each person's contribution by relevance. Traditional RNNs are like a single-file line where each person only talks to the person in front of them - slow and limited context. Transformers are like that meeting: every token (word/subword) can directly connect to every other token, with attention scores determining 'how much should I listen to you?' Multi-head attention is like having multiple meetings in parallel, each focusing on different aspects - one meeting tracks syntax, another tracks semantics, another tracks pronoun references. The feed-forward layers are like each expert privately processing what they learned from the meeting before the next round. Position encodings give each participant a name tag saying 'I am word #5', because unlike RNNs, Transformers process words simultaneously and need to know their order.

Mathematical Formula

Scaled Dot-Product Attention:
\[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \]
Multi-Head Attention:
\[ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O \]
\[ \text{where head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) \]
Feed-Forward Network:
\[ \text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2 \]
Layer Normalization:
\[ \text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta \]
Positional Encoding (Sinusoidal):
\[ PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}}) \]
\[ PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}}) \]
Self-Attention Complexity:
\[ \text{Time: } O(n^2 \cdot d) \quad \text{Space: } O(n^2) \]

Step-by-Step Explanation:

  1. Scaled Dot-Product: Query-Key dot products scaled by \(1/\sqrt{d_k}\) for stable softmax, multiplied by Values
  2. Multi-Head: Projects Q,K,V into h subspaces, applies attention in parallel, concatenates results
  3. Feed-Forward: Two-layer MLP with ReLU activation applied position-wise (same network for each position)
  4. LayerNorm: Normalizes across feature dimension, then learns scale \(\gamma\) and shift \(eta\) parameters
  5. Positional Encoding: Adds position information through sinusoidal functions of varying frequencies
  6. Complexity: Quadratic in sequence length n, linear in model dimension d - main computational bottleneck

Real-World Use Cases

Large Language Models

GPT-4, Claude, Llama using decoder-only Transformers for text generation

Machine Translation

Google Translate using Transformer encoder-decoder for 100+ languages

Code Generation

GitHub Copilot generating code completions using GPT-style models

Computer Vision

Vision Transformers (ViT) achieving SOTA on ImageNet without convolutions

Scientific Discovery

AlphaFold2 using Transformers for protein structure prediction

Multimodal AI

CLIP and GPT-4V processing both text and images with shared Transformer architecture

Implementation

Manual Implementation (No Libraries)

This implements core Transformer components: MultiHeadAttention with parallel attention heads, FeedForward with ReLU activation, and TransformerEncoderLayer with pre-norm residual connections. PositionalEncoding adds sinusoidal position information. The encoder stacks multiple layers, each performing self-attention then feed-forward processing with residual connections and layer normalization.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        output = torch.matmul(attn_weights, V)
        
        return output, attn_weights
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # Linear projections and reshape for multi-head
        Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # Apply attention
        attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Concatenate heads and apply final linear
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(attn_output)
        
        return output, attn_weights

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.linear1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Self-attention with residual connection
        attn_output, _ = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward with residual connection
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length=5000, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_length)
        
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(d_model)
    
    def forward(self, x, mask=None):
        x = self.embedding(x) * self.scale
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        for layer in self.layers:
            x = layer(x, mask)
        
        return x

# Test
vocab_size = 10000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048\
encoder = TransformerEncoder(vocab_size, d_model, num_heads, num_layers, d_ff)
src = torch.randint(0, vocab_size, (2, 20))  # batch_size=2, seq_len=20
output = encoder(src)
print(f'Output shape: {output.shape}')  # [2, 20, 512]

Using Libraries (torch, torch.nn, transformers, tensorflow, keras)

import torch
import torch.nn as nn
import torch.nn.functional as F

# Using PyTorch's built-in Transformer
class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, num_classes, dropout=0.1):
        super(TransformerClassifier, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        
        encoder_layers = nn.TransformerEncoderLayer(d_model, num_heads, dim_feedforward=4*d_model, 
                                                    dropout=dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        
        self.fc = nn.Linear(d_model, num_classes)
        self.d_model = d_model
    
    def forward(self, src, src_mask=None):
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        # Global average pooling
        output = output.mean(dim=1)
        output = self.fc(output)
        return output

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + self.pe[:x.size(1), :].transpose(0, 1)
        return self.dropout(x)

# GPT-style Decoder-only Transformer
class GPTDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, max_seq_len=512, dropout=0.1):
        super(GPTDecoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)
        
        decoder_layer = nn.TransformerDecoderLayer(d_model, num_heads, dim_feedforward=4*d_model, 
                                                   dropout=dropout, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        
        self.max_seq_len = max_seq_len
        self.d_model = d_model
    
    def forward(self, idx, targets=None):
        b, t = idx.size()
        
        # Token + positional embeddings
        tok_emb = self.embedding(idx)
        pos_emb = self.pos_embedding(torch.arange(t, device=idx.device))
        x = tok_emb + pos_emb
        
        # Causal mask
        causal_mask = torch.triu(torch.ones(t, t, device=idx.device), diagonal=1).bool()
        
        # Transformer decoder (with memory=None for autoregressive)
        x = self.transformer_decoder(x, memory=torch.zeros((b, 1, self.d_model), device=idx.device), \
                                     tgt_mask=~causal_mask)
        x = self.ln_f(x)
        logits = self.head(x)
        
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        
        return logits, loss

# Hugging Face Transformers
from transformers import GPT2LMHeadModel, GPT2Tokenizer, BertModel, BertTokenizer

# Load pre-trained GPT-2
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

# Generate text
text = 'The future of AI is'
inputs = tokenizer(text, return_tensors='pt')
outputs = model.generate(**inputs, max_length=50, num_return_sequences=1)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)

# BERT for encoding
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

encoded = bert_tokenizer('Hello world', return_tensors='pt')
bert_output = bert_model(**encoded)
print(f'Last hidden state shape: {bert_output.last_hidden_state.shape}')

# TensorFlow/Keras
import tensorflow as tf

def create_transformer_encoder(vocab_size, d_model, num_heads, num_layers, seq_len):
    inputs = tf.keras.Input(shape=(seq_len,))
    
    # Embedding + positional encoding
    embedding = tf.keras.layers.Embedding(vocab_size, d_model)(inputs)
    positions = tf.keras.layers.Embedding(seq_len, d_model)(tf.range(seq_len))
    x = embedding + positions
    
    # Transformer encoder layers
    for _ in range(num_layers):
        # Multi-head attention
        attn_output = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model//num_heads)(x, x)
        x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x + attn_output)
        
        # Feed-forward
        ff_output = tf.keras.layers.Dense(d_model * 4, activation='relu')(x)
        ff_output = tf.keras.layers.Dense(d_model)(ff_output)
        x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x + ff_output)
    
    return tf.keras.Model(inputs, x)

model_tf = create_transformer_encoder(10000, 512, 8, 6, 128)

When to Use

✅ Appropriate Use Cases:

  • Sequential data where long-range dependencies matter
  • Tasks requiring parallel processing of sequences (faster than RNNs)
  • Large-scale language modeling and text generation
  • When you have sufficient compute for quadratic attention complexity
  • Multi-modal tasks combining text, images, audio
  • Transfer learning with pre-trained models (BERT, GPT, T5)

❌ Avoid When:

  • Very long sequences (>10k tokens) where O(n²) attention is prohibitive
  • Resource-constrained environments (use RNNs or linear attention variants)
  • When strict causality must be enforced (use masked attention carefully)
  • Small datasets where RNNs with strong inductive bias perform better
  • Real-time low-latency applications (attention computation overhead)
  • When interpretability requires understanding of local feature hierarchies

Common Pitfalls

  • Forgetting causal masking in autoregressive decoders (leaks future information)
  • Not scaling attention by \(1/\sqrt{d_k}\) causing gradient instability
  • Using absolute position embeddings when relative positions matter more
  • Insufficient gradient clipping causing training divergence in deep models
  • Not handling variable-length sequences with proper padding and masking
  • Incorrect key/query/value dimensions breaking multi-head attention
  • Layer norm placement (pre-norm vs post-norm) affecting training stability
  • Attention dropout applied after softmax breaking probability distribution