Models API

This page documents the model APIs in CausalFM.

Model Classes

StandardCATEModel

Class: causalfm.models.standard.StandardCATEModel

Foundation model for standard CATE estimation.

This model uses a transformer architecture with GMM prediction head for estimating conditional average treatment effects.

Example:

from causalfm.models import StandardCATEModel
import torch

# Create new model
model = StandardCATEModel(use_gmm_head=True, gmm_n_components=5)

# Or load pretrained
model = StandardCATEModel.from_pretrained("checkpoints/best_model.pth")

# Estimate CATE
x_train = torch.randn(800, 10)
a_train = torch.randint(0, 2, (800, 1)).float()
y_train = torch.randn(800, 1)
x_test = torch.randn(200, 10)

result = model.estimate_cate(x_train, a_train, y_train, x_test)
cate = result['cate']  # Point estimates

IVModel

Class: causalfm.models.iv.IVModel

Foundation model for instrumental variables setting.

This model uses instruments to identify causal effects in the presence of unobserved confounding.

Example:

from causalfm.models import IVModel
import torch

model = IVModel.from_pretrained("checkpoints/iv_model.pth")

# Requires instrument variable z
x_train = torch.randn(800, 10)
z_train = torch.randint(0, 2, (800, 1)).float()  # Binary instrument
a_train = torch.randint(0, 2, (800, 1)).float()
y_train = torch.randn(800, 1)
x_test = torch.randn(200, 10)

result = model.estimate_cate(x_train, z_train, a_train, y_train, x_test)
cate = result['cate']

FrontdoorModel

Class: causalfm.models.frontdoor.FrontdoorModel

Foundation model for front-door adjustment setting.

This model uses mediators to identify causal effects via front-door adjustment when backdoor paths are blocked.

Example:

from causalfm.models import FrontdoorModel
import torch

model = FrontdoorModel.from_pretrained("checkpoints/fd_model.pth")

# Requires mediator variable m
x_train = torch.randn(800, 10)
m_train = torch.randn(800, 1)  # Mediator
a_train = torch.randint(0, 2, (800, 1)).float()
y_train = torch.randn(800, 1)
x_test = torch.randn(200, 10)

result = model.estimate_cate(x_train, m_train, a_train, y_train, x_test)
cate = result['cate']

Common Methods

All model classes share the following interface:

__init__(use_gmm_head, gmm_n_components, device)

Initialize a new model.

param bool use_gmm_head:

Whether to use GMM prediction head

param int gmm_n_components:

Number of mixture components

param str device:

Device to use (‘cuda’, ‘cpu’, or ‘auto’)

from_pretrained(checkpoint_path, device)

Load a pretrained model from checkpoint.

param str checkpoint_path:

Path to checkpoint file

param str device:

Device to use

return:

Loaded model instance

rtype:

Model

estimate_cate(...)

Estimate conditional average treatment effects.

return:

Dictionary with keys: - ‘cate’: Point estimates (n_test,) - ‘gmm_pi’: Mixture weights (n_test, n_components) - ‘gmm_mu’: Component means (n_test, n_components) - ‘gmm_sigma’: Component std devs (n_test, n_components)

rtype:

dict

save(path)

Save model checkpoint.

param str path:

Path to save checkpoint

eval_mode()

Set model to evaluation mode.

train_mode()

Set model to training mode.

Input Shapes

All models expect specific tensor shapes:

Standard CATE:

x_train: (n_train, n_features)
a_train: (n_train, 1)          # NOT (n_train,)
y_train: (n_train, 1)          # NOT (n_train,)
x_test:  (n_test, n_features)

IV Model:

x_train: (n_train, n_features)
z_train: (n_train, 1)          # Instrument
a_train: (n_train, 1)
y_train: (n_train, 1)
x_test:  (n_test, n_features)

Front-door Model:

x_train: (n_train, n_features)
m_train: (n_train, 1)          # Mediator
a_train: (n_train, 1)
y_train: (n_train, 1)
x_test:  (n_test, n_features)

Output Format

GMM Output

All models return a dictionary with GMM parameters:

result = model.estimate_cate(x_train, a_train, y_train, x_test)

# Point estimate (mixture mean)
cate = result['cate']  # Shape: (n_test,)

# GMM parameters
pi = result['gmm_pi']        # Mixing weights: (n_test, n_components)
mu = result['gmm_mu']        # Component means: (n_test, n_components)
sigma = result['gmm_sigma']  # Component stds: (n_test, n_components)

# Compute predictive variance
variance = (pi * (sigma**2 + mu**2)).sum(dim=-1) - cate**2

Confidence Intervals

Sample from the GMM to get confidence intervals:

import numpy as np

# Sample from GMM
n_samples = 10000
samples = np.zeros((len(cate), n_samples))

for i in range(len(cate)):
    # Sample components
    components = np.random.choice(
        len(pi[i]),
        size=n_samples,
        p=pi[i].cpu().numpy()
    )

    # Sample from each component
    for k in range(len(pi[i])):
        mask = (components == k)
        n_k = mask.sum()
        if n_k > 0:
            samples[i, mask] = np.random.normal(
                mu[i, k].cpu().numpy(),
                sigma[i, k].cpu().numpy(),
                n_k
            )

# Compute percentiles
ci_lower = np.percentile(samples, 2.5, axis=1)
ci_upper = np.percentile(samples, 97.5, axis=1)

Advanced Usage

Custom Model Configuration

model = StandardCATEModel(
    use_gmm_head=True,
    gmm_n_components=10,        # More components
    gmm_min_sigma=1e-4,         # Minimum variance
    gmm_pi_temp=0.8,            # Temperature for mixing
    device='cuda:0'
)

Batch Processing

from pathlib import Path
import torch
import pandas as pd

model = StandardCATEModel.from_pretrained("checkpoints/best_model.pth")

# Process multiple files
results = []
for file in Path("data/test/").glob("*.csv"):
    df = pd.read_csv(file)

    # Prepare data
    x_cols = [c for c in df.columns if c.startswith('x')]
    X = torch.FloatTensor(df[x_cols].values)
    A = torch.FloatTensor(df['treatment'].values).unsqueeze(1)
    Y = torch.FloatTensor(df['outcome'].values).unsqueeze(1)

    # Split
    n_train = int(0.8 * len(X))

    # Predict
    result = model.estimate_cate(
        X[:n_train], A[:n_train], Y[:n_train], X[n_train:]
    )

    results.append({
        'file': file.name,
        'mean_cate': result['cate'].mean().item()
    })

print(pd.DataFrame(results))

GPU Management

# Auto-detect device
model = StandardCATEModel.from_pretrained("model.pth", device='auto')

# Specific GPU
model = StandardCATEModel.from_pretrained("model.pth", device='cuda:1')

# Move between devices
model = model.to('cpu')

# Check current device
print(model.device)

Model Properties

Access model attributes:

# Get underlying model
underlying_model = model.model

# Get model parameters
params = model.parameters

# Count parameters
n_params = sum(p.numel() for p in model.parameters)
print(f"Model has {n_params:,} parameters")

See Also