Chain-of-Thought: Reasoning in Large Language Models

Advanced Deep Learning
~11 min read Deep Learning

Definition

Chain-of-Thought (CoT) prompting is a technique that enables Large Language Models (LLMs) to solve complex reasoning tasks by generating intermediate reasoning steps before producing the final answer. Introduced by Wei et al. in 2022, CoT emerged as an emergent capability of sufficiently large models (typically >100B parameters) that manifests through in-context learning. The technique involves prompting the model with examples that demonstrate step-by-step reasoning, which the model then emulates. CoT dramatically improves performance on arithmetic, commonsense, and symbolic reasoning tasks. Variants include Zero-Shot CoT (adding 'Let's think step by step'), Automatic CoT, and Self-Consistency CoT (sampling multiple reasoning paths and taking majority vote). The underlying mechanism relates to how Transformers allocate attention across tokens and how training on code and reasoning corpora develops internal reasoning pathways.

Intuition

💡

Imagine solving a complex math problem. Instead of trying to jump straight to the answer, you write out your work: 'First, I need to find x, then substitute into equation y, then solve for z.' Chain-of-Thought prompting does the same for AI - it asks the model to show its work. Without CoT, the model is like a student who blurts out answers without thinking; with CoT, it's like a student who methodically works through problems. The magic happens because Transformers process text token by token, and by forcing the model to generate intermediate steps, we create more tokens that carry reasoning information, allowing later tokens to attend to this reasoning. It's like giving the model more 'thinking time' through its own generated text. The model learned patterns of reasoning during pretraining on books, code, and explanations, and CoT prompts activate these learned reasoning pathways.

Mathematical Formula

CoT Probability Decomposition:
\[ P(y|x, \text{CoT}) = \sum_{r} P(y|r, x) P(r|x, \text{CoT}) \]
Self-Consistency:
\[ \hat{y} = \arg\max_{y} \sum_{i=1}^{k} \mathbb{1}[y_i = y] \]
where $y_i \sim P(y|x, \text{CoT})$
CoT Loss (during training):
\[ \mathcal{L} = -\sum_{t=1}^{T} \log P(\text{token}_t | \text{token}_{
Emergence Threshold:
\[ \text{CoT emerges when model size } N > N_{crit} \approx 10^{11} \]
Attention Pattern for Reasoning:
\[ \alpha_{ij} = \text{softmax}\left(\frac{q_i^T k_j}{\sqrt{d_k}}\right) \quad \text{for reasoning tokens } j \]

Step-by-Step Explanation:

  1. CoT Decomposition: Final answer probability is sum over reasoning paths r, weighted by their likelihood
  2. Self-Consistency: Majority voting across k sampled reasoning chains to improve accuracy
  3. Training Loss: Standard next-token prediction extended over reasoning steps and answer
  4. Emergence Threshold: CoT capability appears abruptly around 100B parameters, showing phase transition
  5. Attention Pattern: Reasoning tokens receive higher attention weights, enabling step-by-step processing

Real-World Use Cases

Mathematical Reasoning

Solving GSM8K word problems: 'John has 5 apples, gives away 2...' with step-by-step arithmetic

Code Generation

Generating code with explanation: 'First, I need to iterate over the list, then...'

Legal Analysis

Case reasoning: 'The precedent in Smith v. Jones establishes... therefore...'

Medical Diagnosis

Differential diagnosis: 'Given symptom A and B, possible causes include... most likely...'

Scientific Problem Solving

Physics derivations showing formula substitutions and unit analysis

Logical Puzzles

Solving Sudoku or logic grid puzzles by explaining elimination steps

Implementation

Manual Implementation (No Libraries)

The ChainOfThoughtPrompt class implements various CoT prompting strategies: few-shot with exemplar demonstrations, zero-shot with trigger phrases like 'Let's think step by step', and self-consistency with multiple sampling. It includes parsing logic to extract reasoning chains and answers, plus quality evaluation metrics to assess the coherence of generated reasoning.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re

class ChainOfThoughtPrompt:
    
    
    def __init__(self, model_name='gpt2'):
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        
        # Standard CoT examples for few-shot prompting
        self.cot_examples = {
            'math': self._get_math_examples(),
            'commonsense': self._get_commonsense_examples(),
            'code': self._get_code_examples()
        }
    
    def _get_math_examples(self):
        return \"\"\"Q: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?
A: Roger starts with 5 balls. He buys 2 cans with 3 balls each, so that's 2 * 3 = 6 balls. In total, 5 + 6 = 11 balls. The answer is 11.

Q: A juggler has 16 balls. Half are golf balls and half are tennis balls. He loses 3 golf balls. How many golf balls does he have left?
A: Half of 16 balls are golf balls, so that's 16 / 2 = 8 golf balls. He loses 3, so 8 - 3 = 5 golf balls. The answer is 5.

Q: {question}
A:\"\"\"
    
    def _get_commonsense_examples(self):
        return \"\"\"Q: John had a fever and took medicine. What happened to the fever?
A: Medicine typically reduces fever by fighting infection or reducing inflammation. John\'s fever likely went down.

Q: Why might someone carry an umbrella on a cloudy day?
A: Cloudy days often indicate rain. Carrying an umbrella provides protection from potential rain.

Q: {question}
A:\"\"\"
    
    def _get_code_examples(self):
        return \"\"\"Q: Write a function to reverse a string.
A: Let me think through this step by step. First, I need to take a string as input. Then I can use slicing with step -1 to reverse it. Here\'s the code:
```python
def reverse_string(s):
    return s[::-1]
```

Q: Write a function to check if a number is prime.
A: To check if a number is prime, I need to verify it has no divisors other than 1 and itself. I\'ll check divisibility from 2 up to the square root of the number. Here\'s the implementation:
```python
def is_prime(n):
    if n < 2:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
                return False
    return True
```

Q: {question}
A:\"\"\"
    
    def format_prompt(self, question, task_type='math', use_few_shot=True, zero_shot_trigger=''):
        
        if use_few_shot:
            prompt = self.cot_examples[task_type].format(question=question)
        else:
            # Zero-shot CoT with trigger
            if zero_shot_trigger:
                prompt = f'Q: {question}
A: {zero_shot_trigger}'
            else:
                prompt = f'Q: {question}
A:'
        
        return prompt
    
    def generate_cot(self, question, task_type='math', max_length=200, temperature=0.7, 
                     use_few_shot=True, zero_shot_trigger='Let\'s think step by step.', 
                     num_samples=1):
        
        
        results = []
        
        for _ in range(num_samples):
            prompt = self.format_prompt(question, task_type, use_few_shot, zero_shot_trigger)
            
            inputs = self.tokenizer(prompt, return_tensors='pt')
            
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_length=len(inputs['input_ids'][0]) + max_length,
                    temperature=temperature,
                    do_sample=temperature > 0,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Extract just the generated part (after the prompt)
            response = generated[len(prompt):].strip()
            \\            # Parse reasoning and answer
            parsed = self._parse_response(response)
            results.append({
                'prompt': prompt,
                'response': response,
                'reasoning': parsed['reasoning'],
                'answer': parsed['answer']
            })
        
        return results
    
    def _parse_response(self, response):
        
        # Try to extract answer after common markers
        answer_markers = ['the answer is', 'answer:', 'therefore,', 'so,', 'thus,', '\*\*']
        \\        answer = response
        reasoning = response
        \\        for marker in answer_markers:
            if marker.lower() in response.lower():
                parts = response.lower().split(marker.lower(), 1)
                if len(parts) == 2:
                    reasoning = parts[0].strip()
                    answer = parts[1].strip()
                    break
        \\        # Clean up answer
        answer = answer.replace('.', '').replace('!', '').strip()
        \\        return {'reasoning': reasoning, 'answer': answer}
    \    def self_consistency(self, question, task_type='math', num_samples=10, threshold=0.5):
        
        
        results = self.generate_cot(question, task_type, num_samples=num_samples)
        \\        # Extract answers and count
        answers = [r['answer'] for r in results]
        answer_counts = {}
        for ans in answers:
            # Normalize answer for comparison
            normalized = re.sub(r'[^\w\s]', '', ans.lower()).strip()
            answer_counts[normalized] = answer_counts.get(normalized, 0) + 1
        \\        # Find most common answer
        if answer_counts:
            best_answer = max(answer_counts.items(), key=lambda x: x[1])
            confidence = best_answer[1] / num_samples
            \            if confidence >= threshold:
                return {
                    'answer': best_answer[0],
                    'confidence': confidence,
                    'all_results': results,
                    'answer_distribution': answer_counts
                }
        \\        return {
            'answer': answers[0],
            'confidence': answer_counts.get(answers[0], 1) / num_samples,
            'all_results': results,
            'answer_distribution': answer_counts
        }
    \    def evaluate_cot_quality(self, reasoning):
        
        
        criteria = {
            'has_steps': len(reasoning.split('.')) > 2 or 'first' in reasoning.lower(),
            'has_logic_words': any(word in reasoning.lower() for word in ['because', 'therefore', 'so', 'thus', 'since']),
            'has_numbers': any(char.isdigit() for char in reasoning),
            'reasonable_length': 20 < len(reasoning) < 500
        }
        \\        score = sum(criteria.values()) / len(criteria)
        return {
            'criteria': criteria,
            'quality_score': score,
            'is_high_quality': score >= 0.75
        }

# Demonstration
print('Chain of Thought Demonstration')
print('=' * 50)

# Initialize (using a small model for demo - in practice use larger models)
# For actual use, replace with a larger model like 'meta-llama/Llama-2-7b'
cot = ChainOfThoughtPrompt('gpt2')

# Example question
question = 'If a train travels 120 miles in 2 hours, how far will it travel in 5 hours at the same speed?'

print(f'Question: {question}')
print('
Few-shot CoT:')
results = cot.generate_cot(question, task_type='math', use_few_shot=True, num_samples=1)
print(f'Response: {results[0][\"response\"]}')

print('
Zero-shot CoT:')
results = cot.generate_cot(question, task_type='math', use_few_shot=False, zero_shot_trigger='Let\'s think step by step.', num_samples=1)
print(f'Response: {results[0][\"response\"]}')

# Quality evaluation
quality = cot.evaluate_cot_quality(results[0]['reasoning'])
print(f'
Reasoning Quality: {quality}')

Using Libraries (torch, transformers, openai, langchain)

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import openai  # For API-based models
import json

class CoTWithModernLLMs:
    
    
    def __init__(self, model_name='meta-llama/Llama-2-7b-chat-hf', use_api=False):
        \\        self.use_api = use_api
        self.model_name = model_name
        \\        if not use_api:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16,
                device_map='auto'
            )
    \\    def generate(self, prompt, max_new_tokens=512, temperature=0.7, top_p=0.9):
        \\        if self.use_api:
            return self._generate_api(prompt, max_new_tokens, temperature)
        else:
            return self._generate_local(prompt, max_new_tokens, temperature, top_p)
    \\    def _generate_local(self, prompt, max_new_tokens, temperature, top_p):
        \\        inputs = self.tokenizer(prompt, return_tensors='pt').to(self.model.device)
        \\        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                do_sample=temperature > 0,
                pad_token_id=self.tokenizer.eos_token_id
            )
        
        return self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    \\    def _generate_api(self, prompt, max_tokens, temperature):
        
        response = openai.ChatCompletion.create(
            model=self.model_name,
            messages=[{'role': 'user', 'content': prompt}],
            max_tokens=max_tokens,
            temperature=temperature
        )
        return response.choices[0].message.content

class AdvancedCoT:
    
    
    def __init__(self, llm_wrapper):
        self.llm = llm_wrapper
    \\    def tree_of_thoughts(self, problem, num_branches=3, depth=2):
        
        
        tree = {'problem': problem, 'branches': []}
        
        for i in range(num_branches):
            branch = self._explore_branch(problem, depth, branch_id=i)
            tree['branches'].append(branch)
        \\        # Evaluate and select best path
        best_branch = self._evaluate_branches(tree['branches'])
        return best_branch
    \\    def _explore_branch(self, problem, depth, branch_id):
        
        branch = {'thoughts': [], 'branch_id': branch_id}
        current = problem
        \\        for d in range(depth):
            prompt = f'Given: {current}
Think step by step. What should I consider next?'
            thought = self.llm.generate(prompt, max_new_tokens=100)
            branch['thoughts'].append(thought)
            current = f'{current}
Thought {d+1}: {thought}'
        
        # Generate final answer
        final_prompt = f'{current}
Based on these thoughts, what is the answer?'
        branch['answer'] = self.llm.generate(final_prompt, max_new_tokens=150)
        
        return branch
    \\    def _evaluate_branches(self, branches):
        
        # Simple evaluation based on coherence
        scores = []
        for branch in branches:
            coherence_prompt = f'Rate the coherence of this reasoning from 1-10:
{\\"
\\".join(branch[\\"thoughts\\"])}'
            score_text = self.llm.generate(coherence_prompt, max_new_tokens=10)
            try:
                score = int([c for c in score_text if c.isdigit()][0])
            except:
                score = 5
            scores.append(score)
        
        best_idx = scores.index(max(scores))
        return branches[best_idx]
    \\    def automatic_cot(self, problem, num_demonstrations=8):
        
        
        # Generate diverse questions similar to the problem
        diversity_prompt = f'Generate {num_demonstrations} diverse questions similar to: {problem}'
        questions_text = self.llm.generate(diversity_prompt, max_new_tokens=500)
        questions = [q.strip() for q in questions_text.split('
') if '?' in q][:num_demonstrations]
        \\        # Generate reasoning chains for each
        demonstrations = []
        for q in questions:
            cot_prompt = f'Q: {q}
A: Let\'s think step by step.'
            reasoning = self.llm.generate(cot_prompt, max_new_tokens=200)
            demonstrations.append({'question': q, 'reasoning': reasoning})
        \\        # Cluster and select diverse examples
        selected = self._cluster_demonstrations(demonstrations)
        \\        # Build final prompt with selected demonstrations
        few_shot_prompt = self._build_few_shot_prompt(selected, problem)
        answer = self.llm.generate(few_shot_prompt, max_new_tokens=200)
        \\        return {
            'demonstrations': selected,
            'prompt': few_shot_prompt,
            'answer': answer
        }
    \\    def _cluster_demonstrations(self, demonstrations):
        
        # Simplified clustering - in practice use embeddings
        # Select diverse examples based on length and content
        selected = []
        lengths = [len(d['reasoning']) for d in demonstrations]
        
        # Select short, medium, long examples
        sorted_idx = sorted(range(len(lengths)), key=lambda i: lengths[i])
        selected = [demonstrations[i] for i in [sorted_idx[0], sorted_idx[len(sorted_idx)//2], sorted_idx[-1]]]
        \\        return selected
    \\    def _build_few_shot_prompt(self, demonstrations, problem):
        
        prompt = ''
        for demo in demonstrations:
            prompt += f'Q: {demo[\\"question\\"]}
A: {demo[\\"reasoning\\"]}

'
        prompt += f'Q: {problem}
A:'
        return prompt
    \\    def reflexion_cot(self, problem, max_iterations=3):
        
        
        iteration = 0
        history = []
        current_answer = None
        \\        while iteration < max_iterations:
            if iteration == 0:
                prompt = f'Q: {problem}
A: Let\'s think step by step.'
            else:
                # Reflect on previous attempt
                reflection_prompt = self._build_reflection_prompt(problem, history[-1])
                prompt = reflection_prompt
            \\            response = self.llm.generate(prompt, max_new_tokens=300)
            
            # Self-evaluate
            eval_prompt = f'Is this reasoning correct? Answer YES or NO and explain why:
{response}'
            evaluation = self.llm.generate(eval_prompt, max_new_tokens=100)
            \\            history.append({
                'iteration': iteration,
                'response': response,
                'evaluation': evaluation,
                'is_correct': 'yes' in evaluation.lower()
            })
            \\            if 'yes' in evaluation.lower():
                current_answer = response
                break
            \\            iteration += 1
        
        return {
            'final_answer': current_answer or history[-1]['response'],
            'history': history,
            'iterations': iteration + 1
        }
    \\    def _build_reflection_prompt(self, problem, previous_attempt):
        
        return f'''I previously tried to solve this problem but made a mistake:

Problem: {problem}
Previous attempt: {previous_attempt['response']}
Why it was wrong: {previous_attempt['evaluation']}

Let me think more carefully this time.

A:'''

# Using LangChain for structured CoT
try:
    from langchain import PromptTemplate, LLMChain
    from langchain.llms import HuggingFacePipeline
    
    # Create structured CoT chain
    cot_template = \\"\\"\\"
    You are a reasoning assistant. Break down the following problem step by step:
    
    Problem: {problem}
    
    Step-by-step reasoning:
    1. \\"\\"\\"
    
    cot_prompt = PromptTemplate(
        input_variables=['problem'],
        template=cot_template
    )
    
    print('LangChain integration available')
    
except ImportError:
    print('LangChain not available')

# Evaluation metrics
class CoTEvaluator:
    
    
    def __init__(self):
        pass
    \\    def accuracy(self, predictions, labels):
        
        correct = sum(p.strip().lower() == l.strip().lower() for p, l in zip(predictions, labels))
        return correct / len(predictions)
    \\    def reasoning_quality(self, reasoning_chains):
        
        metrics = {
            'avg_length': sum(len(r) for r in reasoning_chains) / len(reasoning_chains),
            'has_logical_connectors': sum(1 for r in reasoning_chains if any(w in r.lower() for w in ['because', 'therefore', 'since'])) / len(reasoning_chains),
            'step_count': sum(len(r.split('.')) for r in reasoning_chains) / len(reasoning_chains)
        }
        return metrics
    \\    def self_consistency_score(self, answers):
        
        from collections import Counter
        answer_counts = Counter(answers)
        most_common = answer_counts.most_common(1)[0]
        return {
            'most_common_answer': most_common[0],
            'agreement_rate': most_common[1] / len(answers),
            'entropy': -sum((c/len(answers)) * (c/len(answers)).bit_length() for c in answer_counts.values())
        }

print('
Advanced CoT implementations loaded successfully')

When to Use

✅ Appropriate Use Cases:

  • Complex multi-step reasoning tasks (arithmetic, logic, planning)
  • When direct prompting gives wrong answers but model has relevant knowledge
  • Tasks requiring explicit reasoning chains for interpretability
  • Mathematical word problems requiring step-by-step calculation
  • Commonsense reasoning where intermediate inference steps help
  • Code generation with explanation requirements

❌ Avoid When:

  • Simple factual recall where reasoning adds unnecessary tokens
  • When latency is critical (CoT generates more tokens)
  • Small models (<10B params) that lack reasoning capability
  • Tasks where reasoning may lead to incorrect confabulation
  • When the cost of token generation is prohibitive
  • Classification tasks where direct prediction works better

Common Pitfalls

  • Using CoT on models too small to exhibit emergent reasoning
  • Insufficient examples in few-shot prompting for complex domains
  • Not verifying that reasoning actually leads to correct answer
  • Over-relying on single reasoning path instead of self-consistency
  • Poorly formatted examples that don't demonstrate clear step-by-step reasoning
  • Mixing different reasoning styles in few-shot examples (be consistent)
  • Not handling cases where model generates reasoning but wrong conclusion
  • Forgetting that CoT works best on tasks similar to training distribution