Tutorial 1: Basics

Welcome to CausalFM! This tutorial introduces the fundamental concepts and workflow.

Learning Objectives

By the end of this tutorial, you will understand:

  • What are Prior-Data Fitted Networks (PFNs)

  • The key concepts in causal inference with CausalFM

  • The basic workflow from data to predictions

  • When to use different causal settings

What are Foundation Models for Causal Inference?

Traditional Approach

Traditional machine learning for causal inference:

  1. Collect a single dataset

  2. Train a model on that dataset

  3. Make predictions on new samples from the same distribution

Problem: Requires large datasets and doesn’t transfer well.

Foundation Model Approach (CausalFM)

CausalFM uses a different paradigm:

  1. Train once on many diverse synthetic datasets

  2. Model learns the structure of causal inference problems

  3. Apply to new datasets without retraining (zero-shot transfer)

  4. Adapt in-context using just a few samples

Advantage: Transfer learning + in-context adaptation!

Key Concepts

Prior-Data Fitted Networks (PFNs)

PFNs are trained to solve a distribution of tasks, not just one task.

# Traditional: Train on one dataset
model.fit(X, y)  # Requires large dataset

# PFN: Trained on many datasets, adapts in-context
model.estimate(X_context, y_context, X_query)  # Few samples needed!

In CausalFM, the “task” is estimating treatment effects for a particular dataset.

In-Context Learning

The model takes training examples as input and adapts its predictions:

# Provide context (training samples)
result = model.estimate_cate(
    x_train,  # Context: observed covariates
    a_train,  # Context: treatments
    y_train,  # Context: outcomes
    x_test    # Query: new covariates to predict for
)

The model learns the relationship between (X, A) → Y from the context samples!

CATE: Conditional Average Treatment Effect

CATE is the expected treatment effect for an individual with covariates X:

\[\tau(x) = \mathbb{E}[Y(1) - Y(0) | X = x]\]

Where:

  • Y(1) = potential outcome under treatment

  • Y(0) = potential outcome under control

  • X = covariates (features)

# Individual treatment effects
cate = model.estimate_cate(x_train, a_train, y_train, x_test)['cate']

# Average treatment effect
ate = cate.mean()

Causal Settings in CausalFM

CausalFM supports three causal inference settings:

1. Standard CATE Estimation

When to use: No unobserved confounding (all confounders measured).

from causalfm.models import StandardCATEModel

model = StandardCATEModel.from_pretrained("checkpoints/standard.pth")
result = model.estimate_cate(x_train, a_train, y_train, x_test)

2. Instrumental Variables

When to use: Unobserved confounding, but valid instrument available.

from causalfm.models import IVModel

model = IVModel.from_pretrained("checkpoints/iv.pth")
result = model.estimate_cate(x_train, z_train, a_train, y_train, x_test)

3. Front-door Adjustment

When to use: Unobserved confounding, mediator blocks backdoor path.

from causalfm.models import FrontdoorModel

model = FrontdoorModel.from_pretrained("checkpoints/fd.pth")
result = model.estimate_cate(x_train, m_train, a_train, y_train, x_test)

Basic Workflow

The CausalFM workflow has four main steps:

Step 1: Generate Data

from causalfm.data import StandardCATEGenerator

generator = StandardCATEGenerator(num_samples=1024, num_features=10)

# Generate training data
generator.generate_multiple(500, "data/train/")

# Generate test data
generator.generate_multiple(50, "data/test/")

Step 2: Train Model

from causalfm.training import StandardCATETrainer, TrainingConfig

if __name__ == '__main__':
    config = TrainingConfig(
        data_path="data/train/*.csv",
        epochs=100,
        save_dir="checkpoints/"
    )
    trainer = StandardCATETrainer(config)
    trainer.train()

Step 3: Load and Predict

from causalfm.models import StandardCATEModel

model = StandardCATEModel.from_pretrained("checkpoints/best_model.pth")

# Prepare data
result = model.estimate_cate(x_train, a_train, y_train, x_test)
cate = result['cate']  # Treatment effect estimates

Step 4: Evaluate

from causalfm.evaluation import compute_pehe, compute_ate_error

pehe = compute_pehe(cate, true_ite)
ate_error = compute_ate_error(cate, true_ite)

print(f"PEHE: {pehe:.4f}")
print(f"ATE Error: {ate_error:.4f}")

Your First CausalFM Script

Let’s put it all together:

"""
My first CausalFM script
"""
import torch
from causalfm.data import StandardCATEGenerator
from causalfm.models import StandardCATEModel
from causalfm.training import StandardCATETrainer, TrainingConfig
from causalfm.evaluation import compute_pehe

# 1. Generate data
print("Generating data...")
gen = StandardCATEGenerator(num_samples=1024, num_features=10)
gen.generate_multiple(100, "data/train/")
gen.generate_multiple(10, "data/test/")

# 2. Train model
if __name__ == '__main__':
    print("Training model...")
    config = TrainingConfig(
        data_path="data/train/*.csv",
        epochs=50,
        batch_size=16,
        num_workers=0,
        save_dir="checkpoints/"
    )
    trainer = StandardCATETrainer(config)
    trainer.train()

    # 3. Load and predict
    print("Making predictions...")
    model = StandardCATEModel.from_pretrained("checkpoints/best_model.pth")

    # Prepare some test data
    x_train = torch.randn(800, 10)
    a_train = torch.randint(0, 2, (800, 1)).float()
    y_train = torch.randn(800, 1)
    x_test = torch.randn(200, 10)

    result = model.estimate_cate(x_train, a_train, y_train, x_test)
    cate = result['cate']

    print(f"Estimated {len(cate)} treatment effects!")
    print(f"Mean CATE: {cate.mean():.4f}")

Key Takeaways

CausalFM uses foundation models that learn from many datasets

In-context learning allows adaptation with few samples

Three causal settings for different identification strategies

Simple workflow: Generate → Train → Predict → Evaluate

Zero-shot transfer to new datasets without retraining

Next Steps

Now that you understand the basics, continue to:

Questions?

  • Check the Quick Start for quick reference

  • Read the Models for model details

  • Look at example notebooks in evaluation/notebook/