Training Scheduling in linnaeus
Overview
The linnaeus scheduling system (linnaeus/ops_schedule/
) provides a flexible way to configure various aspects of the training process dynamically. It manages learning rates, validation frequency, checkpointing, meta-data masking probabilities, and mixup application throughout training. This system uses the centralized TrainingProgress
tracker to ensure consistency, especially in distributed settings with gradient accumulation.
Core Design Philosophy
- Clarity and Validation: Schedule parameters should be unambiguous. The system validates configurations at startup to catch conflicts (e.g., defining both step-based and fraction-based intervals for the same event).
- Step-Based Internal Logic: While configuration allows for epoch or fraction definitions, the internal scheduling decisions (within
OpsSchedule
) are primarily driven by the global step count maintained byTrainingProgress
. - Accurate Total Steps: The total number of training steps (
total_steps
) is crucial for resolving fraction-based parameters and LR decay. This value is calculated accurately inmain.py
after the training dataloader is built and its initial length (based on the first mixup group level) is determined. See Design Decisions for details. - Automatic LR Scaling: Learning rates are automatically scaled based on the effective batch size.
Parameter Definition Methods
Schedule timings or durations can be defined using one of three methods for each parameter. You must choose only one method per parameter.
- Absolute Steps (
*_STEPS
): Direct specification of the global step count (optimizer steps). Example:WARMUP_STEPS: 2000
. - Fraction of Total (
*_FRACTION
): Relative to the total training steps calculated at initialization. Example:WARMUP_FRACTION: 0.1
(10% oftotal_steps
). - Epochs (
*_EPOCHS
): Based on completed training epochs. Primarily used for validation and checkpointing intervals. Example:VALIDATION.INTERVAL_EPOCHS: 1
.
The system resolves all fraction-based parameters into absolute step counts during initialization using utils.schedule_utils.resolve_all_schedule_params
.
Schedule Components
Learning Rate Scheduling (LR_SCHEDULER
)
Controls how the learning rate changes over time. See linnaeus/lr_schedulers/
.
LR_SCHEDULER:
NAME: 'cosine' # Schedule type: cosine, step, linear, wsd
BASE_LR: 5e-5 # Base learning rate (before scaling)
REFERENCE_BS: 512 # Reference batch size for LR scaling
# Choose ONE warmup definition:
WARMUP_FRACTION: 0.05 # Preferred: Warmup over 5% of total steps
# WARMUP_STEPS: 0 # Alternative: Explicit warmup steps
# WARMUP_EPOCHS: 0.0 # Alternative: Warmup over N epochs (converted to steps)
MIN_LR: 1e-6 # Minimum LR for cosine/linear decay
WARMUP_LR: 5e-7 # Starting LR for warmup phase
# Step scheduler params (if NAME='step')
# DECAY_STEPS: 10000 # Step interval for decay
# DECAY_FRACTION: None # Alternative: Fraction of total steps for decay interval
# DECAY_RATE: 0.1 # Multiplicative decay factor
# WSD scheduler params (if NAME='wsd')
# STABLE_DURATION_FRACTION: 0.8 # Fraction of post-warmup steps for stable phase
# DECAY_DURATION_FRACTION: 0.1 # Fraction of post-warmup steps for decay phase
# DECAY_TYPE: 'cosine' # Decay shape: 'cosine' or 'linear'
Warmup-Stable-Decay (WSD) Schedule (NAME: 'wsd'
)
This schedule implements the Warmup-Stable-Decay pattern. It requires the standard BASE_LR
(target stable LR), WARMUP_*
, and MIN_LR
parameters. Additionally, configure the post-warmup phases:
STABLE_DURATION_FRACTION
: (Float, 0.0 to 1.0) Fraction of the post-warmup duration spent in the stable phase (constantBASE_LR
). Default: 0.8.DECAY_DURATION_FRACTION
: (Float, 0.0 to 1.0) Fraction of the post-warmup duration spent in the decay phase. Default: 0.1.DECAY_TYPE
: (String) Shape of the decay curve ('cosine' or 'linear'). Default: 'cosine'.
Important: All WSD-specific duration parameters apply to the period after warmup completes. For example, if your total training is 100 steps with 20 warmup steps, and you set STABLE_DURATION_FRACTION: 0.8
and DECAY_DURATION_FRACTION: 0.2
, you'll get:
- 20 steps of warmup from WARMUP_LR
to BASE_LR
- 64 steps (80% of remaining 80 steps) of stable LR at BASE_LR
- 16 steps (20% of remaining 80 steps) of decay from BASE_LR
to MIN_LR
Note: The sum of STABLE_DURATION_FRACTION
and DECAY_DURATION_FRACTION
does not need to be 1.0. After the decay phase completes, the LR remains at MIN_LR
.
- Initialization:
total_steps
andoptimizer_steps_per_epoch
are calculated accurately inmain.py
after dataloader initialization. These values are used bylr_schedulers.build.build_scheduler
to configure warmup steps and decay durations correctly. - LR Scaling: Applied automatically by
utils.schedule_utils.apply_lr_scaling
based oneffective_batch_size = per_gpu_batch_size * world_size * accumulation_steps
. Note: This scaling mechanism modifies the initial learning rates in the optimizer. It assumes that the chosen learning rate scheduler uses these initial values as its starting point, which is standard for most schedulers. - Parameter Groups: Supports different LR schedules per parameter group via
LR_SCHEDULER.PARAMETER_GROUPS
. Seeconfigs/
for examples.
Validation Scheduling (SCHEDULE.VALIDATION
)
Determines when validation runs occur. See linnaeus/validation.py
.
SCHEDULE:
VALIDATION:
# --- Standard Validation Interval ---
# Choose ONE method:
INTERVAL_EPOCHS: 1 # Method 1: Every N epochs
# INTERVAL_STEPS: 5000 # Method 2: Every N global steps
# INTERVAL_FRACTION: 0.1 # Method 3: Every (0.1 * total_steps) steps
# --- Mask Meta Validation Interval (Optional) ---
# Choose ONE method:
MASK_META_INTERVAL_EPOCHS: 5
# MASK_META_INTERVAL_STEPS: 25000
# MASK_META_INTERVAL_FRACTION: 0.5 # Every (0.5 * total_steps) steps
# --- Partial Mask Meta Validation (Optional) ---
PARTIAL_MASK_META:
ENABLED: True
# Choose ONE interval method:
INTERVAL_EPOCHS: 10 # e.g., Run every 10 epochs
# INTERVAL_STEPS: 50000 # e.g., Run every 50k steps
# INTERVAL_FRACTION: 0.5 # e.g., Run every (0.5 * total_steps) steps
WHITELIST: # Combinations to test
- ["TEMPORAL"]
- ["SPATIAL", "ELEVATION"]
# --- Final Epoch Exhaustive Validation (Optional) ---
FINAL_EPOCH:
EXHAUSTIVE_PARTIAL_META_VALIDATION: False # If True, runs all combinations below
EXHAUSTIVE_META_COMPONENTS: ["TEMPORAL", "SPATIAL", "ELEVATION"]
- Execution Timing: Although schedules can be defined by steps or fractions, validation runs are executed only at epoch boundaries. The
OpsSchedule
checks if a configured interval (in steps or epochs) has been met at the end of an epoch. - Periodic Execution: For step-based intervals, validation is triggered at epoch boundaries where the global step is divisible by the interval. For fraction-based intervals, the fraction is converted to a step interval using
interval_steps = total_steps * fraction
. - Epoch-Based Intervals Behavior: When using
INTERVAL_EPOCHS
, the system follows these rules:- If
INTERVAL_EPOCHS = 1
: Validation runs at epochs 0, 1, 2, 3, ... - If
INTERVAL_EPOCHS > 1
: Validation runs at epochs N, 2N, 3N, ... (skipping epoch 0) - This applies to all epoch-based intervals (standard validation, mask meta validation, partial mask meta validation, and checkpointing)
- If
- See Validation Documentation for more details.
Checkpoint Scheduling and Management (SCHEDULE.CHECKPOINT
)
Controls when and how model checkpoints are saved. See linnaeus/utils/checkpoint.py
.
SCHEDULE:
CHECKPOINT:
# Choose ONE scheduling method:
INTERVAL_EPOCHS: 1 # Method 1: Every N epochs
# INTERVAL_STEPS: 5000 # Method 2: Every N global steps
# INTERVAL_FRACTION: 0.1 # Method 3: Every (0.1 * total_steps) steps
# --- Retention Policies ---
KEEP_TOP_N: 3 # Keep 3 best checkpoints (based on val metric)
KEEP_LAST_N: 2 # Keep 2 most recent checkpoints
- Execution Timing: Similar to validation, checkpoints are saved only at epoch boundaries based on the configured interval.
- Retention: The system keeps the union of the
KEEP_TOP_N
best andKEEP_LAST_N
most recent checkpoints. Older/lower-performing checkpoints are automatically deleted from the local disk. - Auto-Resume: Training can be automatically resumed from the
latest.pth
checkpoint usingTRAIN.AUTO_RESUME: True
. See Auto-Resume Documentation. - Remote Sync: Optionally syncs the entire experiment output directory (including checkpoints) to Backblaze B2 after each save using
ENV.OUTPUT.BUCKET.ENABLED: True
.
Early Stopping (TRAIN.EARLY_STOP
)
Allows automatic termination of training based on metric progress or other conditions.
TRAIN:
EARLY_STOP:
ACTIVE: True
METRIC: 'val_loss' # Metric to monitor (e.g., 'val_loss', 'val_chain_accuracy')
PATIENCE_STEPS: 10000 # Global steps to wait for improvement
MIN_DELTA: 0.001 # Minimum change considered an improvement
# Optional extra conditions:
# MAX_STEPS: 100000 # Max global steps for the run
# MAX_LOSS: 10.0 # Stop if metric exceeds this value
# MIN_LR: 1e-7 # Stop if LR falls below this
# MAX_GRAD_NORM: 100.0 # Stop if gradient norm exceeds this
- Mechanism: The
OpsSchedule.should_stop_early()
method checks conditions against themetrics_tracker
and current LR/gradient norm. - Execution Timing: Checks are performed at validation points (typically epoch boundaries).
- Metric Direction: The system automatically determines if higher or lower is better for the chosen
METRIC
(e.g., lower for loss, higher for accuracy).
Meta-Masking Scheduling (SCHEDULE.META_MASKING
)
Controls the probability of masking metadata during training. See linnaeus/ops_schedule/ops_schedule.py
(get_meta_mask_prob
, get_partial_mask_enabled
, etc.).
SCHEDULE:
META_MASKING:
ENABLED: True
START_PROB: 1.0 # Probability of full masking at step 0
END_PROB: 0.05 # Probability of full masking at END_STEPS/FRACTION
# Choose ONE end definition:
END_FRACTION: 0.3 # Reach END_PROB at 30% of total steps
# END_STEPS: 15000 # Alternative: Reach END_PROB at step 15000
PARTIAL:
ENABLED: True
# Define start/end window for *applying* partial masking
START_FRACTION: 0.1 # Start applying partial masking at 10%
END_FRACTION: 0.9 # Stop applying partial masking at 90%
# Define schedule for the *probability* of applying partial mask (within the window)
START_PROB: 0.01 # Initial probability for partial masking (if not fully masked)
END_PROB: 0.7 # Final probability for partial masking
# Choose ONE end definition for the probability schedule:
PROB_END_FRACTION: 0.5 # Reach END_PROB at 50% of total steps
# PROB_END_STEPS: 25000 # Alternative: Reach END_PROB at step 25k
WHITELIST: [...] # List of component combinations for partial masking
# WEIGHTS: [...] # Optional: Weights for sampling from WHITELIST
- Full Masking: Probability decreases linearly from
START_PROB
toEND_PROB
over the specified duration. - Partial Masking: Applied only within the
START_FRACTION
/END_FRACTION
window and only to samples not already fully masked. The probability of applying partial masking also ramps linearly based on its ownSTART_PROB
/END_PROB
schedule. - See Meta Masking Documentation for details.
Mixup Scheduling (SCHEDULE.MIXUP
)
Controls the probability of applying mixup and (conceptually) the grouping level.
SCHEDULE:
MIXUP:
GROUP_LEVELS: ['taxa_L10'] # Taxonomic levels for grouping samples
# --- LEVEL SWITCHING IS CURRENTLY DISABLED ---
LEVEL_SWITCH_STEPS: [] # Currently disabled - must be empty
LEVEL_SWITCH_EPOCHS: [] # Currently disabled - must be empty
# --------------------------------------------
PROB:
ENABLED: True
START_PROB: 1.0 # Mixup probability at step 0
END_PROB: 0.2 # Mixup probability at END_STEPS/FRACTION
# Choose ONE end definition:
END_FRACTION: 0.4 # Reach END_PROB at 40% of total steps
# END_STEPS: 20000 # Alternative: Reach END_PROB at step 20k
ALPHA: 1.0 # Beta distribution alpha parameter
USE_GPU: True # Use GPU implementation if available
MIN_GROUP_SIZE: 4 # Groups smaller than this are excluded
EXCLUDE_NULL_SAMPLES: True # Exclude samples with null labels from mixup
- Probability: Decreases linearly like meta-masking.
- Level Switching: Currently disabled.
LEVEL_SWITCH_STEPS
andLEVEL_SWITCH_EPOCHS
must be empty. The system uses only the first level specified inGROUP_LEVELS
for the entire run. This is due to the schedule initialization dependency explained in Design Decisions. - See Augmentations Documentation for details on mixup and configuration best practices.
Metrics Logging (SCHEDULE.METRICS
)
Controls how often various metrics are logged.
SCHEDULE:
METRICS:
# Choose ONE interval definition method for each metric type:
STEP_INTERVAL: 50 # Log basic training metrics every N steps
# STEP_FRACTION: 0.001 # Alternative: Log every 0.1% of steps
CONSOLE_INTERVAL: 100 # Log summary to console every N steps
WANDB_INTERVAL: 50 # Log detailed metrics to WandB every N steps
LR_INTERVAL: 100 # Log learning rates every N steps
PIPELINE_INTERVAL: 250 # Log data pipeline stats every N steps
# Note: GradNorm metrics are logged automatically when GradNorm weights are updated
- See Metrics Documentation for details on tracked metrics.
Schedule Initialization and Monitoring
The scheduling system is initialized at the start of training (main.py
) after the dataloader length and total_steps
are accurately determined. The OpsSchedule
class uses the resolved step counts.
- Resolution:
utils.schedule_utils.resolve_all_schedule_params
converts all fraction-based parameters to absolute step counts. - Validation:
utils.schedule_utils.validate_schedule_config
checks for conflicting parameter definitions.utils.schedule_utils.validate_schedule_sanity
checks for potentially nonsensical (but not strictly invalid) configurations. - Summary: A detailed summary, including resolved step counts and epoch equivalents, is logged and saved to
output/assets/schedule_summary.txt
. - Visualization: A text-based visualization is generated and included in the summary file.
Always review the schedule summary and visualization before starting a long run to ensure the configuration behaves as expected.