Decision Trees: Splitting Criteria, Pruning, and Interpretability
Definition
Decision Trees are non-parametric supervised learning algorithms that model decisions through a hierarchical structure of internal nodes (feature tests), branches (outcomes), and leaf nodes (predictions). The algorithm recursively partitions the feature space into rectangular regions, assigning predictions based on the majority class (classification) or average value (regression) of samples in each region. Trees select splits by maximizing information gain (reduction in impurity), using criteria like Gini impurity or entropy for classification, and variance reduction or MSE for regression. The greedy top-down approach makes locally optimal splits without backtracking, which can lead to overfitting on complex datasets. Tree-based models are highly interpretable - you can literally trace the decision path - but this simplicity comes at the cost of instability (small data changes cause different trees) and bias toward features with many levels.
Intuition
Think of decision trees like a game of 20 Questions. You ask yes/no questions about features to narrow down the answer. 'Is it bigger than a breadbox?' splits the possibilities. 'Is it alive?' further refines them. Each question should ideally eliminate as much uncertainty as possible. The tree learns the most informative questions to ask first, creating a flowchart that anyone can follow.
Mathematical Formula
Step-by-Step Explanation:
- Gini impurity measures probability of misclassifying a randomly chosen element
- Entropy measures average information content - uncertainty in the node
- Information gain is the reduction in impurity after a split on attribute a
- H(T) is the entropy of the parent node, H(Tᵥ) is entropy of child v
- MSE for regression measures spread around the mean prediction in a node
- Variance reduction quantifies how much a split decreases target variance
Real-World Use Cases
Medical diagnosis decision support: 'Is patient over 50? Yes → Check cholesterol. No → Check blood pressure.' Interpretable for doctors.
Loan approval rules engine: explicit if-then rules satisfying regulatory requirements for explainability.
Customer segmentation for marketing: 'High income AND frequent purchaser → Premium tier'.
Quality control: decision rules identifying defective products based on sensor thresholds.
Implementation
Manual Implementation (No Libraries)
import numpy as np
from collections import Counter
class DecisionTree:
"""
Manual implementation of a Decision Tree classifier.
Supports Gini impurity and Information Gain (Entropy).
"""
def __init__(self, max_depth=5, min_samples_split=2, criterion='gini'):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.criterion = criterion
self.tree = None
def _gini(self, y):
"""Calculate Gini impurity."""
if len(y) == 0:
return 0
proportions = np.bincount(y) / len(y)
return 1 - np.sum(proportions ** 2)
def _entropy(self, y):
"""Calculate entropy."""
if len(y) == 0:
return 0
proportions = np.bincount(y) / len(y)
# Add epsilon to avoid log(0)
return -np.sum(proportions * np.log2(proportions + 1e-10))
def _impurity(self, y):
"""Calculate impurity based on chosen criterion."""
if self.criterion == 'gini':
return self._gini(y)
return self._entropy(y)
def _information_gain(self, y, left_idx, right_idx):
"""Calculate information gain from a split."""
parent_impurity = self._impurity(y)
n = len(y)
n_left = len(left_idx)
n_right = len(right_idx)
if n_left == 0 or n_right == 0:
return 0
child_impurity = (n_left / n) * self._impurity(y[left_idx]) + \
(n_right / n) * self._impurity(y[right_idx])
return parent_impurity - child_impurity
def _best_split(self, X, y):
"""Find the best feature and threshold to split on."""
best_gain = -1
best_feature = None
best_threshold = None
n_features = X.shape[1]
for feature in range(n_features):
thresholds = np.unique(X[:, feature])
for threshold in thresholds:
left_idx = np.where(X[:, feature] <= threshold)[0]
right_idx = np.where(X[:, feature] > threshold)[0]
gain = self._information_gain(y, left_idx, right_idx)
if gain > best_gain:
best_gain = gain
best_feature = feature
best_threshold = threshold
return best_feature, best_threshold, best_gain
def _build_tree(self, X, y, depth=0):
"""Recursively build the decision tree."""
n_samples = len(y)
n_classes = len(np.unique(y))
# Stopping conditions
if (depth >= self.max_depth or
n_samples < self.min_samples_split or
n_classes == 1):
return {'prediction': Counter(y).most_common(1)[0][0]}
# Find best split
feature, threshold, gain = self._best_split(X, y)
if gain <= 0:
return {'prediction': Counter(y).most_common(1)[0][0]}
# Split data
left_idx = X[:, feature] <= threshold
right_idx = X[:, feature] > threshold
# Recursively build left and right subtrees
left_subtree = self._build_tree(X[left_idx], y[left_idx], depth + 1)
right_subtree = self._build_tree(X[right_idx], y[right_idx], depth + 1)
return {
'feature': feature,
'threshold': threshold,
'left': left_subtree,
'right': right_subtree
}
def fit(self, X, y):
"""Build the decision tree."""
self.tree = self._build_tree(X, y)
return self
def _predict_single(self, x, node):
"""Predict for a single sample."""
if 'prediction' in node:
return node['prediction']
if x[node['feature']] <= node['threshold']:
return self._predict_single(x, node['left'])
else:
return self._predict_single(x, node['right'])
def predict(self, X):
"""Predict for all samples."""
return np.array([self._predict_single(x, self.tree) for x in X])
# Demonstration
if __name__ == '__main__':
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
# Generate sample data
X, y = make_classification(
n_samples=200, n_features=4, n_redundant=0,
n_informative=4, n_clusters_per_class=1, random_state=42
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Train tree
tree = DecisionTree(max_depth=3, criterion='gini')
tree.fit(X_train, y_train)
# Predict
y_pred = tree.predict(X_test)
accuracy = np.mean(y_pred == y_test)
print(f'Accuracy: {accuracy:.3f}')
Using Libraries (scikit-learn, numpy, matplotlib)
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, export_text, plot_tree
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
from sklearn.datasets import load_iris, fetch_california_housing
from sklearn.metrics import accuracy_score, mean_squared_error, classification_report
import numpy as np
import matplotlib.pyplot as plt
# CLASSIFICATION EXAMPLE with Iris dataset
print('=== DECISION TREE CLASSIFICATION ===')
iris = load_iris()
X_cls, y_cls = iris.data, iris.target
X_train_c, X_test_c, y_train_c, y_test_c = train_test_split(
X_cls, y_cls, test_size=0.2, random_state=42, stratify=y_cls
)
# Train with pruning parameters
clf = DecisionTreeClassifier(
criterion='gini', # 'gini' or 'entropy'
max_depth=4, # Limit tree depth
min_samples_split=5, # Min samples to split
min_samples_leaf=2, # Min samples in leaf
random_state=42
)
clf.fit(X_train_c, y_train_c)
# Evaluate
y_pred_c = clf.predict(X_test_c)
print(f'Accuracy: {accuracy_score(y_test_c, y_pred_c):.3f}')
print(f'
Feature Importances:')
for name, importance in zip(iris.feature_names, clf.feature_importances_):
print(f' {name}: {importance:.3f}')
# Visualize tree rules
print(f'
Tree Rules:
{export_text(clf, feature_names=list(iris.feature_names))}')
# REGRESSION EXAMPLE with California Housing
print('
=== DECISION TREE REGRESSION ===')
housing = fetch_california_housing()
X_reg, y_reg = housing.data[:1000], housing.target[:1000] # Sample for speed
X_train_r, X_test_r, y_train_r, y_test_r = train_test_split(
X_reg, y_reg, test_size=0.2, random_state=42
)
reg = DecisionTreeRegressor(
max_depth=6,
min_samples_split=10,
min_samples_leaf=5,
random_state=42
)
reg.fit(X_train_r, y_train_r)
y_pred_r = reg.predict(X_test_r)
print(f'R² Score: {reg.score(X_test_r, y_test_r):.3f}')
print(f'RMSE: {np.sqrt(mean_squared_error(y_test_r, y_pred_r)):.3f}')
# HYPERPARAMETER TUNING
print('
=== HYPERPARAMETER TUNING ===')
param_grid = {
'max_depth': [3, 5, 7, 10, None],
'min_samples_split': [2, 5, 10],
'min_samples_leaf': [1, 2, 4],
'criterion': ['gini', 'entropy']
}
grid_search = GridSearchCV(
DecisionTreeClassifier(random_state=42),
param_grid,
cv=5,
scoring='accuracy',
n_jobs=-1
)
grid_search.fit(X_train_c, y_train_c)
print(f'Best parameters: {grid_search.best_params_}')
print(f'Best CV accuracy: {grid_search.best_score_:.3f}')
print(f'Test accuracy: {accuracy_score(y_test_c, grid_search.predict(X_test_c)):.3f}')
When to Use
✅ Appropriate Use Cases:
- When interpretability is critical: you can explain exactly why a prediction was made
- Small to medium datasets where model simplicity is valued
- Mixed data types (numeric and categorical) without extensive preprocessing
- Feature selection: tree-based importance helps identify predictive features
- Baseline model before trying ensemble methods
- When decision rules need to be explicit (regulatory compliance)
❌ Avoid When:
- Very large datasets: computationally expensive and prone to overfitting
- When prediction accuracy is paramount (use Random Forest or XGBoost instead)
- High-dimensional sparse data (text, genomics): trees struggle with many features
- When you need probability estimates: tree probabilities are poorly calibrated
- Extrapolation tasks: trees cannot predict outside training data range
- Unbalanced datasets without careful tuning
Common Pitfalls
- Overfitting: Unrestricted trees memorize training data. Always set max_depth
- Instability: Small data changes cause completely different trees
- Bias toward high-cardinality features: features with many unique values
- Greedy splits: Locally optimal choices may not be globally optimal
- Axis-aligned splits: Diagonal decision boundaries require many splits
- Ignoring class weights: Skewed classes need balanced class_weight parameter
- Not pruning: Post-pruning (ccp_alpha) or pre-pruning (max_depth) is essential