Source code for hyrax.verbs.infer

import logging

from colorama import Back, Fore, Style

from hyrax.trace import trace_verb_data

from .verb_registry import Verb, hyrax_verb

[docs] logger = logging.getLogger(__name__)
@hyrax_verb
[docs] class Infer(Verb): """Inference verb"""
[docs] cli_name = "infer"
[docs] add_parser_kwargs = {}
[docs] description = "Run inference on a model using a dataset."
# Dataset groups that the Infer verb knows about.
[docs] REQUIRED_DATA_GROUPS = ("infer",)
[docs] OPTIONAL_DATA_GROUPS = ()
@staticmethod
[docs]
[docs] def setup_parser(parser): """We don't need any parser setup for CLI opts""" pass
[docs] def run_cli(self, args=None): """CLI stub for Infer verb""" logger.info("infer run from CLI")
self.run()
@trace_verb_data
[docs] def run(self): """Run inference on a model using a dataset Parameters ---------- config : dict The parsed config file as a nested dict """ from hyrax.config_utils import ( create_results_dir, log_runtime_config, ) from hyrax.datasets.result_factories import load_results_dataset from hyrax.models.model_utils import load_model_weights from hyrax.pytorch_ignite import ( create_evaluator, create_save_batch_callback, dist_data_loader, setup_dataset, setup_model, ) from hyrax.tensorboardx_logger import close_tensorboard_logger, init_tensorboard_logger config = self.config context = {} # Create a results directory and dump our config there results_dir = create_results_dir(config, "infer") # Create a tensorboardX logger init_tensorboard_logger(log_dir=results_dir) dataset = setup_dataset( config, splits=Infer.REQUIRED_DATA_GROUPS + Infer.OPTIONAL_DATA_GROUPS, shuffle=False, ) model = setup_model(config, dataset["infer"]) logger.info( f"{Style.BRIGHT}{Fore.BLACK}{Back.GREEN}Inference model:{Style.RESET_ALL} " f"{model.__class__.__name__}" ) logger.info( f"{Style.BRIGHT}{Fore.BLACK}{Back.GREEN}Inference dataset(s):{Style.RESET_ALL}\n{dataset}" ) # setup_dataset returns a dataset dictionary keyed by split name. # When split_fraction is defined on the "infer" group, setup_dataset # will have already computed split_indices on the DataProvider. # dist_data_loader with split=False will automatically apply a # deterministic sampler to restrict the dataloader to those indices. if isinstance(dataset, dict) and "infer" in dataset: dataset = dataset["infer"] logger.debug(f"Inference dataset has length: {len(dataset)}") # type: ignore[arg-type] data_loader, _ = dist_data_loader(dataset, config, False) load_model_weights(config, model, "infer") log_runtime_config(config, results_dir) context["results_dir"] = results_dir # Log Results directory logger.info(f"Saving inference results at: {results_dir}") model.save(results_dir / "inference_weights.pth") # Create the save batch callback save_batch_callback = create_save_batch_callback(results_dir) # Run inference evaluator = create_evaluator(model, save_batch_callback, config) evaluator.run(data_loader) # Write out a dictionary to map IDs->Batch save_batch_callback.data_writer.commit() # type: ignore[attr-defined] # Write out our tensorboard stuff close_tensorboard_logger() # Log completion logger.info("Inference Complete.") return load_results_dataset(config, results_dir)