Developer Guide: Model Metadata, Parameter Filtering & Checkpoint Loading
1. Introduction
This guide explains the system used in linnaeus for managing model parameters and handling pretrained checkpoints through model-defined metadata. This approach promotes modularity, consistency, and simplifies the integration of new model architectures and diverse pretrained weights.
Core Concepts:
- Model-Defined Metadata: Each model architecture class (inheriting from
BaseModel
) defines two crucial properties:parameter_groups_metadata
(@property
): Describes the semantic grouping of its parameters (e.g., which parameters belong to convolution stages, transformer stages, specific heads, embeddings). This informs training-time operations but doesn't execute them.pretrained_ckpt_handling_metadata
(@property
): Specifies instructions for theload_pretrained
utility on how checkpoints should be processed before being loaded into an instance of this specific model architecture (e.g., which layers/buffers to drop, whether to interpolate positional embeddings).
- Configuration-Driven Filtering: The actual rules for grouping parameters for multi-optimizer, multi-LR scheduling, or GradNorm exclusion are defined in the YAML configuration file.
- Unified Filtering Utility: The
UnifiedParamFilter
system (utils/unified_filtering.py
) reads the rules from the YAML config and applies them to the model's parameters. It does not directly parse the model'sparameter_groups_metadata
. - Metadata-Aware Checkpoint Loading: The
load_pretrained
utility (utils/checkpoint.py
) reads the target model'spretrained_ckpt_handling_metadata
to adapt the checkpoint dictionary before attempting to load it usingmodel.load_state_dict
.
2. Parameter Grouping Metadata (parameter_groups_metadata
)
Purpose
This @property
within a model class allows the model architecture to declare the semantic roles of its different parameter sets. It serves as self-documentation and provides hints for setting up filtering rules in the configuration.
It does NOT directly control filtering. Filtering logic is defined in the YAML config.
Definition within a Model Class
Models should define parameter_groups_metadata
returning a dictionary. Keys are categories, values are lists of representative name patterns associated with that category.
# Inside your model class (e.g., linnaeus/models/mFormerV1.py)
@property
def parameter_groups_metadata(self) -> Dict[str, Any]:
return {
"stages": {
"convnext_stages": ["stem.", "stages.0.", "stages.1.", "downsample_layers.0", "downsample_layers.1"],
"rope_stages": ["stages.2.", "stages.3.", "downsample_layers.2", "downsample_layers.3"],
"rope_freqs": ["freqs"], # Specific learnable parameter
},
"heads": {
"classification_heads": ["head."], # Use dot to avoid matching 'headroom' etc.
"meta_heads": ["meta_"],
},
"embeddings": ["cls_token"],
"norm_layers": ["norm", ".bn", "LayerNorm"], # Include specific layer types if needed
"aggregation": ["cl_1_fc.", "aggregate.", "final_norm."],
}
How Filtering Actually Happens
Filtering for multi-optimizer, multi-LR, or GradNorm is configured in YAML and executed by UnifiedParamFilter
. The YAML FILTER
definitions use rules based on name, dimension, layer type, etc.
Example (config.yaml
for GradNorm exclusion):
LOSS:
GRAD_WEIGHTING:
TASK:
# ... other GradNorm params ...
EXCLUDE_CONFIG: # Filter definition read by UnifiedParamFilter
TYPE: "or" # Combine rules with OR
FILTERS:
- TYPE: "name" # Rule 1: Match by name
MATCH_TYPE: "startswith"
PATTERNS: ["head.", "meta_"] # Exclude classification and meta heads
# Add more rules if needed, e.g., excluding LayerNorm biases:
# - TYPE: "and"
# FILTERS:
# - {TYPE: "layer_type", LAYER_TYPES: ["LayerNorm"], model: "@MODEL"} # Needs model access
# - {TYPE: "name", MATCH_TYPE: "endswith", PATTERNS: [".bias"]}
UnifiedParamFilter
is instantiated with this config and the model instance.- It applies the defined rules (e.g.,
NameFilter
,LayerTypeFilter
) to categorize parameters. - The model's
parameter_groups_metadata
is not directly used by the filter execution but serves as a reference for writing the filter rules in the config.
3. Pretrained Checkpoint Handling Metadata (pretrained_ckpt_handling_metadata
)
Purpose
This @property
within a model class provides instructions to the load_pretrained
utility on how to process a checkpoint before calling model.load_state_dict
. This allows the target model architecture to define how it wants to adapt potentially incompatible checkpoints.
Definition within a Model Class
# Inside your model class (e.g., linnaeus/models/mFormerV1.py)
@property
def pretrained_ckpt_handling_metadata(self) -> Dict[str, Any]:
return {
"drop_buffers": [], # e.g., ["relative_position_index"] for mFormerV0
"drop_params": ["head.", "meta_", "pos_embed"], # Always drop classifier, meta heads, and any absolute pos embeds
"interpolate_rel_pos_bias": False, # mFormerV1 uses RoPE, no relative bias tables
"supports_module_prefix": True, # Allow handling DDP 'module.' prefix
"strict": False, # Use strict=False for load_state_dict by default
}
How Checkpoint Loading Actually Happens (utils/checkpoint.py
)
load_pretrained
is called with the targetmodel
instance and theconfig
.- It loads the raw checkpoint dictionary from the path specified in
config.MODEL.PRETRAINED
. - Crucially, it accesses
model.pretrained_ckpt_handling_metadata
from the target model instance. - It modifies the loaded checkpoint dictionary based on the rules specified in the metadata (
drop_buffers
,drop_params
, handlingmodule.
prefix). - If
interpolate_rel_pos_bias
isTrue
in the metadata, it callsrelative_bias_interpolate
. - If a custom source mapping is specified (e.g.,
config.MODEL.PRETRAINED_SOURCE == 'metaformer'
or'stitched_convnext_ropevit'
), it calls the correspondingmap_*_checkpoint
function before applying the metadata rules (or as part of a dedicated loading function like the plannedload_stitched_pretrained
). - Finally, it calls
model.load_state_dict(modified_checkpoint['model'], strict=metadata.get('strict', False))
.
4. Key Differences Summarized
Feature | Defined In | Used By | Purpose |
---|---|---|---|
Parameter Grouping Rules | YAML Config (OPTIMIZER , LR_SCHEDULER , LOSS ) |
UnifiedParamFilter (called by optimizers/build.py , loss/gradient_weighting.py ) |
Executes the filtering logic for multi-optimizer, multi-LR, GradNorm exclusion. |
parameter_groups_metadata |
Model Class (@property ) |
Developers (for writing YAML filters), Future Tools | Declares semantic structure of the model's parameters. Informs filter design but doesn't execute it. |
Checkpoint Handling Rules | Model Class (@property ) |
load_pretrained utility (utils/checkpoint.py ) |
Instructs the loading utility on how to modify a checkpoint before loading into this model type. |
Checkpoint Key/Shape Mapping | Custom map_* functions (utils/checkpoint.py ) |
load_pretrained utility (if PRETRAINED_SOURCE triggers it) |
Performs complex renaming/reshaping for specific checkpoint sources before metadata rules are applied. |
5. Debugging (Recap)
- Filtering Issues: Check YAML filter definitions. Use
inspect_multilr_filters
/inspect_gradnorm_filters
orUnifiedParamFilter.log_matches
withDEBUG.OPTIMIZER=True
. - Loading Issues: Check
pretrained_ckpt_handling_metadata
in your model. Usedebug_load_checkpoint
before modifications inload_pretrained
. Check theload_state_dict
return value. UseDEBUG.CHECKPOINT=True
.
This revised structure emphasizes that model metadata informs configuration and utilities, while the execution of filtering and loading is handled by dedicated configuration sections (YAML) and utility functions (UnifiedParamFilter
, load_pretrained
).