Chain-of-Thought: Reasoning in Large Language Models
Definition
Chain-of-Thought (CoT) prompting is a technique that enables Large Language Models (LLMs) to solve complex reasoning tasks by generating intermediate reasoning steps before producing the final answer. Introduced by Wei et al. in 2022, CoT emerged as an emergent capability of sufficiently large models (typically >100B parameters) that manifests through in-context learning. The technique involves prompting the model with examples that demonstrate step-by-step reasoning, which the model then emulates. CoT dramatically improves performance on arithmetic, commonsense, and symbolic reasoning tasks. Variants include Zero-Shot CoT (adding 'Let's think step by step'), Automatic CoT, and Self-Consistency CoT (sampling multiple reasoning paths and taking majority vote). The underlying mechanism relates to how Transformers allocate attention across tokens and how training on code and reasoning corpora develops internal reasoning pathways.
Intuition
Imagine solving a complex math problem. Instead of trying to jump straight to the answer, you write out your work: 'First, I need to find x, then substitute into equation y, then solve for z.' Chain-of-Thought prompting does the same for AI - it asks the model to show its work. Without CoT, the model is like a student who blurts out answers without thinking; with CoT, it's like a student who methodically works through problems. The magic happens because Transformers process text token by token, and by forcing the model to generate intermediate steps, we create more tokens that carry reasoning information, allowing later tokens to attend to this reasoning. It's like giving the model more 'thinking time' through its own generated text. The model learned patterns of reasoning during pretraining on books, code, and explanations, and CoT prompts activate these learned reasoning pathways.
Mathematical Formula
Step-by-Step Explanation:
- CoT Decomposition: Final answer probability is sum over reasoning paths r, weighted by their likelihood
- Self-Consistency: Majority voting across k sampled reasoning chains to improve accuracy
- Training Loss: Standard next-token prediction extended over reasoning steps and answer
- Emergence Threshold: CoT capability appears abruptly around 100B parameters, showing phase transition
- Attention Pattern: Reasoning tokens receive higher attention weights, enabling step-by-step processing
Real-World Use Cases
Solving GSM8K word problems: 'John has 5 apples, gives away 2...' with step-by-step arithmetic
Generating code with explanation: 'First, I need to iterate over the list, then...'
Case reasoning: 'The precedent in Smith v. Jones establishes... therefore...'
Differential diagnosis: 'Given symptom A and B, possible causes include... most likely...'
Physics derivations showing formula substitutions and unit analysis
Solving Sudoku or logic grid puzzles by explaining elimination steps
Implementation
Manual Implementation (No Libraries)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
class ChainOfThoughtPrompt:
def __init__(self, model_name='gpt2'):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
# Standard CoT examples for few-shot prompting
self.cot_examples = {
'math': self._get_math_examples(),
'commonsense': self._get_commonsense_examples(),
'code': self._get_code_examples()
}
def _get_math_examples(self):
return \"\"\"Q: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?
A: Roger starts with 5 balls. He buys 2 cans with 3 balls each, so that's 2 * 3 = 6 balls. In total, 5 + 6 = 11 balls. The answer is 11.
Q: A juggler has 16 balls. Half are golf balls and half are tennis balls. He loses 3 golf balls. How many golf balls does he have left?
A: Half of 16 balls are golf balls, so that's 16 / 2 = 8 golf balls. He loses 3, so 8 - 3 = 5 golf balls. The answer is 5.
Q: {question}
A:\"\"\"
def _get_commonsense_examples(self):
return \"\"\"Q: John had a fever and took medicine. What happened to the fever?
A: Medicine typically reduces fever by fighting infection or reducing inflammation. John\'s fever likely went down.
Q: Why might someone carry an umbrella on a cloudy day?
A: Cloudy days often indicate rain. Carrying an umbrella provides protection from potential rain.
Q: {question}
A:\"\"\"
def _get_code_examples(self):
return \"\"\"Q: Write a function to reverse a string.
A: Let me think through this step by step. First, I need to take a string as input. Then I can use slicing with step -1 to reverse it. Here\'s the code:
```python
def reverse_string(s):
return s[::-1]
```
Q: Write a function to check if a number is prime.
A: To check if a number is prime, I need to verify it has no divisors other than 1 and itself. I\'ll check divisibility from 2 up to the square root of the number. Here\'s the implementation:
```python
def is_prime(n):
if n < 2:
return False
for i in range(2, int(n**0.5) + 1):
if n % i == 0:
return False
return True
```
Q: {question}
A:\"\"\"
def format_prompt(self, question, task_type='math', use_few_shot=True, zero_shot_trigger=''):
if use_few_shot:
prompt = self.cot_examples[task_type].format(question=question)
else:
# Zero-shot CoT with trigger
if zero_shot_trigger:
prompt = f'Q: {question}
A: {zero_shot_trigger}'
else:
prompt = f'Q: {question}
A:'
return prompt
def generate_cot(self, question, task_type='math', max_length=200, temperature=0.7,
use_few_shot=True, zero_shot_trigger='Let\'s think step by step.',
num_samples=1):
results = []
for _ in range(num_samples):
prompt = self.format_prompt(question, task_type, use_few_shot, zero_shot_trigger)
inputs = self.tokenizer(prompt, return_tensors='pt')
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=len(inputs['input_ids'][0]) + max_length,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.eos_token_id
)
generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract just the generated part (after the prompt)
response = generated[len(prompt):].strip()
\\ # Parse reasoning and answer
parsed = self._parse_response(response)
results.append({
'prompt': prompt,
'response': response,
'reasoning': parsed['reasoning'],
'answer': parsed['answer']
})
return results
def _parse_response(self, response):
# Try to extract answer after common markers
answer_markers = ['the answer is', 'answer:', 'therefore,', 'so,', 'thus,', '\*\*']
\\ answer = response
reasoning = response
\\ for marker in answer_markers:
if marker.lower() in response.lower():
parts = response.lower().split(marker.lower(), 1)
if len(parts) == 2:
reasoning = parts[0].strip()
answer = parts[1].strip()
break
\\ # Clean up answer
answer = answer.replace('.', '').replace('!', '').strip()
\\ return {'reasoning': reasoning, 'answer': answer}
\ def self_consistency(self, question, task_type='math', num_samples=10, threshold=0.5):
results = self.generate_cot(question, task_type, num_samples=num_samples)
\\ # Extract answers and count
answers = [r['answer'] for r in results]
answer_counts = {}
for ans in answers:
# Normalize answer for comparison
normalized = re.sub(r'[^\w\s]', '', ans.lower()).strip()
answer_counts[normalized] = answer_counts.get(normalized, 0) + 1
\\ # Find most common answer
if answer_counts:
best_answer = max(answer_counts.items(), key=lambda x: x[1])
confidence = best_answer[1] / num_samples
\ if confidence >= threshold:
return {
'answer': best_answer[0],
'confidence': confidence,
'all_results': results,
'answer_distribution': answer_counts
}
\\ return {
'answer': answers[0],
'confidence': answer_counts.get(answers[0], 1) / num_samples,
'all_results': results,
'answer_distribution': answer_counts
}
\ def evaluate_cot_quality(self, reasoning):
criteria = {
'has_steps': len(reasoning.split('.')) > 2 or 'first' in reasoning.lower(),
'has_logic_words': any(word in reasoning.lower() for word in ['because', 'therefore', 'so', 'thus', 'since']),
'has_numbers': any(char.isdigit() for char in reasoning),
'reasonable_length': 20 < len(reasoning) < 500
}
\\ score = sum(criteria.values()) / len(criteria)
return {
'criteria': criteria,
'quality_score': score,
'is_high_quality': score >= 0.75
}
# Demonstration
print('Chain of Thought Demonstration')
print('=' * 50)
# Initialize (using a small model for demo - in practice use larger models)
# For actual use, replace with a larger model like 'meta-llama/Llama-2-7b'
cot = ChainOfThoughtPrompt('gpt2')
# Example question
question = 'If a train travels 120 miles in 2 hours, how far will it travel in 5 hours at the same speed?'
print(f'Question: {question}')
print('
Few-shot CoT:')
results = cot.generate_cot(question, task_type='math', use_few_shot=True, num_samples=1)
print(f'Response: {results[0][\"response\"]}')
print('
Zero-shot CoT:')
results = cot.generate_cot(question, task_type='math', use_few_shot=False, zero_shot_trigger='Let\'s think step by step.', num_samples=1)
print(f'Response: {results[0][\"response\"]}')
# Quality evaluation
quality = cot.evaluate_cot_quality(results[0]['reasoning'])
print(f'
Reasoning Quality: {quality}')
Using Libraries (torch, transformers, openai, langchain)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import openai # For API-based models
import json
class CoTWithModernLLMs:
def __init__(self, model_name='meta-llama/Llama-2-7b-chat-hf', use_api=False):
\\ self.use_api = use_api
self.model_name = model_name
\\ if not use_api:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map='auto'
)
\\ def generate(self, prompt, max_new_tokens=512, temperature=0.7, top_p=0.9):
\\ if self.use_api:
return self._generate_api(prompt, max_new_tokens, temperature)
else:
return self._generate_local(prompt, max_new_tokens, temperature, top_p)
\\ def _generate_local(self, prompt, max_new_tokens, temperature, top_p):
\\ inputs = self.tokenizer(prompt, return_tensors='pt').to(self.model.device)
\\ with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.eos_token_id
)
return self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
\\ def _generate_api(self, prompt, max_tokens, temperature):
response = openai.ChatCompletion.create(
model=self.model_name,
messages=[{'role': 'user', 'content': prompt}],
max_tokens=max_tokens,
temperature=temperature
)
return response.choices[0].message.content
class AdvancedCoT:
def __init__(self, llm_wrapper):
self.llm = llm_wrapper
\\ def tree_of_thoughts(self, problem, num_branches=3, depth=2):
tree = {'problem': problem, 'branches': []}
for i in range(num_branches):
branch = self._explore_branch(problem, depth, branch_id=i)
tree['branches'].append(branch)
\\ # Evaluate and select best path
best_branch = self._evaluate_branches(tree['branches'])
return best_branch
\\ def _explore_branch(self, problem, depth, branch_id):
branch = {'thoughts': [], 'branch_id': branch_id}
current = problem
\\ for d in range(depth):
prompt = f'Given: {current}
Think step by step. What should I consider next?'
thought = self.llm.generate(prompt, max_new_tokens=100)
branch['thoughts'].append(thought)
current = f'{current}
Thought {d+1}: {thought}'
# Generate final answer
final_prompt = f'{current}
Based on these thoughts, what is the answer?'
branch['answer'] = self.llm.generate(final_prompt, max_new_tokens=150)
return branch
\\ def _evaluate_branches(self, branches):
# Simple evaluation based on coherence
scores = []
for branch in branches:
coherence_prompt = f'Rate the coherence of this reasoning from 1-10:
{\\"
\\".join(branch[\\"thoughts\\"])}'
score_text = self.llm.generate(coherence_prompt, max_new_tokens=10)
try:
score = int([c for c in score_text if c.isdigit()][0])
except:
score = 5
scores.append(score)
best_idx = scores.index(max(scores))
return branches[best_idx]
\\ def automatic_cot(self, problem, num_demonstrations=8):
# Generate diverse questions similar to the problem
diversity_prompt = f'Generate {num_demonstrations} diverse questions similar to: {problem}'
questions_text = self.llm.generate(diversity_prompt, max_new_tokens=500)
questions = [q.strip() for q in questions_text.split('
') if '?' in q][:num_demonstrations]
\\ # Generate reasoning chains for each
demonstrations = []
for q in questions:
cot_prompt = f'Q: {q}
A: Let\'s think step by step.'
reasoning = self.llm.generate(cot_prompt, max_new_tokens=200)
demonstrations.append({'question': q, 'reasoning': reasoning})
\\ # Cluster and select diverse examples
selected = self._cluster_demonstrations(demonstrations)
\\ # Build final prompt with selected demonstrations
few_shot_prompt = self._build_few_shot_prompt(selected, problem)
answer = self.llm.generate(few_shot_prompt, max_new_tokens=200)
\\ return {
'demonstrations': selected,
'prompt': few_shot_prompt,
'answer': answer
}
\\ def _cluster_demonstrations(self, demonstrations):
# Simplified clustering - in practice use embeddings
# Select diverse examples based on length and content
selected = []
lengths = [len(d['reasoning']) for d in demonstrations]
# Select short, medium, long examples
sorted_idx = sorted(range(len(lengths)), key=lambda i: lengths[i])
selected = [demonstrations[i] for i in [sorted_idx[0], sorted_idx[len(sorted_idx)//2], sorted_idx[-1]]]
\\ return selected
\\ def _build_few_shot_prompt(self, demonstrations, problem):
prompt = ''
for demo in demonstrations:
prompt += f'Q: {demo[\\"question\\"]}
A: {demo[\\"reasoning\\"]}
'
prompt += f'Q: {problem}
A:'
return prompt
\\ def reflexion_cot(self, problem, max_iterations=3):
iteration = 0
history = []
current_answer = None
\\ while iteration < max_iterations:
if iteration == 0:
prompt = f'Q: {problem}
A: Let\'s think step by step.'
else:
# Reflect on previous attempt
reflection_prompt = self._build_reflection_prompt(problem, history[-1])
prompt = reflection_prompt
\\ response = self.llm.generate(prompt, max_new_tokens=300)
# Self-evaluate
eval_prompt = f'Is this reasoning correct? Answer YES or NO and explain why:
{response}'
evaluation = self.llm.generate(eval_prompt, max_new_tokens=100)
\\ history.append({
'iteration': iteration,
'response': response,
'evaluation': evaluation,
'is_correct': 'yes' in evaluation.lower()
})
\\ if 'yes' in evaluation.lower():
current_answer = response
break
\\ iteration += 1
return {
'final_answer': current_answer or history[-1]['response'],
'history': history,
'iterations': iteration + 1
}
\\ def _build_reflection_prompt(self, problem, previous_attempt):
return f'''I previously tried to solve this problem but made a mistake:
Problem: {problem}
Previous attempt: {previous_attempt['response']}
Why it was wrong: {previous_attempt['evaluation']}
Let me think more carefully this time.
A:'''
# Using LangChain for structured CoT
try:
from langchain import PromptTemplate, LLMChain
from langchain.llms import HuggingFacePipeline
# Create structured CoT chain
cot_template = \\"\\"\\"
You are a reasoning assistant. Break down the following problem step by step:
Problem: {problem}
Step-by-step reasoning:
1. \\"\\"\\"
cot_prompt = PromptTemplate(
input_variables=['problem'],
template=cot_template
)
print('LangChain integration available')
except ImportError:
print('LangChain not available')
# Evaluation metrics
class CoTEvaluator:
def __init__(self):
pass
\\ def accuracy(self, predictions, labels):
correct = sum(p.strip().lower() == l.strip().lower() for p, l in zip(predictions, labels))
return correct / len(predictions)
\\ def reasoning_quality(self, reasoning_chains):
metrics = {
'avg_length': sum(len(r) for r in reasoning_chains) / len(reasoning_chains),
'has_logical_connectors': sum(1 for r in reasoning_chains if any(w in r.lower() for w in ['because', 'therefore', 'since'])) / len(reasoning_chains),
'step_count': sum(len(r.split('.')) for r in reasoning_chains) / len(reasoning_chains)
}
return metrics
\\ def self_consistency_score(self, answers):
from collections import Counter
answer_counts = Counter(answers)
most_common = answer_counts.most_common(1)[0]
return {
'most_common_answer': most_common[0],
'agreement_rate': most_common[1] / len(answers),
'entropy': -sum((c/len(answers)) * (c/len(answers)).bit_length() for c in answer_counts.values())
}
print('
Advanced CoT implementations loaded successfully')
When to Use
✅ Appropriate Use Cases:
- Complex multi-step reasoning tasks (arithmetic, logic, planning)
- When direct prompting gives wrong answers but model has relevant knowledge
- Tasks requiring explicit reasoning chains for interpretability
- Mathematical word problems requiring step-by-step calculation
- Commonsense reasoning where intermediate inference steps help
- Code generation with explanation requirements
❌ Avoid When:
- Simple factual recall where reasoning adds unnecessary tokens
- When latency is critical (CoT generates more tokens)
- Small models (<10B params) that lack reasoning capability
- Tasks where reasoning may lead to incorrect confabulation
- When the cost of token generation is prohibitive
- Classification tasks where direct prediction works better
Common Pitfalls
- Using CoT on models too small to exhibit emergent reasoning
- Insufficient examples in few-shot prompting for complex domains
- Not verifying that reasoning actually leads to correct answer
- Over-relying on single reasoning path instead of self-consistency
- Poorly formatted examples that don't demonstrate clear step-by-step reasoning
- Mixing different reasoning styles in few-shot examples (be consistent)
- Not handling cases where model generates reasoning but wrong conclusion
- Forgetting that CoT works best on tasks similar to training distribution