Training API
This page documents the training APIs in CausalFM.
Trainer Classes
StandardCATETrainer
Class: causalfm.training.standard.StandardCATETrainer
Trainer for standard CATE models.
Example:
from causalfm.training import StandardCATETrainer, TrainingConfig config = TrainingConfig( data_path="data/train/*.csv", epochs=100, batch_size=16, save_dir="checkpoints/" ) if __name__ == '__main__': trainer = StandardCATETrainer(config) trainer.train()
IVTrainer
Class: causalfm.training.iv.IVTrainer
Trainer for instrumental variables models.
Example:
from causalfm.training import IVTrainer, TrainingConfig config = TrainingConfig( data_path="data/iv_train/*.csv", epochs=100, save_dir="checkpoints/iv/" ) if __name__ == '__main__': trainer = IVTrainer(config) trainer.train()
FrontdoorTrainer
Class: causalfm.training.frontdoor.FrontdoorTrainer
Trainer for front-door adjustment models.
Example:
from causalfm.training import FrontdoorTrainer, TrainingConfig config = TrainingConfig( data_path="data/fd_train/*.csv", epochs=100, save_dir="checkpoints/frontdoor/" ) if __name__ == '__main__': trainer = FrontdoorTrainer(config) trainer.train()
Configuration
TrainingConfig
Class: causalfm.training.base.TrainingConfig
Configuration class for training.
Parameters:
Data Settings:
data_path(str): Glob pattern for training data files
val_split(float): Validation split ratio (default: 0.2)
batch_size(int): Training batch size (default: 16)
num_workers(int): Number of data loading workers (default: 4)Training Settings:
epochs(int): Number of training epochs (default: 100)
learning_rate(float): Learning rate (default: 0.001)
weight_decay(float): L2 regularization (default: 1e-5)
clip_grad(float): Gradient clipping value (default: 1.0)
early_stop_patience(int): Early stopping patience (default: 50)Model Settings:
use_gmm_head(bool): Use GMM prediction head (default: True)
gmm_n_components(int): Number of mixture components (default: 5)
gmm_min_sigma(float): Minimum std dev (default: 1e-3)
gmm_pi_temp(float): Mixture weight temperature (default: 1.0)Checkpointing:
save_dir(str): Directory for checkpoints (default: “checkpoints/”)
save_freq(int): Save every N epochs (default: 10)
save_best_only(bool): Only save best model (default: False)
resume_from(str): Path to resume training from (default: None)Logging:
log_dir(str): TensorBoard log directory (default: “logs/”)
log_freq(int): Log every N steps (default: 100)Hardware:
device(str): Device to use (‘auto’, ‘cuda’, ‘cpu’) (default: ‘auto’)
gpu_id(int): GPU ID to use (default: 0)
pin_memory(bool): Pin memory for faster GPU transfer (default: True)Reproducibility:
seed(int): Random seed (default: 42)
deterministic(bool): Use deterministic mode (default: False)Example:
config = TrainingConfig( # Data data_path="data/train/*.csv", val_split=0.2, batch_size=16, num_workers=4, # Training epochs=150, learning_rate=0.001, weight_decay=1e-5, clip_grad=1.0, early_stop_patience=50, # Model use_gmm_head=True, gmm_n_components=5, # Checkpointing save_dir="checkpoints/", save_freq=10, # Logging log_dir="logs/", # Hardware device='auto', num_workers=0, # Set to 0 to avoid multiprocessing issues # Reproducibility seed=42 )
Common Trainer Methods
All trainer classes share the following interface:
__init__(config)Initialize trainer with configuration.
- param TrainingConfig config:
Training configuration
train()Start training loop.
- return:
None
This method will:
Load training and validation data
Initialize model and optimizer
Run training loop with progress bars
Save checkpoints periodically
Apply early stopping if validation loss doesn’t improve
Log metrics to TensorBoard
from_args(**kwargs)Class method to create trainer from keyword arguments.
- param kwargs:
Arguments to pass to TrainingConfig
- return:
Trainer instance
Example:
trainer = StandardCATETrainer.from_args( data_path="data/*.csv", epochs=100, batch_size=16, save_dir="checkpoints/" )
Training Workflow
Basic Training
from causalfm.training import StandardCATETrainer, TrainingConfig
# Configure
config = TrainingConfig(
data_path="data/train/*.csv",
epochs=100,
save_dir="checkpoints/"
)
# Train (wrap in if __name__ == '__main__' for multiprocessing)
if __name__ == '__main__':
trainer = StandardCATETrainer(config)
trainer.train()
Training with Custom Settings
config = TrainingConfig(
# Data
data_path="data/train/*.csv",
val_split=0.15,
batch_size=32,
num_workers=8,
# Optimization
learning_rate=0.0005,
weight_decay=1e-4,
clip_grad=0.5,
# Early stopping
early_stop_patience=30,
# Model
use_gmm_head=True,
gmm_n_components=10,
# Hardware
device='cuda:0',
# Checkpointing
save_dir="checkpoints/custom/",
save_freq=5,
# Logging
log_dir="logs/custom/"
)
if __name__ == '__main__':
trainer = StandardCATETrainer(config)
trainer.train()
Resume Training
config = TrainingConfig(
data_path="data/train/*.csv",
save_dir="checkpoints/",
resume_from="checkpoints/checkpoint_epoch_50.pth"
)
if __name__ == '__main__':
trainer = StandardCATETrainer(config)
trainer.train() # Continues from epoch 51
Monitoring Training
TensorBoard
Training automatically logs to TensorBoard:
tensorboard --logdir logs/
Tracked metrics:
Training loss (per epoch)
Validation loss (per epoch)
Learning rate (if using scheduler)
Gradient norms
Model statistics
Progress Output
During training, you’ll see:
Epoch 1/100
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% | Loss: 1.234
Train Loss: 1.2345 | Val Loss: 1.3456
Epoch 2/100
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% | Loss: 1.123
Train Loss: 1.1234 | Val Loss: 1.2345
✓ New best model saved!
Training Output
Access training history:
if __name__ == '__main__':
trainer = StandardCATETrainer(config)
trainer.train()
# After training
print(f"Final train loss: {trainer.train_losses[-1]:.4f}")
print(f"Final val loss: {trainer.val_losses[-1]:.4f}")
print(f"Best val loss: {min(trainer.val_losses):.4f}")
print(f"Best epoch: {trainer.best_epoch}")
Checkpointing
Automatic Checkpointing
Checkpoints are saved automatically:
config = TrainingConfig(
save_dir="checkpoints/",
save_freq=10, # Save every 10 epochs
save_best_only=False # Save all checkpoints
)
This creates:
checkpoints/
├── best_model.pth # Best model by validation loss
├── checkpoint_epoch_10.pth
├── checkpoint_epoch_20.pth
└── ...
Checkpoint Contents
Each checkpoint includes:
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'train_loss': train_loss,
'val_loss': val_loss,
'config': config
}
Load Checkpoint
import torch
from causalfm.models import StandardCATEModel
# Load checkpoint
checkpoint = torch.load("checkpoints/best_model.pth")
# Load model
model = StandardCATEModel()
model.load_state_dict(checkpoint['model_state_dict'])
# Or use from_pretrained
model = StandardCATEModel.from_pretrained("checkpoints/best_model.pth")
Best Practices
Recommended Settings
For most use cases:
config = TrainingConfig(
# Data
data_path="data/train/*.csv",
val_split=0.2,
batch_size=16,
num_workers=4, # Or 0 if multiprocessing issues
# Training
epochs=150,
learning_rate=0.001,
weight_decay=1e-5,
clip_grad=1.0,
early_stop_patience=50,
# Model
use_gmm_head=True,
gmm_n_components=5,
# Checkpointing
save_dir="checkpoints/",
save_freq=10,
# Logging
log_dir="logs/",
# Hardware
device='auto',
# Reproducibility
seed=42
)
Multiprocessing
Always wrap training code in if __name__ == '__main__': guard:
if __name__ == '__main__':
trainer = StandardCATETrainer(config)
trainer.train()
Or disable multiprocessing:
config = TrainingConfig(
data_path="data/*.csv",
num_workers=0 # No multiprocessing
)
GPU Usage
For single GPU:
config = TrainingConfig(
data_path="data/*.csv",
device='cuda', # Or 'cuda:0'
batch_size=16
)
For multi-GPU (DataParallel):
config = TrainingConfig(
data_path="data/*.csv",
device='cuda', # Uses all available GPUs
batch_size=64 # Increase for multi-GPU
)
For CPU:
config = TrainingConfig(
data_path="data/*.csv",
device='cpu'
)
Troubleshooting
Common Issues
CUDA Out of Memory:
# Reduce batch size
config = TrainingConfig(
batch_size=8, # Smaller batches
num_workers=0
)
Multiprocessing Errors:
# Use if __name__ == '__main__':
if __name__ == '__main__':
trainer.train()
# Or disable multiprocessing
config = TrainingConfig(num_workers=0)
Not Converging:
# Lower learning rate
config = TrainingConfig(
learning_rate=0.0001,
epochs=200,
clip_grad=0.5
)
Overfitting:
# Increase regularization
config = TrainingConfig(
weight_decay=1e-3,
early_stop_patience=20,
val_split=0.3
)
See Also
Training - Detailed training guide
Models API - Model API reference
Standard CATE Estimation Example - Complete training example