Experiment Tracking with Weights & Biases
Definition
Weights & Biases (WandB) is a machine learning experiment tracking and visualization platform that helps ML practitioners log, compare, and reproduce experiments. It provides a centralized dashboard to track hyperparameters, metrics, model artifacts, datasets, and training logs across multiple runs. WandB automatically captures system metrics like GPU/CPU usage, memory consumption, and network I/O. The platform supports collaborative features for team projects, enabling shared experiment workspaces and report generation. Key features include hyperparameter sweeps for distributed optimization, artifact versioning for datasets and models, and integration with major ML frameworks including PyTorch, TensorFlow, Keras, scikit-learn, and Hugging Face. WandB also provides model registry capabilities for managing model lifecycle from experimentation to production, along with visualization tools for metrics, predictions, attention maps, and model graphs.
Intuition
Imagine you're a scientist running dozens of experiments in a lab. Instead of scribbling results in scattered notebooks that get lost or mixed up, WandB is like having a digital lab assistant that: automatically logs every experiment detail, organizes results in a searchable database, creates charts showing what worked and what didn't, lets your team see and build on each other's work, and keeps perfect records so any experiment can be reproduced exactly. No more wondering 'what learning rate did I use last Tuesday?'
Mathematical Formula
Step-by-Step Explanation:
- \(\sigma_{env}\): environment variance (dependencies, hardware)
- \(\sigma_{data}\): dataset version variance
- \(\sigma_{code}\): code version variance
- \(\sigma_{params}\): hyperparameter variance
- R: reproducibility score approaching 1.0 for perfect reproducibility
- WandB reduces all variances by logging and versioning
Real-World Use Cases
Paper reproduction: Share exact hyperparameters, code versions, and dataset artifacts to enable other researchers to reproduce published results and verify claims.
Model comparison at scale: Teams of 20+ ML engineers run concurrent experiments, comparing hundreds of model variants across the organization using shared dashboards.
Regulatory compliance: Maintain complete audit trails of all experiments, data versions, and hyperparameters for FDA/regulatory submission requirements.
Resource optimization: Track GPU utilization across training jobs to identify waste and optimize cloud compute spending by 30-40%.
Implementation
Manual Implementation (No Libraries)
# Manual experiment tracking - error-prone
import json
import csv
import datetime
import os
# Create experiment directory
experiment_id = f'exp_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}'
os.makedirs(f'experiments/{experiment_id}', exist_ok=True)
# Manually log hyperparameters
config = {
'learning_rate': 0.001,
'batch_size': 32,
'epochs': 100,
'optimizer': 'adam',
'architecture': 'resnet50'
}
with open(f'experiments/{experiment_id}/config.json', 'w') as f:
json.dump(config, f)
# Manual metrics logging during training
train_losses = []
val_losses = []
val_accuracies = []
for epoch in range(100):
# ... training code ...
train_loss = 0.5 # computed
val_loss = 0.4
val_acc = 0.85
# Manual logging
train_losses.append(train_loss)
val_losses.append(val_loss)
val_accuracies.append(val_acc)
# Save to CSV
with open(f'experiments/{experiment_id}/metrics.csv', 'a', newline='') as f:
writer = csv.writer(f)
writer.writerow([epoch, train_loss, val_loss, val_acc])
print(f'Epoch {epoch}: train={train_loss:.4f}, val={val_loss:.4f}')
# Save model
torch.save(model.state_dict(), f'experiments/{experiment_id}/model.pt')
# Problems:
# - No visualization built-in
# - Hard to compare across experiments
# - No automatic artifact versioning
# - Manual code is error-prone
# - No system metrics
# - Difficult to share with team
Using Libraries (wandb, torch)
# Weights & Biases experiment tracking
import wandb
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
# Initialize run
wandb.init(
project='image-classification',
name='resnet50_baseline',
config={
'learning_rate': 0.001,
'batch_size': 32,
'epochs': 100,
'optimizer': 'adam',
'architecture': 'resnet50',
'dropout': 0.5
}
)
# Access config
config = wandb.config
# Create model (config is automatically tracked)
model = ResNet50(dropout=config.dropout)
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
criterion = nn.CrossEntropyLoss()
# Training loop with automatic logging
for epoch in range(config.epochs):
model.train()
train_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
# Validation
model.eval()
val_loss = 0.0
correct = 0
with torch.no_grad():
for data, target in val_loader:
output = model(data)
val_loss += criterion(output, target).item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
val_accuracy = correct / len(val_loader.dataset)
# Log metrics (automatically creates charts)
wandb.log({
'epoch': epoch,
'train_loss': train_loss / len(train_loader),
'val_loss': val_loss / len(val_loader),
'val_accuracy': val_accuracy,
'learning_rate': optimizer.param_groups[0]['lr']
})
# Log sample predictions
if epoch % 10 == 0:
images = wandb.Image(data[:8], caption=f'Epoch {epoch}')
wandb.log({'predictions': images})
# Save model as artifact
artifact = wandb.Artifact('model', type='model')
artifact.add_file('model.pt')
wandb.log_artifact(artifact)
# Finish run
wandb.finish()
# Hyperparameter Sweeps
# Define sweep configuration
sweep_config = {
'method': 'bayes',
'metric': {'name': 'val_accuracy', 'goal': 'maximize'},
'parameters': {
'learning_rate': {'distribution': 'log_uniform', 'min': 0.0001, 'max': 0.1},
'batch_size': {'values': [16, 32, 64, 128]},
'dropout': {'distribution': 'uniform', 'min': 0.1, 'max': 0.5},
'optimizer': {'values': ['adam', 'sgd', 'adamw']}
}
}
# Initialize sweep
sweep_id = wandb.sweep(sweep_config, project='image-classification')
# Training function for sweep
def train_sweep(config=None):
with wandb.init(config=config):
config = wandb.config
# ... training code using config parameters ...
pass
# Run sweep agents
wandb.agent(sweep_id, train_sweep, count=20)
# Dataset and Model Versioning
# Log dataset
artifact = wandb.Artifact('dataset-v1', type='dataset')
artifact.add_dir('data/processed/')
wandb.log_artifact(artifact)
# Use specific artifact version
dataset_artifact = wandb.use_artifact('dataset-v1:latest')
dataset_path = dataset_artifact.download()
# Model registry
# Register best model for production
model_artifact = wandb.use_artifact('model:v15')
wandb.run.link_artifact(model_artifact, 'model-registry/production-model')
When to Use
✅ Appropriate Use Cases:
- Multiple team members working on ML experiments
- Need to compare hyperparameter configurations
- Long-running training jobs requiring monitoring
- Reproducibility requirements for research or compliance
- Distributed training across multiple GPUs/machines
- Hyperparameter optimization (sweeps)
- Model and dataset versioning
- Sharing results with stakeholders
- Debugging training issues with rich visualizations
- Tracking compute resource utilization
- Building experiment reports for presentations
❌ Avoid When:
- Simple, one-off experiments that don't need tracking
- Offline-only environments with no internet access
- Very sensitive data that cannot leave premises (use local mode)
- When latency is critical (logging overhead)
- Projects with strict data residency requirements
- Very simple models with no hyperparameter tuning
- When using purely deterministic baselines with no variance
Common Pitfalls
- Logging too many metrics (increases overhead)
- Not setting wandb.init() before logging
- Forgetting to call wandb.finish() (resource leaks)
- Logging sensitive data without proper configuration
- Not versioning datasets alongside models
- Running sweeps without proper resource limits
- Not using meaningful run names or tags
- Logging entire models frequently (storage costs)
- Not organizing projects properly (hard to find experiments)
- Ignoring the 'offline mode' for air-gapped environments
- Not setting up team workspaces for collaboration
- Logging images/plots at every step (slow)
- Not using artifacts for reproducibility