Tutorial 1: Basics
Welcome to CausalFM! This tutorial introduces the fundamental concepts and workflow.
Learning Objectives
By the end of this tutorial, you will understand:
What are Prior-Data Fitted Networks (PFNs)
The key concepts in causal inference with CausalFM
The basic workflow from data to predictions
When to use different causal settings
What are Foundation Models for Causal Inference?
Traditional Approach
Traditional machine learning for causal inference:
Collect a single dataset
Train a model on that dataset
Make predictions on new samples from the same distribution
Problem: Requires large datasets and doesn’t transfer well.
Foundation Model Approach (CausalFM)
CausalFM uses a different paradigm:
Train once on many diverse synthetic datasets
Model learns the structure of causal inference problems
Apply to new datasets without retraining (zero-shot transfer)
Adapt in-context using just a few samples
Advantage: Transfer learning + in-context adaptation!
Key Concepts
Prior-Data Fitted Networks (PFNs)
PFNs are trained to solve a distribution of tasks, not just one task.
# Traditional: Train on one dataset
model.fit(X, y) # Requires large dataset
# PFN: Trained on many datasets, adapts in-context
model.estimate(X_context, y_context, X_query) # Few samples needed!
In CausalFM, the “task” is estimating treatment effects for a particular dataset.
In-Context Learning
The model takes training examples as input and adapts its predictions:
# Provide context (training samples)
result = model.estimate_cate(
x_train, # Context: observed covariates
a_train, # Context: treatments
y_train, # Context: outcomes
x_test # Query: new covariates to predict for
)
The model learns the relationship between (X, A) → Y from the context samples!
CATE: Conditional Average Treatment Effect
CATE is the expected treatment effect for an individual with covariates X:
Where:
Y(1) = potential outcome under treatment
Y(0) = potential outcome under control
X = covariates (features)
# Individual treatment effects
cate = model.estimate_cate(x_train, a_train, y_train, x_test)['cate']
# Average treatment effect
ate = cate.mean()
Causal Settings in CausalFM
CausalFM supports three causal inference settings:
1. Standard CATE Estimation
When to use: No unobserved confounding (all confounders measured).
from causalfm.models import StandardCATEModel
model = StandardCATEModel.from_pretrained("checkpoints/standard.pth")
result = model.estimate_cate(x_train, a_train, y_train, x_test)
2. Instrumental Variables
When to use: Unobserved confounding, but valid instrument available.
from causalfm.models import IVModel
model = IVModel.from_pretrained("checkpoints/iv.pth")
result = model.estimate_cate(x_train, z_train, a_train, y_train, x_test)
3. Front-door Adjustment
When to use: Unobserved confounding, mediator blocks backdoor path.
from causalfm.models import FrontdoorModel
model = FrontdoorModel.from_pretrained("checkpoints/fd.pth")
result = model.estimate_cate(x_train, m_train, a_train, y_train, x_test)
Basic Workflow
The CausalFM workflow has four main steps:
Step 1: Generate Data
from causalfm.data import StandardCATEGenerator
generator = StandardCATEGenerator(num_samples=1024, num_features=10)
# Generate training data
generator.generate_multiple(500, "data/train/")
# Generate test data
generator.generate_multiple(50, "data/test/")
Step 2: Train Model
from causalfm.training import StandardCATETrainer, TrainingConfig
if __name__ == '__main__':
config = TrainingConfig(
data_path="data/train/*.csv",
epochs=100,
save_dir="checkpoints/"
)
trainer = StandardCATETrainer(config)
trainer.train()
Step 3: Load and Predict
from causalfm.models import StandardCATEModel
model = StandardCATEModel.from_pretrained("checkpoints/best_model.pth")
# Prepare data
result = model.estimate_cate(x_train, a_train, y_train, x_test)
cate = result['cate'] # Treatment effect estimates
Step 4: Evaluate
from causalfm.evaluation import compute_pehe, compute_ate_error
pehe = compute_pehe(cate, true_ite)
ate_error = compute_ate_error(cate, true_ite)
print(f"PEHE: {pehe:.4f}")
print(f"ATE Error: {ate_error:.4f}")
Your First CausalFM Script
Let’s put it all together:
"""
My first CausalFM script
"""
import torch
from causalfm.data import StandardCATEGenerator
from causalfm.models import StandardCATEModel
from causalfm.training import StandardCATETrainer, TrainingConfig
from causalfm.evaluation import compute_pehe
# 1. Generate data
print("Generating data...")
gen = StandardCATEGenerator(num_samples=1024, num_features=10)
gen.generate_multiple(100, "data/train/")
gen.generate_multiple(10, "data/test/")
# 2. Train model
if __name__ == '__main__':
print("Training model...")
config = TrainingConfig(
data_path="data/train/*.csv",
epochs=50,
batch_size=16,
num_workers=0,
save_dir="checkpoints/"
)
trainer = StandardCATETrainer(config)
trainer.train()
# 3. Load and predict
print("Making predictions...")
model = StandardCATEModel.from_pretrained("checkpoints/best_model.pth")
# Prepare some test data
x_train = torch.randn(800, 10)
a_train = torch.randint(0, 2, (800, 1)).float()
y_train = torch.randn(800, 1)
x_test = torch.randn(200, 10)
result = model.estimate_cate(x_train, a_train, y_train, x_test)
cate = result['cate']
print(f"Estimated {len(cate)} treatment effects!")
print(f"Mean CATE: {cate.mean():.4f}")
Key Takeaways
✅ CausalFM uses foundation models that learn from many datasets
✅ In-context learning allows adaptation with few samples
✅ Three causal settings for different identification strategies
✅ Simple workflow: Generate → Train → Predict → Evaluate
✅ Zero-shot transfer to new datasets without retraining
Next Steps
Now that you understand the basics, continue to:
Tutorial 2: Data Generation - Learn about data generation
Tutorial 3: Training Models - Dive deep into training
Standard CATE Estimation Example - See a complete working example
Questions?
Check the Quick Start for quick reference
Read the Models for model details
Look at example notebooks in
evaluation/notebook/