Training

CausalFM provides powerful training modules for all causal inference settings. Training is based on the PFN (Prior-Data Fitted Networks) paradigm where models learn from many synthetic datasets.

Training Overview

Key Concepts

Prior-Data Fitted Networks (PFNs)

Instead of training on a single dataset, PFNs are trained on a distribution of synthetic datasets. This enables:

  • Transfer to new datasets without fine-tuning

  • In-context learning capabilities

  • Robust performance across diverse settings

Training Process
  1. Generate many synthetic datasets (100-1000s)

  2. Train model to predict outcomes in-context

  3. Model learns to adapt to dataset characteristics

  4. Deploy on real data without retraining

Loss Function

Models are trained using Gaussian Mixture Model (GMM) negative log-likelihood, which encourages both accurate predictions and calibrated uncertainty.

Training Configuration

All training is configured via TrainingConfig:

from causalfm.training import TrainingConfig

config = TrainingConfig(
    # Data
    data_path="data/train/*.csv",      # Path to training datasets
    val_split=0.2,                      # Validation split ratio

    # Training
    epochs=100,                         # Number of training epochs
    batch_size=16,                      # Batch size
    learning_rate=0.001,                # Learning rate
    weight_decay=1e-5,                  # L2 regularization

    # Optimization
    clip_grad=1.0,                      # Gradient clipping
    early_stop_patience=50,             # Early stopping patience

    # Model
    use_gmm_head=True,                  # Use GMM prediction head
    gmm_n_components=5,                 # Number of mixture components

    # Checkpointing
    save_dir="checkpoints/",            # Save directory
    save_freq=10,                       # Save every N epochs

    # Logging
    log_dir="logs/",                    # TensorBoard logs

    # Hardware
    device='auto',                      # 'auto', 'cuda', 'cpu'
    num_workers=4,                      # DataLoader workers

    # Reproducibility
    seed=42
)

Standard CATE Training

Training a Standard CATE Model

from causalfm.training import StandardCATETrainer, TrainingConfig
from causalfm.data import StandardCATEGenerator

# Step 1: Generate training data
generator = StandardCATEGenerator(num_samples=1024, num_features=10)
generator.generate_multiple(
    num_datasets=500,
    output_dir="data/train/"
)

# Step 2: Configure training
config = TrainingConfig(
    data_path="data/train/*.csv",
    epochs=100,
    batch_size=16,
    save_dir="checkpoints/standard/"
)

# Step 3: Train (must be wrapped for multiprocessing)
if __name__ == '__main__':
    trainer = StandardCATETrainer(config)
    trainer.train()

Training Output

During training, you’ll see:

Epoch 1/100
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:15
Train Loss: 1.2345 | Val Loss: 1.3456

Epoch 2/100
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:14
Train Loss: 1.1234 | Val Loss: 1.2345
✓ New best model saved!

...

Early stopping triggered after epoch 75
Best validation loss: 0.8765 at epoch 65
Training completed!

Simplified Interface

from causalfm.training import StandardCATETrainer

if __name__ == '__main__':
    # One-liner training
    trainer = StandardCATETrainer.from_args(
        data_path="data/train/*.csv",
        epochs=100,
        save_dir="checkpoints/"
    )
    trainer.train()

Instrumental Variables Training

Training an IV Model

from causalfm.training import IVTrainer, TrainingConfig
from causalfm.data import IVDataGenerator

# Generate IV data
generator = IVDataGenerator(
    num_samples=1024,
    num_features=10,
    instrument_type='binary'
)
generator.generate_multiple(
    num_datasets=500,
    output_dir="data/iv_train/"
)

# Train
if __name__ == '__main__':
    config = TrainingConfig(
        data_path="data/iv_train/*.csv",
        epochs=100,
        save_dir="checkpoints/iv/"
    )
    trainer = IVTrainer(config)
    trainer.train()

Binary vs Continuous IV

# Binary IV (e.g., randomized encouragement)
binary_gen = IVDataGenerator(
    num_samples=1024,
    num_features=10,
    instrument_type='binary'
)
binary_gen.generate_multiple(500, "data/iv_binary/")

# Train binary IV model
if __name__ == '__main__':
    trainer = IVTrainer.from_args(
        data_path="data/iv_binary/*.csv",
        save_dir="checkpoints/iv_binary/"
    )
    trainer.train()

# Continuous IV
conti_gen = IVDataGenerator(
    num_samples=1024,
    num_features=10,
    instrument_type='continuous'
)
conti_gen.generate_multiple(500, "data/iv_conti/")

# Train continuous IV model
if __name__ == '__main__':
    trainer = IVTrainer.from_args(
        data_path="data/iv_conti/*.csv",
        save_dir="checkpoints/iv_conti/"
    )
    trainer.train()

Front-door Training

Training a Front-door Model

from causalfm.training import FrontdoorTrainer, TrainingConfig
from causalfm.data import FrontdoorDataGenerator

# Generate front-door data
generator = FrontdoorDataGenerator(
    num_samples=1024,
    num_features=10
)
generator.generate_multiple(
    num_datasets=500,
    output_dir="data/frontdoor_train/"
)

# Train
if __name__ == '__main__':
    config = TrainingConfig(
        data_path="data/frontdoor_train/*.csv",
        epochs=100,
        save_dir="checkpoints/frontdoor/"
    )
    trainer = FrontdoorTrainer(config)
    trainer.train()

Advanced Configuration

Custom Training Settings

config = TrainingConfig(
    # Data
    data_path="data/train/*.csv",
    val_split=0.15,                     # 15% validation

    # Optimization
    learning_rate=0.0005,               # Lower learning rate
    weight_decay=1e-4,                  # Stronger regularization
    clip_grad=0.5,                      # Stricter gradient clipping

    # Scheduler (optional)
    use_scheduler=True,
    scheduler_patience=10,
    scheduler_factor=0.5,

    # Early stopping
    early_stop_patience=30,             # Stop after 30 epochs without improvement
    min_delta=1e-4,                     # Minimum improvement threshold

    # Batch size
    batch_size=32,                      # Larger batches

    # Model architecture
    use_gmm_head=True,
    gmm_n_components=10,                # More components for better uncertainty
    gmm_min_sigma=1e-4,
    gmm_pi_temp=0.8,

    # Checkpointing
    save_dir="checkpoints/custom/",
    save_freq=5,                        # Save every 5 epochs
    save_best_only=True,                # Only save best model

    # Logging
    log_dir="logs/custom/",
    log_freq=100,                       # Log every 100 steps

    # Hardware
    device='cuda:0',                    # Specific GPU
    num_workers=8,                      # More workers for data loading
    pin_memory=True,                    # Pin memory for faster GPU transfer

    # Reproducibility
    seed=42,
    deterministic=True                  # Deterministic mode (slower but reproducible)
)

Multi-GPU Training

# Use DataParallel for multi-GPU
config = TrainingConfig(
    data_path="data/train/*.csv",
    device='cuda',                      # Will use all available GPUs
    batch_size=64,                      # Increase batch size for multi-GPU
    num_workers=16                      # More workers
)

if __name__ == '__main__':
    trainer = StandardCATETrainer(config)
    trainer.train()

Monitoring Training

TensorBoard Integration

# Training automatically logs to TensorBoard
config = TrainingConfig(
    data_path="data/train/*.csv",
    log_dir="logs/experiment_1/"
)

if __name__ == '__main__':
    trainer = StandardCATETrainer(config)
    trainer.train()

View logs with:

tensorboard --logdir logs/experiment_1/

Tracked metrics include:

  • Training loss (per epoch)

  • Validation loss (per epoch)

  • Learning rate (if using scheduler)

  • Gradient norms

  • Model parameter statistics

Manual Logging

# Access trainer's history
if __name__ == '__main__':
    trainer = StandardCATETrainer(config)
    trainer.train()

    # After training
    print(f"Final train loss: {trainer.train_losses[-1]:.4f}")
    print(f"Final val loss: {trainer.val_losses[-1]:.4f}")
    print(f"Best val loss: {min(trainer.val_losses):.4f}")
    print(f"Best epoch: {trainer.best_epoch}")

Checkpointing

Automatic Checkpointing

config = TrainingConfig(
    data_path="data/train/*.csv",
    save_dir="checkpoints/",
    save_freq=10,                       # Save every 10 epochs
    save_best_only=False                # Save all checkpoints
)

# This will create:
# checkpoints/
#   ├── best_model.pth         # Best model by validation loss
#   ├── checkpoint_epoch_10.pth
#   ├── checkpoint_epoch_20.pth
#   └── ...

Resume Training

# Resume from checkpoint
config = TrainingConfig(
    data_path="data/train/*.csv",
    save_dir="checkpoints/",
    resume_from="checkpoints/checkpoint_epoch_50.pth"
)

if __name__ == '__main__':
    trainer = StandardCATETrainer(config)
    trainer.train()  # Continues from epoch 51

Troubleshooting

Common Issues

Multiprocessing Errors

# ❌ Wrong - causes errors
trainer = StandardCATETrainer(config)
trainer.train()

# ✅ Correct - wrap in if __name__ == '__main__':
if __name__ == '__main__':
    trainer = StandardCATETrainer(config)
    trainer.train()

# Or disable multiprocessing
config = TrainingConfig(
    data_path="data/*.csv",
    num_workers=0  # No multiprocessing
)

CUDA Out of Memory

# Reduce batch size
config = TrainingConfig(
    data_path="data/*.csv",
    batch_size=8,    # Smaller batches
    num_workers=0
)

# Or use CPU
config = TrainingConfig(
    data_path="data/*.csv",
    device='cpu'
)

Training Too Slow

# Increase batch size
config = TrainingConfig(
    data_path="data/*.csv",
    batch_size=32,           # Larger batches
    num_workers=8,           # More workers
    pin_memory=True          # Faster GPU transfer
)

Overfitting

# Increase regularization
config = TrainingConfig(
    data_path="data/*.csv",
    weight_decay=1e-3,       # Stronger L2
    early_stop_patience=20,  # Earlier stopping
    val_split=0.3            # More validation data
)

Not Converging

# Adjust learning rate
config = TrainingConfig(
    data_path="data/*.csv",
    learning_rate=0.0001,    # Lower LR
    clip_grad=0.5,           # Stricter clipping
    epochs=200               # More epochs
)

Best Practices

Data Generation

# Generate diverse training data
generator = StandardCATEGenerator(num_samples=1024, num_features=10)

# Use many datasets (500-1000 recommended)
generator.generate_multiple(
    num_datasets=1000,
    output_dir="data/train/"
)

# Separate test data with different seed
test_gen = StandardCATEGenerator(
    num_samples=1024,
    num_features=10,
    seed=999
)
test_gen.generate_multiple(100, "data/test/")

Training Configuration

# Recommended settings for most cases
config = TrainingConfig(
    data_path="data/train/*.csv",

    # Training
    epochs=150,
    batch_size=16,
    learning_rate=0.001,
    weight_decay=1e-5,

    # Early stopping
    early_stop_patience=50,
    val_split=0.2,

    # Model
    use_gmm_head=True,
    gmm_n_components=5,

    # Checkpointing
    save_dir="checkpoints/",
    save_freq=10,

    # Hardware
    device='auto',
    num_workers=4,

    # Reproducibility
    seed=42
)

Validation Strategy

# Use held-out datasets for validation
config = TrainingConfig(
    data_path="data/train/*.csv",
    val_split=0.2  # 20% for validation
)

# Or use separate validation files
config = TrainingConfig(
    data_path="data/train/*.csv",
    val_path="data/val/*.csv"  # Explicit validation set
)

API Reference

For complete API documentation, see:

  • causalfm.training.StandardCATETrainer

  • causalfm.training.IVTrainer

  • causalfm.training.FrontdoorTrainer

  • causalfm.training.TrainingConfig