HiRoPE: Hierarchical Rotary Position Embedding
Definition
Hierarchical Rotary Position Embedding (HiRoPE) is an extension of RoPE designed specifically for modeling extremely long sequences beyond the capabilities of standard position encodings. While RoPE excels at relative position modeling, its performance degrades when extrapolating to sequences significantly longer than those seen during training. HiRoPE addresses this by introducing a hierarchical decomposition of position information, combining coarse-grained block-level positions with fine-grained intra-block positions. This approach is inspired by how humans handle long documents - we track which chapter/section we're in (coarse) and our position within that section (fine). HiRoPE has been particularly effective in long-context language models like LongLoRA and has influenced approaches to context window extension in models like Llama-2-Long. The key insight is that long-range dependencies can be captured more efficiently through hierarchical position structures than through linear position encodings.
Intuition
Imagine reading a very long novel. Instead of remembering 'I'm on word 50,000', you think 'I'm in Chapter 5, about halfway through the chapter.' This is hierarchical position awareness - you track both which block you're in and where you are within that block. HiRoPE works similarly: it divides a long sequence into blocks (like chapters) and encodes two separate positions for each token - its block index (coarse position) and its offset within the block (fine position). When computing attention, the model can efficiently handle both local context (within-block attention) and global context (across-block attention) using different position encodings for each. It's like having a GPS that tells you both which city you're in and your exact street address. This hierarchical structure allows the model to scale to much longer sequences because the position space grows sub-linearly - instead of needing unique encodings for 100k positions, you might only need encodings for 100 blocks and 1000 positions within blocks.
Mathematical Formula
Step-by-Step Explanation:
- Position Decomposition: Absolute position m splits into block index and intra-block offset
- Block-Level RoPE: Lower frequencies for capturing long-range block relationships
- Intra-Block RoPE: Higher frequencies for fine-grained within-block positions
- Combined HiRoPE: Sequential application of both rotation matrices (order matters)
- Local Attention: Uses only intra-block positions for nearby token interactions
- Global Attention: Uses combined positions or only block positions for distant tokens
Real-World Use Cases
Processing legal contracts and research papers with 100k+ tokens
Understanding entire codebases where context spans multiple files
Maintaining coherent dialogue over extended multi-turn conversations
Analyzing DNA sequences where patterns span millions of base pairs
Processing hour-long videos with frame-level attention
Training models on entire books without chunking
Implementation
Manual Implementation (No Libraries)
import torch
import torch.nn as nn
import math
class HiRoPE(nn.Module):
def __init__(self, dim, block_size=1024, max_blocks=128, local_base=10000, global_base=100000):
super().__init__()
self.dim = dim
self.block_size = block_size
self.max_blocks = max_blocks
self.local_base = local_base
self.global_base = global_base
# Local (intra-block) frequencies - standard RoPE
self.local_inv_freq = 1.0 / (local_base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('local_inv_freq', self.local_inv_freq)
# Global (inter-block) frequencies - much slower rotation
self.global_inv_freq = 1.0 / (global_base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('global_inv_freq', self.global_inv_freq)
# Precompute local rotations for block_size
local_pos = torch.arange(block_size, dtype=torch.float32)
local_freqs = torch.einsum('i,j->ij', local_pos, self.local_inv_freq)
local_emb = torch.cat([local_freqs, local_freqs], dim=-1)
self.register_buffer('local_cos', local_emb.cos())
self.register_buffer('local_sin', local_emb.sin())
# Precompute global rotations
global_pos = torch.arange(max_blocks, dtype=torch.float32)
global_freqs = torch.einsum('i,j->ij', global_pos, self.global_inv_freq)
global_emb = torch.cat([global_freqs, global_freqs], dim=-1)
self.register_buffer('global_cos', global_emb.cos())
self.register_buffer('global_sin', global_emb.sin())
def decompose_position(self, positions):
block_indices = positions // self.block_size
offsets = positions % self.block_size
return block_indices, offsets
def rotate_half(self, x):
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat([-x2, x1], dim=-1)
def apply_rotation(self, x, cos, sin):
return (x * cos) + (self.rotate_half(x) * sin)
def forward(self, x, positions=None, use_hierarchical=True):
batch, seq_len, dim = x.shape
if positions is None:
positions = torch.arange(seq_len, device=x.device)
block_indices, offsets = self.decompose_position(positions)
if use_hierarchical:
# Apply global rotation based on block index
global_cos = self.global_cos[block_indices]
global_sin = self.global_sin[block_indices]
x = self.apply_rotation(x, global_cos, global_sin)
# Apply local rotation based on intra-block offset
local_cos = self.local_cos[offsets]
local_sin = self.local_sin[offsets]
x = self.apply_rotation(x, local_cos, local_sin)
return x
class HierarchicalAttention(nn.Module):
def __init__(self, d_model, num_heads, block_size=1024, max_blocks=128, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.block_size = block_size
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.hirope = HiRoPE(self.head_dim, block_size, max_blocks)
self.dropout = nn.Dropout(dropout)
def create_hierarchical_mask(self, seq_len, device):
positions = torch.arange(seq_len, device=device)
block_indices = positions // self.block_size
# Local mask: attend to tokens in same block and adjacent blocks
block_diff = block_indices.unsqueeze(0) - block_indices.unsqueeze(1)
local_mask = (block_diff.abs() <= 1).float()
# Global mask: sparse attention across blocks
global_mask = ((positions.unsqueeze(0) - positions.unsqueeze(1)) % self.block_size == 0).float()
return local_mask, global_mask
def forward(self, x, hierarchical=True):
batch, seq_len, _ = x.shape
# Linear projections
q = self.W_q(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.W_k(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.W_v(x).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Apply HiRoPE
positions = torch.arange(seq_len, device=x.device)
# Flatten for HiRoPE application
q_flat = q.transpose(1, 2).reshape(-1, self.head_dim)
k_flat = k.transpose(1, 2).reshape(-1, self.head_dim)
q_rot = self.hirope(q_flat.unsqueeze(1), positions.repeat(batch * self.num_heads), hierarchical).squeeze(1)
k_rot = self.hirope(k_flat.unsqueeze(1), positions.repeat(batch * self.num_heads), hierarchical).squeeze(1)
# Reshape back
q = q_rot.view(batch, self.num_heads, seq_len, self.head_dim)
k = k_rot.view(batch, self.num_heads, seq_len, self.head_dim)
# Compute attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if hierarchical:
local_mask, global_mask = self.create_hierarchical_mask(seq_len, x.device)
# Combine masks - use local for close tokens, global for distant
distance = torch.arange(seq_len, device=x.device).unsqueeze(0) - torch.arange(seq_len, device=x.device).unsqueeze(1)
distance = distance.abs()
combined_mask = torch.where(distance < self.block_size, local_mask, global_mask)
scores = scores.masked_fill(combined_mask.unsqueeze(0).unsqueeze(0) == 0, -1e9)
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
output = torch.matmul(attn_weights, v)
output = output.transpose(1, 2).contiguous().view(batch, seq_len, self.d_model)
output = self.W_o(output)
return output
# Test HiRoPE
print('Testing HiRoPE...')
hirope = HiRoPE(dim=64, block_size=1024, max_blocks=128)
x = torch.randn(1, 5000, 64) # 5000 tokens
# Standard RoPE for comparison
positions = torch.arange(5000)
block_indices, offsets = hirope.decompose_position(positions)
print(f'Sequence length: 5000')
print(f'Block size: {hirope.block_size}')
print(f'Number of blocks spanned: {block_indices.max().item() + 1}')
print(f'Position 2500 -> Block: {block_indices[2500].item()}, Offset: {offsets[2500].item()}')
print(f'Position 4999 -> Block: {block_indices[4999].item()}, Offset: {offsets[4999].item()}')
# Test hierarchical attention\h_attn = HierarchicalAttention(d_model=512, num_heads=8, block_size=512)
x = torch.randn(2, 2000, 512)
output = h_attn(x, hierarchical=True)
print(f'
Hierarchical Attention output shape: {output.shape}')
Using Libraries (torch, flash_attn)
import torch
import torch.nn as nn
import math
# LongLoRA-style HiRoPE implementation
class LongLoRAHiRoPE(nn.Module):
def __init__(self, dim, max_position_embeddings=8192, original_max_position=2048,
block_size=256, scale_factor=4):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.original_max_position = original_max_position
self.block_size = block_size
self.scale_factor = scale_factor
# Standard RoPE frequencies
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# Precompute standard RoPE embeddings
t = torch.arange(original_max_position, dtype=torch.float32)
freqs = torch.einsum('i,j->ij', t, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer('cos_cached', emb.cos())
self.register_buffer('sin_cached', emb.sin())
def forward(self, x, seq_len=None, use_hierarchical=True):
if seq_len is None:
seq_len = x.shape[1]
if seq_len <= self.original_max_position:
# Use standard RoPE for short sequences
cos = self.cos_cached[:seq_len]
sin = self.sin_cached[:seq_len]
return self.apply_rotary_emb(x, cos, sin)
if use_hierarchical:
# Use HiRoPE for long sequences
return self.apply_hierarchical_rotary(x, seq_len)
else:
# Position interpolation for comparison
return self.apply_interpolated_rotary(x, seq_len)
def apply_rotary_emb(self, x, cos, sin):
# x shape: [batch, seq_len, num_heads, head_dim]
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim]
sin = sin.unsqueeze(0).unsqueeze(2)
x1, x2 = x[..., ::2], x[..., 1::2]
rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
return x * cos + rot_x * sin
def apply_hierarchical_rotary(self, x, seq_len):
# Decompose positions
positions = torch.arange(seq_len, device=x.device)
block_idx = positions // self.block_size
offset = positions % self.block_size
# Create hierarchical embeddings
# Group tokens by block for shifted attention
num_blocks = (seq_len + self.block_size - 1) // self.block_size
output = torch.zeros_like(x)
for b in range(num_blocks):
start = b * self.block_size
end = min((b + 1) * self.block_size, seq_len)
# Use standard RoPE within each block
block_positions = offset[start:end]
cos = self.cos_cached[block_positions]
sin = self.sin_cached[block_positions]
output[:, start:end] = self.apply_rotary_emb(x[:, start:end], cos, sin)
return output
def apply_interpolated_rotary(self, x, seq_len):
# Position interpolation method
scale = seq_len / self.original_max_position
positions = torch.arange(seq_len, device=x.device) / scale
# Interpolate in frequency domain
freqs = torch.einsum('i,j->ij', positions.float(), self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
return self.apply_rotary_emb(x, emb.cos(), emb.sin())
# Integration with Flash Attention for efficiency
try:
from flash_attn import flash_attn_func
class FlashHiRoPEAttention(nn.Module):
def __init__(self, dim, num_heads, block_size=1024):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.block_size = block_size
self.W_qkv = nn.Linear(dim, 3 * dim)
self.W_o = nn.Linear(dim, dim)
self.hirope = LongLoRAHiRoPE(self.head_dim, block_size=block_size)
def forward(self, x):
batch, seq_len, _ = x.shape
# QKV projection
qkv = self.W_qkv(x).reshape(batch, seq_len, 3, self.num_heads, self.head_dim)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
# Apply HiRoPE
q = self.hirope(q, seq_len)
k = self.hirope(k, seq_len)
# Flash attention
output = flash_attn_func(q, k, v, causal=True)
output = output.reshape(batch, seq_len, self.dim)
return self.W_o(output)
print('Flash Attention integration available')
except ImportError:
print('Flash Attention not available - using standard attention')
# Hugging Face Transformers integration pattern
def patch_model_with_hirope(model, block_size=1024):
for name, module in model.named_modules():
if hasattr(module, 'rotary_emb'):
original_forward = module.rotary_emb.forward
dim = module.rotary_emb.inv_freq.shape[0] * 2
hirope = LongLoRAHiRoPE(dim, block_size=block_size)
module.rotary_emb = hirope
print(f'Patched {name} with HiRoPE')
return model
# NTK-aware HiRoPE scaling
class NTKAwareHiRoPE(nn.Module):
def __init__(self, dim, max_position_embeddings=32768, base=10000,
ntk_factor=1.0, extrapolation_factor=1.0):
super().__init__()
# NTK-aware scaling of base frequency
adjusted_base = base * ntk_factor ** (dim / (dim - 2))
inv_freq = 1.0 / (adjusted_base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.ntk_factor = ntk_factor
self.extrapolation_factor = extrapolation_factor
def forward(self, x, seq_len):
# Dynamic frequency adjustment based on sequence length
scale = max(1.0, seq_len / self.max_position_embeddings * self.extrapolation_factor)
adjusted_inv_freq = self.inv_freq / scale
t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
freqs = torch.einsum('i,j->ij', t, adjusted_inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
cos, sin = emb.cos(), emb.sin()
# Apply rotary embeddings
cos = cos.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2)
x1, x2 = x[..., ::2], x[..., 1::2]
rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
return x * cos + rot_x * sin
print('
HiRoPE implementations loaded successfully')
When to Use
✅ Appropriate Use Cases:
- Sequences longer than standard RoPE can handle (>8k tokens)
- Document-level understanding where chunking loses coherence
- Long-context language modeling with coherent attention patterns
- When combining local and global attention patterns is beneficial
- Video/audio processing with long temporal dependencies
- Genomic/proteomic sequence analysis with very long inputs
❌ Avoid When:
- Short sequences (<4k tokens) where standard RoPE suffices
- When using ALiBi which naturally handles extrapolation
- If Flash Attention is required but doesn't support hierarchical patterns
- When training from scratch on long sequences (use standard RoPE with proper init)
- Tasks requiring fine-grained position awareness at all distances
- If the model architecture uses learned position embeddings exclusively
Common Pitfalls
- Block size too small causing excessive block boundaries and fragmentation
- Block size too large losing hierarchical benefits
- Using global attention everywhere (defeats the purpose of hierarchy)
- Not training with the same hierarchical pattern used at inference
- Mismatched block sizes between different layers of the model
- Forgetting to handle remainder sequences at end of input
- Using too many different frequency bases (keep local vs global simple)
- Not validating that relative positions work correctly across block boundaries