Standard CATE Estimation Example
This example demonstrates a complete workflow for standard CATE (Conditional Average Treatment Effect) estimation using CausalFM.
Overview
In this example, we will:
Generate synthetic training data
Train a Standard CATE model
Evaluate the model on test data
Visualize the results
Complete Example
"""
Complete example for Standard CATE estimation with CausalFM
"""
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from causalfm.data import StandardCATEGenerator
from causalfm.models import StandardCATEModel
from causalfm.training import StandardCATETrainer, TrainingConfig
from causalfm.evaluation import compute_pehe, compute_ate_error, compute_rmse
def generate_data():
"""Step 1: Generate synthetic datasets"""
print("=" * 60)
print("STEP 1: Data Generation")
print("=" * 60)
# Training data
train_gen = StandardCATEGenerator(
num_samples=1024,
num_features=10,
seed=42
)
print("Generating 500 training datasets...")
train_gen.generate_multiple(
num_datasets=500,
output_dir="data/standard_cate/train/",
filename_prefix="train"
)
# Test data
test_gen = StandardCATEGenerator(
num_samples=1024,
num_features=10,
seed=999 # Different seed for test
)
print("Generating 50 test datasets...")
test_gen.generate_multiple(
num_datasets=50,
output_dir="data/standard_cate/test/",
filename_prefix="test"
)
print("✓ Data generation complete!\n")
def train_model():
"""Step 2: Train the model"""
print("=" * 60)
print("STEP 2: Model Training")
print("=" * 60)
# Configure training
config = TrainingConfig(
# Data
data_path="data/standard_cate/train/*.csv",
val_split=0.2,
# Training
epochs=100,
batch_size=16,
learning_rate=0.001,
weight_decay=1e-5,
# Early stopping
early_stop_patience=30,
# Model
use_gmm_head=True,
gmm_n_components=5,
# Checkpointing
save_dir="checkpoints/standard_cate/",
save_freq=10,
# Logging
log_dir="logs/standard_cate/",
# Hardware
device='auto',
num_workers=0, # Set to 0 to avoid multiprocessing issues
# Reproducibility
seed=42
)
# Train
print("Starting training...")
trainer = StandardCATETrainer(config)
trainer.train()
print("\n✓ Training complete!\n")
def evaluate_model():
"""Step 3: Evaluate the model"""
print("=" * 60)
print("STEP 3: Model Evaluation")
print("=" * 60)
# Load trained model
model = StandardCATEModel.from_pretrained(
"checkpoints/standard_cate/best_model.pth",
device='cpu'
)
model.eval_mode()
# Evaluate on all test datasets
test_dir = Path("data/standard_cate/test/")
test_files = sorted(test_dir.glob("test_*.csv"))
results = []
for file in test_files:
# Load dataset
df = pd.read_csv(file)
# Extract features
x_cols = [c for c in df.columns if c.startswith('x')]
X = torch.FloatTensor(df[x_cols].values)
A = torch.FloatTensor(df['treatment'].values).unsqueeze(1)
Y = torch.FloatTensor(df['outcome'].values).unsqueeze(1)
true_ite = df['ite'].values
# Split into train/test for in-context learning
n_train = int(0.8 * len(X))
x_train = X[:n_train]
x_test = X[n_train:]
a_train = A[:n_train]
y_train = Y[:n_train]
ite_test = true_ite[n_train:]
# Predict
with torch.no_grad():
result = model.estimate_cate(x_train, a_train, y_train, x_test)
pred_cate = result['cate'].cpu().numpy()
# Compute metrics
pehe = compute_pehe(pred_cate, ite_test)
ate_error = compute_ate_error(pred_cate, ite_test)
rmse = compute_rmse(pred_cate, ite_test)
results.append({
'dataset': file.name,
'pehe': pehe,
'ate_error': ate_error,
'rmse': rmse,
'n_test': len(ite_test)
})
print(f" {file.name}: PEHE={pehe:.4f}, ATE Error={ate_error:.4f}")
# Aggregate results
results_df = pd.DataFrame(results)
results_df.to_csv("results/standard_cate_results.csv", index=False)
print("\n" + "=" * 60)
print("SUMMARY STATISTICS")
print("=" * 60)
print(f"Number of test datasets: {len(results_df)}")
print(f"\nPEHE: {results_df['pehe'].mean():.4f} ± {results_df['pehe'].std():.4f}")
print(f"ATE Error: {results_df['ate_error'].mean():.4f} ± {results_df['ate_error'].std():.4f}")
print(f"RMSE: {results_df['rmse'].mean():.4f} ± {results_df['rmse'].std():.4f}")
print("\n✓ Evaluation complete!\n")
return results_df, model
def visualize_results(results_df, model):
"""Step 4: Visualize results"""
print("=" * 60)
print("STEP 4: Visualization")
print("=" * 60)
# Load one test dataset for visualization
df = pd.read_csv("data/standard_cate/test/test_dataset_1.csv")
x_cols = [c for c in df.columns if c.startswith('x')]
X = torch.FloatTensor(df[x_cols].values)
A = torch.FloatTensor(df['treatment'].values).unsqueeze(1)
Y = torch.FloatTensor(df['outcome'].values).unsqueeze(1)
true_ite = df['ite'].values
n_train = int(0.8 * len(X))
x_train = X[:n_train]
x_test = X[n_train:]
a_train = A[:n_train]
y_train = Y[:n_train]
ite_test = true_ite[n_train:]
# Get predictions with uncertainty
with torch.no_grad():
result = model.estimate_cate(x_train, a_train, y_train, x_test)
pred_cate = result['cate'].cpu().numpy()
pi = result['gmm_pi'].cpu().numpy()
mu = result['gmm_mu'].cpu().numpy()
sigma = result['gmm_sigma'].cpu().numpy()
# Compute confidence intervals
n_samples = 10000
samples = np.zeros((len(pred_cate), n_samples))
for i in range(len(pred_cate)):
components = np.random.choice(len(pi[i]), size=n_samples, p=pi[i])
for k in range(len(pi[i])):
mask = (components == k)
n_k = mask.sum()
if n_k > 0:
samples[i, mask] = np.random.normal(mu[i, k], sigma[i, k], n_k)
ci_lower = np.percentile(samples, 2.5, axis=1)
ci_upper = np.percentile(samples, 97.5, axis=1)
# Create visualizations
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Plot 1: Predicted vs True
ax1 = axes[0, 0]
ax1.scatter(ite_test, pred_cate, alpha=0.6)
min_val = min(ite_test.min(), pred_cate.min())
max_val = max(ite_test.max(), pred_cate.max())
ax1.plot([min_val, max_val], [min_val, max_val], 'r--', label='Perfect Prediction')
ax1.set_xlabel('True ITE')
ax1.set_ylabel('Predicted CATE')
ax1.set_title('Predicted vs True Treatment Effects')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Plot 2: Error Distribution
ax2 = axes[0, 1]
errors = pred_cate - ite_test
ax2.hist(errors, bins=30, edgecolor='black', alpha=0.7)
ax2.axvline(0, color='r', linestyle='--', linewidth=2, label='Zero Error')
ax2.set_xlabel('Prediction Error')
ax2.set_ylabel('Frequency')
ax2.set_title('Error Distribution')
ax2.legend()
# Plot 3: Uncertainty Calibration
ax3 = axes[1, 0]
variance = (pi * (sigma**2 + mu**2)).sum(axis=-1) - pred_cate**2
std_dev = np.sqrt(variance)
ax3.scatter(std_dev, np.abs(errors), alpha=0.6)
ax3.set_xlabel('Predicted Std Dev')
ax3.set_ylabel('Absolute Error')
ax3.set_title('Uncertainty Calibration')
ax3.grid(True, alpha=0.3)
# Plot 4: Predictions with Uncertainty
ax4 = axes[1, 1]
sorted_idx = np.argsort(pred_cate)
x = np.arange(len(sorted_idx))
ax4.plot(x, pred_cate[sorted_idx], label='Predicted CATE', color='blue', linewidth=2)
ax4.fill_between(x, ci_lower[sorted_idx], ci_upper[sorted_idx],
alpha=0.3, label='95% CI')
ax4.scatter(x, ite_test[sorted_idx], s=20, alpha=0.6,
color='red', label='True ITE')
ax4.set_xlabel('Sample (sorted by prediction)')
ax4.set_ylabel('Treatment Effect')
ax4.set_title('Predictions with Uncertainty Bands')
ax4.legend()
ax4.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('results/standard_cate_visualization.png', dpi=300, bbox_inches='tight')
print("✓ Visualization saved to results/standard_cate_visualization.png")
# Plot 5: PEHE across datasets
fig2, ax = plt.subplots(figsize=(10, 6))
ax.bar(range(len(results_df)), results_df['pehe'])
ax.axhline(results_df['pehe'].mean(), color='r', linestyle='--',
linewidth=2, label=f"Mean: {results_df['pehe'].mean():.4f}")
ax.set_xlabel('Test Dataset')
ax.set_ylabel('PEHE')
ax.set_title('PEHE Across Test Datasets')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.savefig('results/pehe_across_datasets.png', dpi=300, bbox_inches='tight')
print("✓ PEHE plot saved to results/pehe_across_datasets.png")
print("\n✓ Visualization complete!\n")
if __name__ == '__main__':
# Create directories
Path("data/standard_cate/train/").mkdir(parents=True, exist_ok=True)
Path("data/standard_cate/test/").mkdir(parents=True, exist_ok=True)
Path("checkpoints/standard_cate/").mkdir(parents=True, exist_ok=True)
Path("logs/standard_cate/").mkdir(parents=True, exist_ok=True)
Path("results/").mkdir(parents=True, exist_ok=True)
# Run complete pipeline
generate_data()
train_model()
results_df, model = evaluate_model()
visualize_results(results_df, model)
print("=" * 60)
print("PIPELINE COMPLETE!")
print("=" * 60)
print("\nOutputs:")
print(" - Model: checkpoints/standard_cate/best_model.pth")
print(" - Results: results/standard_cate_results.csv")
print(" - Plots: results/*.png")
print(" - Logs: logs/standard_cate/")
Expected Output
When you run this example, you should see output similar to:
============================================================
STEP 1: Data Generation
============================================================
Generating 500 training datasets...
Generating 50 test datasets...
✓ Data generation complete!
============================================================
STEP 2: Model Training
============================================================
Starting training...
Epoch 1/100 ━━━━━━━━━━━━━━━━━━━━━━━ 100% | Train Loss: 1.23 | Val Loss: 1.34
Epoch 2/100 ━━━━━━━━━━━━━━━━━━━━━━━ 100% | Train Loss: 1.15 | Val Loss: 1.25
✓ New best model saved!
...
✓ Training complete!
============================================================
STEP 3: Model Evaluation
============================================================
test_dataset_1.csv: PEHE=0.4523, ATE Error=0.0234
test_dataset_2.csv: PEHE=0.4312, ATE Error=0.0189
...
============================================================
SUMMARY STATISTICS
============================================================
Number of test datasets: 50
PEHE: 0.4456 ± 0.0234
ATE Error: 0.0201 ± 0.0089
RMSE: 0.4782 ± 0.0245
✓ Evaluation complete!
Key Takeaways
Data Generation: Generate many diverse synthetic datasets for training
Training: Use the TrainingConfig to control all aspects of training
Evaluation: Evaluate on multiple test datasets for robust estimates
Uncertainty: GMM head provides calibrated uncertainty quantification
Visualization: Visualize predictions and errors to understand model behavior
Next Steps
Try adjusting
num_featuresornum_samplesin data generationExperiment with different
gmm_n_componentsvaluesTest on real-world datasets like the Jobs dataset
Compare with other causal inference methods
Explore the IV and Front-door examples