Developer Guide: Scheduling System
This document details the linnaeus scheduling system, managed primarily by the OpsSchedule
class (linnaeus/ops_schedule/ops_schedule.py
), configuration parameters, and supporting utilities.
1. Overview
The scheduling system coordinates various time-dependent training operations:
- Learning Rate: Managed by schedulers in
linnaeus/lr_schedulers/
, configured viaLR_SCHEDULER.*
. - Validation Runs: Standard, mask-meta, and partial mask-meta validation, configured via
SCHEDULE.VALIDATION.*
. - Checkpointing: Saving model state, configured via
SCHEDULE.CHECKPOINT.*
. - Early Stopping: Terminating training based on metrics, configured via
TRAIN.EARLY_STOP.*
. - Meta-Masking Probability: Adjusting the chance of masking metadata, configured via
SCHEDULE.META_MASKING.*
. - Mixup Application: Controlling mixup probability and grouping level, configured via
SCHEDULE.MIXUP.*
. - Metrics Logging: Controlling frequency of logging different metrics, configured via
SCHEDULE.METRICS.*
.
2. Core Component: OpsSchedule
The OpsSchedule
class is the central decision-maker for when scheduled events should occur.
- Initialization: Takes the fully resolved
config
and theMetricsTracker
(for early stopping metric checks). It also holds a reference to theTrainingProgress
instance (set viaset_training_progress
) to query the current state (epoch, global step). - Decision Methods: Provides methods like
should_validate()
,should_save_checkpoint()
,should_stop_early()
,should_update_gradnorm()
, etc. These methods typically check the current progress (epoch or step) against configured intervals or thresholds. - Value Methods: Provides methods like
get_meta_mask_prob()
,get_null_mask_prob()
,get_mixup_prob()
,get_mixup_group_level()
. These methods calculate the scheduled value based on the current step. - Step Dependency: Crucially, methods that depend on the current progress within an epoch (like
should_update_gradnorm
,get_*_prob
) now require thecurrent_step
(optimizer step count) to be passed as an argument. They no longer rely solely on theTrainingProgress.global_step
which is only updated at epoch end. - Supported Scheduler Types: The system now supports four different learning rate scheduler patterns: 'cosine', 'linear', 'step', and 'wsd' (Warmup-Stable-Decay).
3. Configuration (SCHEDULE.*
, LR_SCHEDULER.*
, TRAIN.EARLY_STOP.*
)
Scheduling is primarily configured within the SCHEDULE
section of the YAML config, supplemented by LR_SCHEDULER
and TRAIN.EARLY_STOP
.
Parameter Definition Methods
As outlined in the Style Guide (to be replaced by this doc section):
- Use absolute steps (
*_STEPS
), fraction of total steps (*_FRACTION
), or epochs (*_EPOCHS
). - Choose only ONE method per parameter. Set unused step/epoch parameters to
0
orNone
for fractions. - Fraction-based parameters are generally preferred for portability across different dataset sizes/training durations.
Resolution
- At startup (
main.py
),utils.schedule_utils.resolve_all_schedule_params
converts all*_FRACTION
settings into absolute*_STEPS
based on the calculatedtotal_steps
. OpsSchedule
then uses these resolved step counts for its decisions.- Epoch-based intervals (
*_EPOCHS
) are used directly by checkingTrainingProgress.current_epoch
.
Validation & Sanity Checks
utils.schedule_utils.validate_schedule_config
: Checks for conflicting definitions (e.g., bothINTERVAL_STEPS
andINTERVAL_EPOCHS
set > 0). Raises errors.utils.schedule_utils.validate_schedule_sanity
: Checks for potentially problematic configurations (e.g., no validation interval defined). Logs warnings.- Schedule Summary: A detailed summary and text visualization are logged and saved to
assets/schedule_summary.txt
, showing the resolved step counts and timings for all major scheduled events. Reviewing this summary is essential before long runs.
4. Integration with Training Loop
Learning Rate
- Built by
lr_schedulers.build.build_scheduler
using resolvedtotal_steps
andwarmup_steps
. - Stepped after each optimizer update in
train_one_epoch
usinglr_scheduler.step_update(current_step)
. - Scaling applied once at initialization based on effective batch size.
- Supports multiple scheduler types: 'cosine', 'linear', 'step', and 'wsd' (Warmup-Stable-Decay).
Validation & Checkpointing
- Checks (
should_validate
,should_save_checkpoint
) are performed at the end of each epoch loop inmain.py
. - These methods use
TrainingProgress.current_epoch
andTrainingProgress.global_step
along with configured intervals (INTERVAL_EPOCHS
, resolvedINTERVAL_STEPS
). - Validation runs (
validate_one_pass
,validate_with_partial_mask
) and checkpoint saving (save_checkpoint
) are triggered based on these checks.
GradNorm Updates
- The check
ops_schedule.should_update_gradnorm(current_step)
is performed inside thetrain_one_epoch
loop before the potential optimizer step. - It receives the local
current_step
variable, ensuring the check uses the up-to-date optimizer step count. - If it returns
True
, the GradNorm update logic (grad_weighting.update_gradnorm_weights_reforward
) is executed before the optimizer step for that batch.
Meta-Masking / Null Masking / Mixup Probabilities
- Probabilities (
get_meta_mask_prob
,get_null_mask_prob
,get_mixup_prob
) are calculated based on the current step. get_null_mask_prob(current_step)
is called withinloss.masking.apply_loss_masking
, receiving the correct step.get_meta_mask_prob(current_step)
andget_mixup_prob(current_step)
are primarily used byStepMetricsLogger.log_schedule_values
for logging, receiving the correct step. If these probabilities were used directly to apply masking/mixup per step (they currently aren't applied dynamically per step in the dataloader collate), those call sites would also need thecurrent_step
.
Mixup Group Level
get_mixup_group_level(current_step)
determines the group key.- This is called once at the start of each epoch in
main.py
to configure theGroupedBatchSampler
. It uses theTrainingProgress.global_step
at the start of the epoch, which is sufficient as the level is fixed for the entire epoch. Note: Scheduled switching is disabled.
5. Mermaid Diagrams
train_one_epoch
Inner Loop Decisions
graph TD
subgraph train_one_epoch Batch Loop
direction TB
B1(Get Batch) --> B2(Forward Pass);
B2 --> B3(Loss Calculation);
B3 --> B4{Is GradNorm Step?};
B4 -- Call ops_schedule.should_update_gradnorm(current_step) --> B4;
B4 -- Yes --> B5(Skip Backward);
B4 -- No --> B6(Standard Backward);
B5 --> B7{Accumulation Boundary?};
B6 --> B7;
B7 -- Yes --> B8[Optional GradNorm Update Logic];
B8 --> B9(Optimizer Step);
B9 --> B10(Scheduler Step Update);
B10 --> B11(Increment current_step);
B11 --> B12{Log Step Metrics?};
B12 -- Call ops_schedule.should_log_metrics() etc. --> B12;
B12 -- Yes --> B13[Log via StepMetricsLogger];
B7 -- No --> B12;
B13 --> B14(Next Batch);
B12 -- No --> B14;
B14 --> B2;
end
End-of-Epoch Decisions (main.py
)
graph TD
subgraph End of Epoch in main.py
direction TB
E1(train_one_epoch returns) --> E2(Finalize Train Metrics);
E2 --> E3(Update TrainingProgress.global_step);
E3 --> E4{Save Checkpoint?};
E4 -- Call ops_schedule.should_save_checkpoint() --> E4;
E4 -- Yes --> E5[Save Checkpoint];
E4 -- No --> E6{Run Validation?};
E5 --> E6;
E6 -- Call ops_schedule.should_validate() etc. --> E6;
E6 -- Yes --> E7[Run Validation(s)];
E6 -- No --> E8{Early Stop?};
E7 -- Validation Done --> E8;
E8 -- Call ops_schedule.should_stop_early() --> E8;
E8 -- Yes --> E9(Break Epoch Loop);
E8 -- No --> E10(Log Epoch Results);
E10 --> E11(Next Epoch);
end
6. Conclusion
The scheduling system provides a robust framework for managing training dynamics. Key takeaways for developers:
- Configure schedules primarily via
SCHEDULE.*
,LR_SCHEDULER.*
, andTRAIN.EARLY_STOP.*
. - Prefer fraction-based parameters where applicable.
- Understand that
OpsSchedule
uses the correctcurrent_step
for intra-epoch decisions andTrainingProgress.global_step
/current_epoch
for epoch-boundary decisions. - Always review the generated schedule summary and visualization.
- Use the centralized
TrainingProgress
API for all progress tracking needs.