Models
CausalFM provides three main model classes for different causal inference settings, all built on the TabPFN foundation model architecture.
Model Architecture
All CausalFM models share a common architecture:
Transformer-based encoder with per-feature attention
GMM prediction head for uncertainty quantification
Context-based learning using training samples as context
In-context adaptation without gradient updates
Key Features
- ✨ Foundation Model Approach
Models are pre-trained on diverse synthetic datasets and can adapt to new datasets in-context without fine-tuning.
- 🎯 Gaussian Mixture Model Head
Instead of point estimates, models output a mixture of Gaussians, providing:
Point estimates (mixture mean)
Uncertainty quantification (mixture variance)
Full predictive distribution
- 📊 Transformer Architecture
Per-feature encoding
Multi-head attention mechanisms
Layer normalization and residual connections
Standard CATE Model
The StandardCATEModel is designed for standard CATE estimation without
unobserved confounding.
Basic Usage
from causalfm.models import StandardCATEModel
import torch
# Create new model
model = StandardCATEModel(
use_gmm_head=True,
gmm_n_components=5,
device='cuda'
)
# Or load pretrained
model = StandardCATEModel.from_pretrained(
"checkpoints/best_model.pth",
device='cuda'
)
Model Parameters
model = StandardCATEModel(
use_gmm_head=True, # Use GMM for uncertainty
gmm_n_components=5, # Number of mixture components
gmm_min_sigma=1e-3, # Minimum std dev
gmm_pi_temp=1.0, # Temperature for mixing weights
device='cuda' # 'cuda', 'cpu', or None (auto)
)
Estimating CATE
# Prepare data
x_train = torch.randn(800, 10) # Covariates
a_train = torch.randint(0, 2, (800, 1)).float() # Treatments
y_train = torch.randn(800, 1) # Outcomes
x_test = torch.randn(200, 10) # Test covariates
# Estimate CATE
result = model.estimate_cate(x_train, a_train, y_train, x_test)
# Extract results
cate = result['cate'] # Point estimates (200,)
# Uncertainty quantification
pi = result['gmm_pi'] # Mixture weights (200, 5)
mu = result['gmm_mu'] # Component means (200, 5)
sigma = result['gmm_sigma'] # Component std devs (200, 5)
# Compute confidence intervals
import numpy as np
lower = np.percentile(mu.cpu().numpy(), 2.5, axis=1)
upper = np.percentile(mu.cpu().numpy(), 97.5, axis=1)
Input Format
Important: Ensure correct tensor shapes:
# ✅ Correct shapes
x_train: (n_train, n_features) # e.g., (800, 10)
a_train: (n_train, 1) # e.g., (800, 1) - NOT (800,)
y_train: (n_train, 1) # e.g., (800, 1) - NOT (800,)
x_test: (n_test, n_features) # e.g., (200, 10)
# ❌ Wrong - will cause errors
a_train_wrong: (800,) # Missing dimension
y_train_wrong: (800,) # Missing dimension
Model Methods
# Set to evaluation mode
model.eval_mode()
# Set to training mode
model.train_mode()
# Save model
model.save("my_model.pth")
# Get model parameters
params = model.parameters
# Direct forward pass (for training)
output = model.forward(x, a, y, single_eval_pos)
Instrumental Variables Model
The IVModel handles settings with unobserved confounding using
instrumental variables.
Basic Usage
from causalfm.models import IVModel
# Load pretrained IV model
model = IVModel.from_pretrained(
"checkpoints/iv_binary_model.pth",
device='cuda'
)
Estimating CATE with Instruments
# Prepare data (including instruments)
x_train = torch.randn(800, 10) # Covariates
z_train = torch.randint(0, 2, (800, 1)).float() # Binary instrument
a_train = torch.randint(0, 2, (800, 1)).float() # Treatment
y_train = torch.randn(800, 1) # Outcome
x_test = torch.randn(200, 10) # Test covariates
# Estimate CATE using IV
result = model.estimate_cate(
x_train, z_train, a_train, y_train, x_test
)
cate = result['cate']
Input Requirements
The IV model requires an additional instrument input:
# All inputs required:
x_train: (n_train, n_features) # Observed covariates
z_train: (n_train, 1) # Instrument (binary or continuous)
a_train: (n_train, 1) # Treatment
y_train: (n_train, 1) # Outcome
x_test: (n_test, n_features) # Test covariates
Binary vs Continuous Instruments
# Binary instrument (e.g., randomized encouragement)
z_binary = torch.randint(0, 2, (800, 1)).float()
# Continuous instrument (e.g., distance, price)
z_continuous = torch.randn(800, 1)
# Model handles both types
result = model.estimate_cate(x_train, z_binary, a_train, y_train, x_test)
result = model.estimate_cate(x_train, z_continuous, a_train, y_train, x_test)
Front-door Model
The FrontdoorModel uses mediators to identify causal effects when
there are unobserved confounders.
Basic Usage
from causalfm.models import FrontdoorModel
# Load pretrained front-door model
model = FrontdoorModel.from_pretrained(
"checkpoints/frontdoor_model.pth",
device='cuda'
)
Estimating CATE with Mediators
# Prepare data (including mediators)
x_train = torch.randn(800, 10) # Covariates
m_train = torch.randn(800, 1) # Mediator values
a_train = torch.randint(0, 2, (800, 1)).float() # Treatment
y_train = torch.randn(800, 1) # Outcome
x_test = torch.randn(200, 10) # Test covariates
# Estimate CATE using front-door adjustment
result = model.estimate_cate(
x_train, m_train, a_train, y_train, x_test
)
cate = result['cate']
Input Requirements
The front-door model requires mediator observations:
# All inputs required:
x_train: (n_train, n_features) # Observed covariates
m_train: (n_train, 1) # Mediator values
a_train: (n_train, 1) # Treatment
y_train: (n_train, 1) # Outcome
x_test: (n_test, n_features) # Test covariates
Model Loading and Saving
Loading Pretrained Models
from causalfm.models import StandardCATEModel, IVModel, FrontdoorModel
# Standard CATE
standard_model = StandardCATEModel.from_pretrained(
"checkpoints/standard/best_model.pth",
device='cuda'
)
# IV
iv_model = IVModel.from_pretrained(
"checkpoints/iv/best_model.pth",
device='cpu' # Use CPU
)
# Front-door
fd_model = FrontdoorModel.from_pretrained(
"checkpoints/frontdoor/best_model.pth"
# device='auto' by default
)
Saving Models
# Save model state
model.save("my_custom_model.pth")
# The checkpoint includes:
# - model_state_dict: Model parameters
# - Other metadata
Creating New Models
# Create a new untrained model
new_model = StandardCATEModel(
use_gmm_head=True,
gmm_n_components=5,
device='cuda'
)
# Train it (see Training guide)
from causalfm.training import StandardCATETrainer, TrainingConfig
config = TrainingConfig(data_path="data/*.csv")
trainer = StandardCATETrainer(config)
trainer.train()
Uncertainty Quantification
GMM Output Interpretation
The GMM head outputs a mixture of Gaussians:
result = model.estimate_cate(x_train, a_train, y_train, x_test)
pi = result['gmm_pi'] # Shape: (n_test, n_components)
mu = result['gmm_mu'] # Shape: (n_test, n_components)
sigma = result['gmm_sigma'] # Shape: (n_test, n_components)
# Point estimate (mixture mean)
cate = result['cate'] # = sum(pi * mu, axis=-1)
# Variance (mixture variance)
variance = (pi * (sigma**2 + mu**2)).sum(dim=-1) - cate**2
std_dev = torch.sqrt(variance)
Computing Confidence Intervals
import numpy as np
# Sample from the GMM
n_samples = 1000
n_test = len(pi)
samples = np.zeros((n_test, n_samples))
for i in range(n_test):
# Sample component indices
components = np.random.choice(
len(pi[i]),
size=n_samples,
p=pi[i].cpu().numpy()
)
# Sample from selected components
for k in range(len(pi[i])):
mask = (components == k)
n_k = mask.sum()
if n_k > 0:
samples[i, mask] = np.random.normal(
mu[i, k].cpu().numpy(),
sigma[i, k].cpu().numpy(),
n_k
)
# Compute percentiles
ci_lower = np.percentile(samples, 2.5, axis=1)
ci_upper = np.percentile(samples, 97.5, axis=1)
Model Comparison
Choosing the Right Model
Setting |
Model |
When to Use |
|---|---|---|
Standard |
|
No unobserved confounding |
IV |
|
Unobserved confounders + valid instrument |
Front-door |
|
Unobserved confounders + mediator |
Example Comparison
# All models have the same interface
models = {
'Standard': StandardCATEModel.from_pretrained("checkpoints/standard.pth"),
'IV': IVModel.from_pretrained("checkpoints/iv.pth"),
'Frontdoor': FrontdoorModel.from_pretrained("checkpoints/fd.pth")
}
# Different inputs required
results = {}
# Standard
results['Standard'] = models['Standard'].estimate_cate(
x_train, a_train, y_train, x_test
)
# IV (needs instrument)
results['IV'] = models['IV'].estimate_cate(
x_train, z_train, a_train, y_train, x_test
)
# Front-door (needs mediator)
results['Frontdoor'] = models['Frontdoor'].estimate_cate(
x_train, m_train, a_train, y_train, x_test
)
Advanced Usage
Batch Processing
import pandas as pd
from causalfm.models import StandardCATEModel
model = StandardCATEModel.from_pretrained("checkpoints/best_model.pth")
# Process multiple test datasets
test_files = ["data/test1.csv", "data/test2.csv", "data/test3.csv"]
all_results = []
for file in test_files:
df = pd.read_csv(file)
# Extract features
x_cols = [c for c in df.columns if c.startswith('x')]
X = torch.FloatTensor(df[x_cols].values)
A = torch.FloatTensor(df['treatment'].values).unsqueeze(1)
Y = torch.FloatTensor(df['outcome'].values).unsqueeze(1)
# Split
n_train = int(0.8 * len(X))
result = model.estimate_cate(
X[:n_train], A[:n_train], Y[:n_train], X[n_train:]
)
all_results.append(result['cate'])
GPU/CPU Management
# Auto-detect device
model = StandardCATEModel.from_pretrained("model.pth", device='auto')
# Force CPU
model_cpu = StandardCATEModel.from_pretrained("model.pth", device='cpu')
# Specific GPU
model_gpu = StandardCATEModel.from_pretrained("model.pth", device='cuda:0')
# Move between devices
model = model.to('cuda:1') # Move to GPU 1
API Reference
For complete API documentation, see:
causalfm.models.StandardCATEModelcausalfm.models.IVModelcausalfm.models.FrontdoorModel