Tutorial 3: Training Models

Learn how to train CausalFM foundation models.

Coming Soon

This tutorial is under development. For now, see:

Quick Reference

Basic Training

from causalfm.training import StandardCATETrainer, TrainingConfig

config = TrainingConfig(
    data_path="data/train/*.csv",
    epochs=100,
    batch_size=16,
    save_dir="checkpoints/"
)

if __name__ == '__main__':
    trainer = StandardCATETrainer(config)
    trainer.train()

Custom Configuration

config = TrainingConfig(
    # Data
    data_path="data/train/*.csv",
    val_split=0.2,
    batch_size=16,
    num_workers=4,

    # Training
    epochs=150,
    learning_rate=0.001,
    weight_decay=1e-5,
    early_stop_patience=50,

    # Model
    use_gmm_head=True,
    gmm_n_components=5,

    # Checkpointing
    save_dir="checkpoints/",

    # Hardware
    device='auto'
)

Next Tutorial

Continue to Tutorial 4: Model Evaluation to learn about model evaluation.