The Linnaeus Inference Bundle
The Linnaeus Inference Bundle is a self-contained package that includes all the necessary artifacts and configurations required for running inference with a trained Linnaeus model. It's designed to be portable and easy to use, whether for direct PyTorch inference or serving with tools like LitServe.
Bundle Structure
A typical inference bundle is a directory containing the following key files:
inference_config.yaml
: The main configuration file for theLinnaeusInferenceHandler
. It specifies paths to other artifacts, model parameters, preprocessing settings, and inference options.- Model Weights (e.g.,
pytorch_model.bin
): The saved state dictionary of the trained PyTorch model. The exact filename is specified ininference_config.yaml
. - Taxonomy Data (e.g.,
taxonomy.json
): A JSON file representing the taxonomic tree structure used by the model. This file is typically generated bylinnaeus.utils.taxonomy.TaxonomyTree.save()
. Its path is specified ininference_config.yaml
. - Class Index Map (e.g.,
class_index_map.json
): A JSON file that maps between the model's output class indices and thetypus
library'staxon_id
s for each taxonomic rank the model predicts. It also includes information about null taxon IDs and the number of classes per rank. Its path is specified ininference_config.yaml
.
Example Directory Layout:
my_model_inference_bundle/
├── inference_config.yaml
├── pytorch_model.bin
├── taxonomy.json
└── class_index_map.json
Key Components in Detail
1. inference_config.yaml
This YAML file is the entry point for the inference handler. It contains several sections:
-
model
:architecture_name
: Name of the model architecture (e.g.,mFormerV1_sm
).weights_path
: Path (relative to the bundle root or absolute) to the model's weights file.model_task_keys_ordered
: Ordered list of internal Linnaeus task keys the model predicts (e.g.,["taxa_L70", "taxa_L60", ..., "taxa_L10"]
). This order must match the model's output structure.num_classes_per_task
: List of class counts (including null) for each task inmodel_task_keys_ordered
.null_class_indices
: Dictionary mapping each Linnaeustask_key
to the model's output index that represents the "null" or "unknown" class for that task.expected_aux_vector_length
: (Optional) The expected length of the auxiliary feature vector if metadata is used. Ifnull
or not provided, theLinnaeusInferenceHandler
will attempt to derive this from themetadata_preprocessing
section. It's recommended to set this explicitly if metadata is used.
-
input_preprocessing
:image_size
: Expected image input dimensions[C, H, W]
.image_mean
: Mean values for image normalization.image_std
: Standard deviation values for image normalization.image_interpolation
: Interpolation method for resizing (e.g.,bilinear
).
-
metadata_preprocessing
:use_geolocation
: Boolean, whether latitude/longitude are used.use_temporal
: Boolean, whether date/time are used.temporal_use_julian_day
: Boolean, use day-of-year (if true) or month-of-year (if false) for temporal encoding.temporal_use_hour
: Boolean, include hour-of-day sinusoidal features.use_elevation
: Boolean, whether elevation is used.elevation_scales
: List of scale values for elevation encoding.
-
taxonomy_data
:source_name
: Source of the taxonomy (e.g.,CoL2024
).version
: Version of the taxonomy.root_identifier
: Root taxon ID or name covered by the model (for context).taxonomy_tree_path
: Path to thetaxonomy.json
file.class_index_map_path
: Path to theclass_index_map.json
file.
-
inference_options
:default_top_k
: Default K for top-K predictions.device
: Device for inference (cpu
,cuda
,mps
, orauto
).batch_size
: Maximum batch size for the handler's internal processing.enable_hierarchical_consistency_check
: Boolean, whether to enforce parent-child consistency in predictions.handler_version
: Version of theLinnaeusInferenceHandler
this bundle is intended for.artifacts_source_uri
: (Optional) URI indicating where the bundle might have been downloaded from (e.g., a Hugging Face Hub path).
-
model_description
: (Optional) A brief human-readable description of the model configuration.
2. Model Weights (e.g., pytorch_model.bin
)
This is a standard PyTorch state dictionary, saved using torch.save(model.state_dict(), ...)
. It contains the learned parameters of your trained model.
3. Taxonomy Data (taxonomy.json
)
This file stores the taxonomic hierarchy relevant to the model. It's created by calling the .save()
method of a linnaeus.utils.taxonomy.taxonomy_tree.TaxonomyTree
instance. The TaxonomyTree
is typically built from the hierarchy_map
generated during dataset processing. The JSON file includes:
* task_keys
: Ordered list of Linnaeus task keys representing hierarchy levels (typically lowest rank to highest, e.g., ["taxa_L10_species", "taxa_L20_genus", ...]
).
* num_classes
: Dictionary mapping each task key to the number of classes at that level.
* hierarchy_map_raw
: The core map defining parent-child relationships: Dict[child_task_key, Dict[child_model_idx, parent_model_idx]]
.
4. Class Index Map (class_index_map.json
)
This JSON file provides the critical mappings needed to translate the model's numerical outputs into meaningful taxonomic information using typus
standards. It contains:
* idx_to_taxon_id
: Maps RankLevel.value
to a dictionary of {model_class_index: typus_taxon_id}
.
* taxon_id_to_idx
: The inverse of idx_to_taxon_id
. Maps RankLevel.value
to {typus_taxon_id: model_class_index}
.
* null_taxon_ids
: Maps RankLevel.value
to the typus_taxon_id
that represents the "null" or "unknown" concept for that rank.
* num_classes_per_rank
: Maps RankLevel.value
to the total number of classes (including null) that the model predicts for that rank.
Creating an Inference Bundle
Currently, creating an inference bundle is a manual process that involves:
- Training a Model: Train your Linnaeus model and save its state dictionary.
- Preparing Taxonomy Artifacts:
- During your data preparation or training setup, you should have access to the
TaxonomyTree
instance. Save it usingtaxonomy_tree.save("taxonomy.json")
. - You will need to construct the
class_index_map.json
file. This requires knowing:- How your model's output indices for each task head map to specific
typus_taxon_id
s. - Which
typus_taxon_id
represents the "null" class for each rank. - The total number of classes (outputs) for each rank-specific head in your model. This mapping is often established during dataset creation and model configuration.
- How your model's output indices for each task head map to specific
- During your data preparation or training setup, you should have access to the
- Writing
inference_config.yaml
:- Carefully create this file, ensuring all paths correctly point to your artifact files (model weights, taxonomy.json, class_index_map.json) relative to the bundle's root.
- Fill in all model parameters (
model_task_keys_ordered
,num_classes_per_task
,null_class_indices
) to match your trained model's architecture precisely. - Configure preprocessing and inference options as needed.
- Assembling the Bundle: Place all these files into a single directory.
A helper script or CLI tool (tools/prepare_inference_bundle.py
mentioned in the issue) can greatly simplify and standardize this process. This tool would ideally take your trained model checkpoint, taxonomy data, and class mapping information as inputs and generate the complete, validated bundle.
Using the Bundle
Once created, the bundle can be used with LinnaeusInferenceHandler
:
from pathlib import Path
from linnaeus.inference.handler import LinnaeusInferenceHandler
bundle_config_path = Path("/path/to/my_model_inference_bundle/inference_config.yaml")
handler = LinnaeusInferenceHandler.load_from_artifacts(config_file_path=bundle_config_path)
# Now the handler is ready for predictions
# image = Image.open(...)
# results = handler.predict(images=[image])
# print(results[0].model_dump_json(indent=2))
This bundle structure ensures that all necessary components for inference are co-located and explicitly defined, promoting reproducibility and ease of deployment.