Skip to content

The Linnaeus Inference Bundle

The Linnaeus Inference Bundle is a self-contained package that includes all the necessary artifacts and configurations required for running inference with a trained Linnaeus model. It's designed to be portable and easy to use, whether for direct PyTorch inference or serving with tools like LitServe.

Bundle Structure

A typical inference bundle is a directory containing the following key files:

  • inference_config.yaml: The main configuration file for the LinnaeusInferenceHandler. It specifies paths to other artifacts, model parameters, preprocessing settings, and inference options.
  • Model Weights (e.g., pytorch_model.bin): The saved state dictionary of the trained PyTorch model. The exact filename is specified in inference_config.yaml.
  • Taxonomy Data (e.g., taxonomy.json): A JSON file representing the taxonomic tree structure used by the model. This file is typically generated by linnaeus.utils.taxonomy.TaxonomyTree.save(). Its path is specified in inference_config.yaml.
  • Class Index Map (e.g., class_index_map.json): A JSON file that maps between the model's output class indices and the typus library's taxon_ids for each taxonomic rank the model predicts. It also includes information about null taxon IDs and the number of classes per rank. Its path is specified in inference_config.yaml.

Example Directory Layout:

my_model_inference_bundle/
├── inference_config.yaml
├── pytorch_model.bin
├── taxonomy.json
└── class_index_map.json

Key Components in Detail

1. inference_config.yaml

This YAML file is the entry point for the inference handler. It contains several sections:

  • model:

    • architecture_name: Name of the model architecture (e.g., mFormerV1_sm).
    • weights_path: Path (relative to the bundle root or absolute) to the model's weights file.
    • model_task_keys_ordered: Ordered list of internal Linnaeus task keys the model predicts (e.g., ["taxa_L70", "taxa_L60", ..., "taxa_L10"]). This order must match the model's output structure.
    • num_classes_per_task: List of class counts (including null) for each task in model_task_keys_ordered.
    • null_class_indices: Dictionary mapping each Linnaeus task_key to the model's output index that represents the "null" or "unknown" class for that task.
    • expected_aux_vector_length: (Optional) The expected length of the auxiliary feature vector if metadata is used. If null or not provided, the LinnaeusInferenceHandler will attempt to derive this from the metadata_preprocessing section. It's recommended to set this explicitly if metadata is used.
  • input_preprocessing:

    • image_size: Expected image input dimensions [C, H, W].
    • image_mean: Mean values for image normalization.
    • image_std: Standard deviation values for image normalization.
    • image_interpolation: Interpolation method for resizing (e.g., bilinear).
  • metadata_preprocessing:

    • use_geolocation: Boolean, whether latitude/longitude are used.
    • use_temporal: Boolean, whether date/time are used.
    • temporal_use_julian_day: Boolean, use day-of-year (if true) or month-of-year (if false) for temporal encoding.
    • temporal_use_hour: Boolean, include hour-of-day sinusoidal features.
    • use_elevation: Boolean, whether elevation is used.
    • elevation_scales: List of scale values for elevation encoding.
  • taxonomy_data:

    • source_name: Source of the taxonomy (e.g., CoL2024).
    • version: Version of the taxonomy.
    • root_identifier: Root taxon ID or name covered by the model (for context).
    • taxonomy_tree_path: Path to the taxonomy.json file.
    • class_index_map_path: Path to the class_index_map.json file.
  • inference_options:

    • default_top_k: Default K for top-K predictions.
    • device: Device for inference (cpu, cuda, mps, or auto).
    • batch_size: Maximum batch size for the handler's internal processing.
    • enable_hierarchical_consistency_check: Boolean, whether to enforce parent-child consistency in predictions.
    • handler_version: Version of the LinnaeusInferenceHandler this bundle is intended for.
    • artifacts_source_uri: (Optional) URI indicating where the bundle might have been downloaded from (e.g., a Hugging Face Hub path).
  • model_description: (Optional) A brief human-readable description of the model configuration.

2. Model Weights (e.g., pytorch_model.bin)

This is a standard PyTorch state dictionary, saved using torch.save(model.state_dict(), ...). It contains the learned parameters of your trained model.

3. Taxonomy Data (taxonomy.json)

This file stores the taxonomic hierarchy relevant to the model. It's created by calling the .save() method of a linnaeus.utils.taxonomy.taxonomy_tree.TaxonomyTree instance. The TaxonomyTree is typically built from the hierarchy_map generated during dataset processing. The JSON file includes: * task_keys: Ordered list of Linnaeus task keys representing hierarchy levels (typically lowest rank to highest, e.g., ["taxa_L10_species", "taxa_L20_genus", ...]). * num_classes: Dictionary mapping each task key to the number of classes at that level. * hierarchy_map_raw: The core map defining parent-child relationships: Dict[child_task_key, Dict[child_model_idx, parent_model_idx]].

4. Class Index Map (class_index_map.json)

This JSON file provides the critical mappings needed to translate the model's numerical outputs into meaningful taxonomic information using typus standards. It contains: * idx_to_taxon_id: Maps RankLevel.value to a dictionary of {model_class_index: typus_taxon_id}. * taxon_id_to_idx: The inverse of idx_to_taxon_id. Maps RankLevel.value to {typus_taxon_id: model_class_index}. * null_taxon_ids: Maps RankLevel.value to the typus_taxon_id that represents the "null" or "unknown" concept for that rank. * num_classes_per_rank: Maps RankLevel.value to the total number of classes (including null) that the model predicts for that rank.

Creating an Inference Bundle

Currently, creating an inference bundle is a manual process that involves:

  1. Training a Model: Train your Linnaeus model and save its state dictionary.
  2. Preparing Taxonomy Artifacts:
    • During your data preparation or training setup, you should have access to the TaxonomyTree instance. Save it using taxonomy_tree.save("taxonomy.json").
    • You will need to construct the class_index_map.json file. This requires knowing:
      • How your model's output indices for each task head map to specific typus_taxon_ids.
      • Which typus_taxon_id represents the "null" class for each rank.
      • The total number of classes (outputs) for each rank-specific head in your model. This mapping is often established during dataset creation and model configuration.
  3. Writing inference_config.yaml:
    • Carefully create this file, ensuring all paths correctly point to your artifact files (model weights, taxonomy.json, class_index_map.json) relative to the bundle's root.
    • Fill in all model parameters (model_task_keys_ordered, num_classes_per_task, null_class_indices) to match your trained model's architecture precisely.
    • Configure preprocessing and inference options as needed.
  4. Assembling the Bundle: Place all these files into a single directory.

A helper script or CLI tool (tools/prepare_inference_bundle.py mentioned in the issue) can greatly simplify and standardize this process. This tool would ideally take your trained model checkpoint, taxonomy data, and class mapping information as inputs and generate the complete, validated bundle.

Using the Bundle

Once created, the bundle can be used with LinnaeusInferenceHandler:

from pathlib import Path
from linnaeus.inference.handler import LinnaeusInferenceHandler

bundle_config_path = Path("/path/to/my_model_inference_bundle/inference_config.yaml")
handler = LinnaeusInferenceHandler.load_from_artifacts(config_file_path=bundle_config_path)

# Now the handler is ready for predictions
# image = Image.open(...)
# results = handler.predict(images=[image])
# print(results[0].model_dump_json(indent=2))

This bundle structure ensures that all necessary components for inference are co-located and explicitly defined, promoting reproducibility and ease of deployment.