Developer Guide: Training Loop and Progress Tracking
This document details the core training loop structure in linnaeus/main.py and the centralized TrainingProgress system used for robustly tracking training progression, especially in distributed and gradient accumulation scenarios.
1. High-Level Training Flow
The main training entry point is linnaeus/main.py. The overall flow is:
- Initialization:
- Parse command-line arguments and configuration files (
parse_option). - Initialize distributed environment (
init_distributed). - Setup output directories and logging (
setup_output_dirs,create_logger). - Build datasets, dataloaders, and augmentation pipelines (
build_datasets,build_loaders,AugmentationPipelineFactory). - Process dataset metadata (
process_and_save_dataset_metadata,TaxonomyTree). - Build the model (
build_model). - Build optimizer(s) and LR scheduler (
build_optimizer,build_scheduler). - Calculate
total_stepsbased on dataloader length and epochs/accumulation. Crucially, this happens after dataloader initialization. See Design Decisions. - Initialize
TrainingProgress,OpsSchedule,MetricsTracker,StepMetricsLogger. - Resolve schedule parameters (fractions -> steps) using
resolve_all_schedule_params. - Attempt checkpoint loading / auto-resume (
load_checkpoint,auto_resume_helper). If resuming, complete any pending validation runs. - Initialize WandB (
initialize_wandb).
- Parse command-line arguments and configuration files (
- Main Training Loop (
mainfunction):- Iterates through epochs (
for epoch in range(...)). - Epoch Start:
- Sets the dataloader epoch (
data_loader_train.set_epoch). - Determines current mixup group level (
ops_schedule.get_mixup_group_level) and sets it on the sampler. - Notifies
TrainingProgressthat a new epoch is starting (progress.start_training_epoch).
- Sets the dataloader epoch (
- Inner Training Loop (
train_one_epoch):- Iterates through mini-batches from the dataloader.
- Performs forward pass, calculates loss (
weighted_hierarchical_loss). - Performs backward pass (
scaler.scale(loss).backward()), skipping if it's a GradNorm step. - Accumulates gradients if
accumulation_steps > 1. - Optimizer Step: On non-accumulation steps or GradNorm steps:
- Optionally performs GradNorm update (
grad_weighting.update_gradnorm_weights_reforward). - Unscales gradients, clips gradients.
- Calls
optimizer.step(),scaler.update(). - Calls
optimizer.zero_grad(). - Updates LR scheduler (
lr_scheduler.step_update) using the global_step. - Updates
TrainingProgress(progress.update_step), which incrementsglobal_step.
- Optionally performs GradNorm update (
- Logs step metrics via
StepMetricsLogger.
- Epoch End:
- Finalizes epoch metrics in
MetricsTracker. - Checks
OpsScheduleto see if a checkpoint should be saved (ops_schedule.should_save_checkpoint). Saves if needed (save_checkpoint), includingTrainingProgressstate. - Checks
OpsScheduleto see if validation should run (ops_schedule.should_validate,should_validate_mask_meta, etc.). Runs validation passes (validate_one_pass,validate_with_partial_mask) if needed. Logs validation results viaStepMetricsLogger. Checkpoints are saved before/after validation runs to ensure resumability. - Checks for early stopping (
ops_schedule.should_stop_early). - Logs epoch summary results to WandB via
StepMetricsLogger.
- Finalizes epoch metrics in
- Iterates through epochs (
- Cleanup: After the loop finishes or is interrupted, resources (dataloaders, distributed group) are cleaned up.
Mermaid Diagram: Overall Flow
graph TD
A[Start main.py] --> B(Parse Config & Args);
B --> C{Initialize Distributed?};
C -- Yes --> D[torch.distributed.init_process_group];
C -- No --> E[Single Process Mode];
D --> F(Setup Dirs & Logging);
E --> F;
F --> G(Build Datasets/Loaders);
G --> H(Build Model);
H --> I(Build Optimizer/Scheduler);
I --> J(Calculate total_steps);
J --> K(Init TrainingProgress);
K --> L(Init OpsSchedule);
L --> M(Init MetricsTracker/Logger);
M --> N(Resolve Schedule Params);
N --> O{Auto-Resume?};
O -- Yes --> P[Load Checkpoint & Restore State];
O -- No --> Q[Continue];
P --> R{Pending Validation?};
R -- Yes --> S[Run Pending Validations];
R -- No --> Q;
S --> Q;
Q --> T(Init WandB);
T --> U{Epoch Loop};
U -- Start Epoch --> V(Set Loader Epoch/Mixup);
V --> W(progress.start_epoch);
W --> X(train_one_epoch);
X -- End Epoch --> Y(Finalize Train Metrics);
Y --> Z{Save Checkpoint?};
Z -- Yes --> AA[save_checkpoint];
Z -- No --> AB{Run Validation?};
AA --> AB;
AB -- Yes --> AC[Run Validation Passes];
AB -- No --> AD{Early Stop?};
AC --> AD;
AD -- Yes --> AE(Finish Training);
AD -- No --> AF(Log Epoch Results);
AF --> U;
U -- Training Complete --> AE;
AE --> AG(Final Logging/Cleanup);
AG --> AH[End];
subgraph train_one_epoch [train_one_epoch Loop]
direction TB
T1(Start Inner Loop) --> T2{For each batch};
T2 --> T3(Forward Pass);
T3 --> T4(Loss Calculation);
T4 --> T5{GradNorm Step?};
T5 -- No --> T6(Backward Pass);
T5 -- Yes --> T7(Skip Backward);
T6 --> T8{Accumulation Boundary?};
T7 --> T8;
T8 -- Yes --> T9[Optional GradNorm Update];
T9 --> T10(Optimizer Step);
T10 --> T11(Scheduler Step Update);
T11 --> T12(progress.update_step);
T12 --> T13(Log Step Metrics);
T8 -- No --> T13;
T13 --> T2;
T2 -- Loop End --> T14(Return);
end
2. TrainingProgress Class
(linnaeus/ops_schedule/training_progress.py)
This class is the central source of truth for the training state.
Key Attributes:
current_stage(TrainingStageEnum): Tracks if the process is currently inTRAINING,VALIDATION_NORMAL,VALIDATION_MASK_META, orVALIDATION_PARTIAL_MASK_META. Crucial for resuming correctly after interruptions during validation.current_epoch(int): The current epoch number (0-based internally, potentially adjusted for logging).global_step(int): Counts optimizer updates. This is the primary step counter used for most scheduling decisions (LR decay, validation intervals based on steps, etc.). It increments only when the optimizer actually performs a step (i.e., not during gradient accumulation steps).expected_total_steps(Optional[int]): The total number ofglobal_stepupdates expected for the entire training run. Calculated once at the start based on dataset size, epochs, batch size, world size, and accumulation steps. Used for resolving fraction-based schedules.pending_validations(List[TrainingStage]): Stores validation types scheduled to run at the end of the current epoch but not yet completed. Used for robust resumption.completed_validations(List[TrainingStage]): Tracks validation types completed within the current epoch boundary handling.partial_validation_indices(List[int]): Specific indices for pending partial mask-meta validations.
Key Methods:
start_training_epoch(epoch): Resets epoch-specific state (pending/completed validations) and updatescurrent_epoch. Setscurrent_stagetoTRAINING.update_step(batch_size, is_accumulation_step): Deprecated/Internal. Called withintrain_one_epoch. Increments internal counters. Only incrementsglobal_stepifis_accumulation_stepisFalse. Note: The direct usage of this method is less critical now astrain_one_epochmanages the localcurrent_stepvariable which is passed to OpsSchedule.schedule_validation(validation_type, partial_index): Adds a validation type to thepending_validationslist.start_validation(validation_type): Setscurrent_stageto the specified validation type.complete_validation(validation_type, partial_index): Removes a validation type frompending_validationsand marks it as completed. Resetscurrent_stagetoTRAININGif all pending validations are done.has_pending_validations(): Checks if thepending_validationslist is non-empty.state_dict(),load_state_dict(state_dict): Used for saving and loading the tracker's state during checkpointing, crucial for auto-resume.
Step Counting Logic
- The
global_steptracks optimizer updates. - The
train_one_epochfunction maintains a localcurrent_stepvariable that accurately reflects the optimizer step within the epoch loop. - Crucially, this local
current_stepis now passed toOpsSchedulemethods (should_update_gradnorm,get_*_prob) for making schedule decisions during the epoch. TrainingProgress.global_stepis updated aftertrain_one_epochcompletes, reflecting the total progress up to the end of the epoch.
This separation ensures that schedule decisions during an epoch use the correct, up-to-date step count, while the TrainingProgress object maintains the overall progress for checkpointing and epoch-boundary decisions.