Models API
This page documents the model APIs in CausalFM.
Model Classes
StandardCATEModel
Class: causalfm.models.standard.StandardCATEModel
Foundation model for standard CATE estimation.
This model uses a transformer architecture with GMM prediction head for estimating conditional average treatment effects.
Example:
from causalfm.models import StandardCATEModel import torch # Create new model model = StandardCATEModel(use_gmm_head=True, gmm_n_components=5) # Or load pretrained model = StandardCATEModel.from_pretrained("checkpoints/best_model.pth") # Estimate CATE x_train = torch.randn(800, 10) a_train = torch.randint(0, 2, (800, 1)).float() y_train = torch.randn(800, 1) x_test = torch.randn(200, 10) result = model.estimate_cate(x_train, a_train, y_train, x_test) cate = result['cate'] # Point estimates
IVModel
Class: causalfm.models.iv.IVModel
Foundation model for instrumental variables setting.
This model uses instruments to identify causal effects in the presence of unobserved confounding.
Example:
from causalfm.models import IVModel import torch model = IVModel.from_pretrained("checkpoints/iv_model.pth") # Requires instrument variable z x_train = torch.randn(800, 10) z_train = torch.randint(0, 2, (800, 1)).float() # Binary instrument a_train = torch.randint(0, 2, (800, 1)).float() y_train = torch.randn(800, 1) x_test = torch.randn(200, 10) result = model.estimate_cate(x_train, z_train, a_train, y_train, x_test) cate = result['cate']
FrontdoorModel
Class: causalfm.models.frontdoor.FrontdoorModel
Foundation model for front-door adjustment setting.
This model uses mediators to identify causal effects via front-door adjustment when backdoor paths are blocked.
Example:
from causalfm.models import FrontdoorModel import torch model = FrontdoorModel.from_pretrained("checkpoints/fd_model.pth") # Requires mediator variable m x_train = torch.randn(800, 10) m_train = torch.randn(800, 1) # Mediator a_train = torch.randint(0, 2, (800, 1)).float() y_train = torch.randn(800, 1) x_test = torch.randn(200, 10) result = model.estimate_cate(x_train, m_train, a_train, y_train, x_test) cate = result['cate']
Common Methods
All model classes share the following interface:
__init__(use_gmm_head, gmm_n_components, device)Initialize a new model.
- param bool use_gmm_head:
Whether to use GMM prediction head
- param int gmm_n_components:
Number of mixture components
- param str device:
Device to use (‘cuda’, ‘cpu’, or ‘auto’)
from_pretrained(checkpoint_path, device)Load a pretrained model from checkpoint.
- param str checkpoint_path:
Path to checkpoint file
- param str device:
Device to use
- return:
Loaded model instance
- rtype:
Model
estimate_cate(...)Estimate conditional average treatment effects.
- return:
Dictionary with keys: - ‘cate’: Point estimates (n_test,) - ‘gmm_pi’: Mixture weights (n_test, n_components) - ‘gmm_mu’: Component means (n_test, n_components) - ‘gmm_sigma’: Component std devs (n_test, n_components)
- rtype:
dict
save(path)Save model checkpoint.
- param str path:
Path to save checkpoint
eval_mode()Set model to evaluation mode.
train_mode()Set model to training mode.
Input Shapes
All models expect specific tensor shapes:
Standard CATE:
x_train: (n_train, n_features)
a_train: (n_train, 1) # NOT (n_train,)
y_train: (n_train, 1) # NOT (n_train,)
x_test: (n_test, n_features)
IV Model:
x_train: (n_train, n_features)
z_train: (n_train, 1) # Instrument
a_train: (n_train, 1)
y_train: (n_train, 1)
x_test: (n_test, n_features)
Front-door Model:
x_train: (n_train, n_features)
m_train: (n_train, 1) # Mediator
a_train: (n_train, 1)
y_train: (n_train, 1)
x_test: (n_test, n_features)
Output Format
GMM Output
All models return a dictionary with GMM parameters:
result = model.estimate_cate(x_train, a_train, y_train, x_test)
# Point estimate (mixture mean)
cate = result['cate'] # Shape: (n_test,)
# GMM parameters
pi = result['gmm_pi'] # Mixing weights: (n_test, n_components)
mu = result['gmm_mu'] # Component means: (n_test, n_components)
sigma = result['gmm_sigma'] # Component stds: (n_test, n_components)
# Compute predictive variance
variance = (pi * (sigma**2 + mu**2)).sum(dim=-1) - cate**2
Confidence Intervals
Sample from the GMM to get confidence intervals:
import numpy as np
# Sample from GMM
n_samples = 10000
samples = np.zeros((len(cate), n_samples))
for i in range(len(cate)):
# Sample components
components = np.random.choice(
len(pi[i]),
size=n_samples,
p=pi[i].cpu().numpy()
)
# Sample from each component
for k in range(len(pi[i])):
mask = (components == k)
n_k = mask.sum()
if n_k > 0:
samples[i, mask] = np.random.normal(
mu[i, k].cpu().numpy(),
sigma[i, k].cpu().numpy(),
n_k
)
# Compute percentiles
ci_lower = np.percentile(samples, 2.5, axis=1)
ci_upper = np.percentile(samples, 97.5, axis=1)
Advanced Usage
Custom Model Configuration
model = StandardCATEModel(
use_gmm_head=True,
gmm_n_components=10, # More components
gmm_min_sigma=1e-4, # Minimum variance
gmm_pi_temp=0.8, # Temperature for mixing
device='cuda:0'
)
Batch Processing
from pathlib import Path
import torch
import pandas as pd
model = StandardCATEModel.from_pretrained("checkpoints/best_model.pth")
# Process multiple files
results = []
for file in Path("data/test/").glob("*.csv"):
df = pd.read_csv(file)
# Prepare data
x_cols = [c for c in df.columns if c.startswith('x')]
X = torch.FloatTensor(df[x_cols].values)
A = torch.FloatTensor(df['treatment'].values).unsqueeze(1)
Y = torch.FloatTensor(df['outcome'].values).unsqueeze(1)
# Split
n_train = int(0.8 * len(X))
# Predict
result = model.estimate_cate(
X[:n_train], A[:n_train], Y[:n_train], X[n_train:]
)
results.append({
'file': file.name,
'mean_cate': result['cate'].mean().item()
})
print(pd.DataFrame(results))
GPU Management
# Auto-detect device
model = StandardCATEModel.from_pretrained("model.pth", device='auto')
# Specific GPU
model = StandardCATEModel.from_pretrained("model.pth", device='cuda:1')
# Move between devices
model = model.to('cpu')
# Check current device
print(model.device)
Model Properties
Access model attributes:
# Get underlying model
underlying_model = model.model
# Get model parameters
params = model.parameters
# Count parameters
n_params = sum(p.numel() for p in model.parameters)
print(f"Model has {n_params:,} parameters")
See Also
Models - Detailed model usage guide
Standard CATE Estimation Example - Complete example
Training API - Training API reference