Training
CausalFM provides powerful training modules for all causal inference settings. Training is based on the PFN (Prior-Data Fitted Networks) paradigm where models learn from many synthetic datasets.
Training Overview
Key Concepts
- Prior-Data Fitted Networks (PFNs)
Instead of training on a single dataset, PFNs are trained on a distribution of synthetic datasets. This enables:
Transfer to new datasets without fine-tuning
In-context learning capabilities
Robust performance across diverse settings
- Training Process
Generate many synthetic datasets (100-1000s)
Train model to predict outcomes in-context
Model learns to adapt to dataset characteristics
Deploy on real data without retraining
- Loss Function
Models are trained using Gaussian Mixture Model (GMM) negative log-likelihood, which encourages both accurate predictions and calibrated uncertainty.
Training Configuration
All training is configured via TrainingConfig:
from causalfm.training import TrainingConfig
config = TrainingConfig(
# Data
data_path="data/train/*.csv", # Path to training datasets
val_split=0.2, # Validation split ratio
# Training
epochs=100, # Number of training epochs
batch_size=16, # Batch size
learning_rate=0.001, # Learning rate
weight_decay=1e-5, # L2 regularization
# Optimization
clip_grad=1.0, # Gradient clipping
early_stop_patience=50, # Early stopping patience
# Model
use_gmm_head=True, # Use GMM prediction head
gmm_n_components=5, # Number of mixture components
# Checkpointing
save_dir="checkpoints/", # Save directory
save_freq=10, # Save every N epochs
# Logging
log_dir="logs/", # TensorBoard logs
# Hardware
device='auto', # 'auto', 'cuda', 'cpu'
num_workers=4, # DataLoader workers
# Reproducibility
seed=42
)
Standard CATE Training
Training a Standard CATE Model
from causalfm.training import StandardCATETrainer, TrainingConfig
from causalfm.data import StandardCATEGenerator
# Step 1: Generate training data
generator = StandardCATEGenerator(num_samples=1024, num_features=10)
generator.generate_multiple(
num_datasets=500,
output_dir="data/train/"
)
# Step 2: Configure training
config = TrainingConfig(
data_path="data/train/*.csv",
epochs=100,
batch_size=16,
save_dir="checkpoints/standard/"
)
# Step 3: Train (must be wrapped for multiprocessing)
if __name__ == '__main__':
trainer = StandardCATETrainer(config)
trainer.train()
Training Output
During training, you’ll see:
Epoch 1/100
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:15
Train Loss: 1.2345 | Val Loss: 1.3456
Epoch 2/100
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:14
Train Loss: 1.1234 | Val Loss: 1.2345
✓ New best model saved!
...
Early stopping triggered after epoch 75
Best validation loss: 0.8765 at epoch 65
Training completed!
Simplified Interface
from causalfm.training import StandardCATETrainer
if __name__ == '__main__':
# One-liner training
trainer = StandardCATETrainer.from_args(
data_path="data/train/*.csv",
epochs=100,
save_dir="checkpoints/"
)
trainer.train()
Instrumental Variables Training
Training an IV Model
from causalfm.training import IVTrainer, TrainingConfig
from causalfm.data import IVDataGenerator
# Generate IV data
generator = IVDataGenerator(
num_samples=1024,
num_features=10,
instrument_type='binary'
)
generator.generate_multiple(
num_datasets=500,
output_dir="data/iv_train/"
)
# Train
if __name__ == '__main__':
config = TrainingConfig(
data_path="data/iv_train/*.csv",
epochs=100,
save_dir="checkpoints/iv/"
)
trainer = IVTrainer(config)
trainer.train()
Binary vs Continuous IV
# Binary IV (e.g., randomized encouragement)
binary_gen = IVDataGenerator(
num_samples=1024,
num_features=10,
instrument_type='binary'
)
binary_gen.generate_multiple(500, "data/iv_binary/")
# Train binary IV model
if __name__ == '__main__':
trainer = IVTrainer.from_args(
data_path="data/iv_binary/*.csv",
save_dir="checkpoints/iv_binary/"
)
trainer.train()
# Continuous IV
conti_gen = IVDataGenerator(
num_samples=1024,
num_features=10,
instrument_type='continuous'
)
conti_gen.generate_multiple(500, "data/iv_conti/")
# Train continuous IV model
if __name__ == '__main__':
trainer = IVTrainer.from_args(
data_path="data/iv_conti/*.csv",
save_dir="checkpoints/iv_conti/"
)
trainer.train()
Front-door Training
Training a Front-door Model
from causalfm.training import FrontdoorTrainer, TrainingConfig
from causalfm.data import FrontdoorDataGenerator
# Generate front-door data
generator = FrontdoorDataGenerator(
num_samples=1024,
num_features=10
)
generator.generate_multiple(
num_datasets=500,
output_dir="data/frontdoor_train/"
)
# Train
if __name__ == '__main__':
config = TrainingConfig(
data_path="data/frontdoor_train/*.csv",
epochs=100,
save_dir="checkpoints/frontdoor/"
)
trainer = FrontdoorTrainer(config)
trainer.train()
Advanced Configuration
Custom Training Settings
config = TrainingConfig(
# Data
data_path="data/train/*.csv",
val_split=0.15, # 15% validation
# Optimization
learning_rate=0.0005, # Lower learning rate
weight_decay=1e-4, # Stronger regularization
clip_grad=0.5, # Stricter gradient clipping
# Scheduler (optional)
use_scheduler=True,
scheduler_patience=10,
scheduler_factor=0.5,
# Early stopping
early_stop_patience=30, # Stop after 30 epochs without improvement
min_delta=1e-4, # Minimum improvement threshold
# Batch size
batch_size=32, # Larger batches
# Model architecture
use_gmm_head=True,
gmm_n_components=10, # More components for better uncertainty
gmm_min_sigma=1e-4,
gmm_pi_temp=0.8,
# Checkpointing
save_dir="checkpoints/custom/",
save_freq=5, # Save every 5 epochs
save_best_only=True, # Only save best model
# Logging
log_dir="logs/custom/",
log_freq=100, # Log every 100 steps
# Hardware
device='cuda:0', # Specific GPU
num_workers=8, # More workers for data loading
pin_memory=True, # Pin memory for faster GPU transfer
# Reproducibility
seed=42,
deterministic=True # Deterministic mode (slower but reproducible)
)
Multi-GPU Training
# Use DataParallel for multi-GPU
config = TrainingConfig(
data_path="data/train/*.csv",
device='cuda', # Will use all available GPUs
batch_size=64, # Increase batch size for multi-GPU
num_workers=16 # More workers
)
if __name__ == '__main__':
trainer = StandardCATETrainer(config)
trainer.train()
Monitoring Training
TensorBoard Integration
# Training automatically logs to TensorBoard
config = TrainingConfig(
data_path="data/train/*.csv",
log_dir="logs/experiment_1/"
)
if __name__ == '__main__':
trainer = StandardCATETrainer(config)
trainer.train()
View logs with:
tensorboard --logdir logs/experiment_1/
Tracked metrics include:
Training loss (per epoch)
Validation loss (per epoch)
Learning rate (if using scheduler)
Gradient norms
Model parameter statistics
Manual Logging
# Access trainer's history
if __name__ == '__main__':
trainer = StandardCATETrainer(config)
trainer.train()
# After training
print(f"Final train loss: {trainer.train_losses[-1]:.4f}")
print(f"Final val loss: {trainer.val_losses[-1]:.4f}")
print(f"Best val loss: {min(trainer.val_losses):.4f}")
print(f"Best epoch: {trainer.best_epoch}")
Checkpointing
Automatic Checkpointing
config = TrainingConfig(
data_path="data/train/*.csv",
save_dir="checkpoints/",
save_freq=10, # Save every 10 epochs
save_best_only=False # Save all checkpoints
)
# This will create:
# checkpoints/
# ├── best_model.pth # Best model by validation loss
# ├── checkpoint_epoch_10.pth
# ├── checkpoint_epoch_20.pth
# └── ...
Resume Training
# Resume from checkpoint
config = TrainingConfig(
data_path="data/train/*.csv",
save_dir="checkpoints/",
resume_from="checkpoints/checkpoint_epoch_50.pth"
)
if __name__ == '__main__':
trainer = StandardCATETrainer(config)
trainer.train() # Continues from epoch 51
Troubleshooting
Common Issues
Multiprocessing Errors
# ❌ Wrong - causes errors
trainer = StandardCATETrainer(config)
trainer.train()
# ✅ Correct - wrap in if __name__ == '__main__':
if __name__ == '__main__':
trainer = StandardCATETrainer(config)
trainer.train()
# Or disable multiprocessing
config = TrainingConfig(
data_path="data/*.csv",
num_workers=0 # No multiprocessing
)
CUDA Out of Memory
# Reduce batch size
config = TrainingConfig(
data_path="data/*.csv",
batch_size=8, # Smaller batches
num_workers=0
)
# Or use CPU
config = TrainingConfig(
data_path="data/*.csv",
device='cpu'
)
Training Too Slow
# Increase batch size
config = TrainingConfig(
data_path="data/*.csv",
batch_size=32, # Larger batches
num_workers=8, # More workers
pin_memory=True # Faster GPU transfer
)
Overfitting
# Increase regularization
config = TrainingConfig(
data_path="data/*.csv",
weight_decay=1e-3, # Stronger L2
early_stop_patience=20, # Earlier stopping
val_split=0.3 # More validation data
)
Not Converging
# Adjust learning rate
config = TrainingConfig(
data_path="data/*.csv",
learning_rate=0.0001, # Lower LR
clip_grad=0.5, # Stricter clipping
epochs=200 # More epochs
)
Best Practices
Data Generation
# Generate diverse training data
generator = StandardCATEGenerator(num_samples=1024, num_features=10)
# Use many datasets (500-1000 recommended)
generator.generate_multiple(
num_datasets=1000,
output_dir="data/train/"
)
# Separate test data with different seed
test_gen = StandardCATEGenerator(
num_samples=1024,
num_features=10,
seed=999
)
test_gen.generate_multiple(100, "data/test/")
Training Configuration
# Recommended settings for most cases
config = TrainingConfig(
data_path="data/train/*.csv",
# Training
epochs=150,
batch_size=16,
learning_rate=0.001,
weight_decay=1e-5,
# Early stopping
early_stop_patience=50,
val_split=0.2,
# Model
use_gmm_head=True,
gmm_n_components=5,
# Checkpointing
save_dir="checkpoints/",
save_freq=10,
# Hardware
device='auto',
num_workers=4,
# Reproducibility
seed=42
)
Validation Strategy
# Use held-out datasets for validation
config = TrainingConfig(
data_path="data/train/*.csv",
val_split=0.2 # 20% for validation
)
# Or use separate validation files
config = TrainingConfig(
data_path="data/train/*.csv",
val_path="data/val/*.csv" # Explicit validation set
)
API Reference
For complete API documentation, see:
causalfm.training.StandardCATETrainercausalfm.training.IVTrainercausalfm.training.FrontdoorTrainercausalfm.training.TrainingConfig