import logging
from pathlib import Path
from typing import Union
from colorama import Back, Fore, Style
from .verb_registry import Verb, hyrax_verb
[docs]
logger = logging.getLogger(__name__)
@hyrax_verb
[docs]
class Infer(Verb):
"""Inference verb"""
[docs]
@staticmethod
[docs]
def setup_parser(parser):
"""We don't need any parser setup for CLI opts"""
pass
[docs]
def run_cli(self, args=None):
"""CLI stub for Infer verb"""
logger.info("infer run from CLI")
self.run()
[docs]
def run(self):
"""Run inference on a model using a dataset
Parameters
----------
config : dict
The parsed config file as a nested dict
"""
import numpy as np
from tensorboardX import SummaryWriter
from torch import Tensor
from hyrax.config_utils import (
create_results_dir,
log_runtime_config,
)
from hyrax.data_sets.inference_dataset import InferenceDataSet, InferenceDataSetWriter
from hyrax.pytorch_ignite import (
create_evaluator,
dist_data_loader,
setup_dataset,
setup_model,
)
config = self.config
context = {}
# Create a results directory and dump our config there
results_dir = create_results_dir(config, "infer")
# Create a tensorboardX logger
tensorboardx_logger = SummaryWriter(log_dir=results_dir)
dataset = setup_dataset(config, tensorboardx_logger)
model = setup_model(config, dataset["infer"])
logger.info(
f"{Style.BRIGHT}{Fore.BLACK}{Back.GREEN}Inference model:{Style.RESET_ALL} "
f"{model.__class__.__name__}"
)
logger.info(
f"{Style.BRIGHT}{Fore.BLACK}{Back.GREEN}Inference dataset(s):{Style.RESET_ALL}\n{dataset}"
)
# Inference doesnt work at all with the dataloader doing additional shuffling:
if config["data_loader"]["shuffle"]:
msg = "Data loader shuffling not supported in inference mode. "
msg += "Setting config['data_loader']['shuffle'] = False"
logger.warning(msg)
config["data_loader"]["shuffle"] = False
# If `dataset` is a dict containing the key "infer", we'll pull that out.
# The only time it wouldn't be is if the dataset is an iterable dataset.
if isinstance(dataset, dict) and "infer" in dataset:
dataset = dataset["infer"]
if dataset.is_map():
logger.debug(f"Inference dataset has length: {len(dataset)}") # type: ignore[arg-type]
data_loader, data_loader_indexes = dist_data_loader(dataset, config, False)
Infer.load_model_weights(config, model)
log_runtime_config(config, results_dir)
context["results_dir"] = results_dir
# Log Results directory
logger.info(f"Saving inference results at: {results_dir}")
model.save(results_dir / "inference_weights.pth")
data_writer = InferenceDataSetWriter(dataset, results_dir)
# These are values the _save_batch callback needs to run
write_index = 0
object_ids = np.array(list(dataset.ids()))[data_loader_indexes] # type: ignore[attr-defined]
def _save_batch(batch: Union[Tensor, list, tuple, dict], batch_results: Tensor):
"""Receive and write results tensors to results_dir immediately
This function writes a single numpy binary file for each object.
"""
nonlocal write_index
nonlocal object_ids
nonlocal data_writer
batch_len = len(batch_results)
batch_results = batch_results.detach().to("cpu")
batch_is_list = isinstance(batch, (tuple, list))
# Batch lacks ids if it is a Tensor, or a list/tuple of tensors
batch_lacks_ids = isinstance(batch, Tensor) or (
batch_is_list and isinstance(batch.get(0), Tensor)
)
# Batch has IDs if it is dict of tensors with the needed key
batch_has_ids = isinstance(batch, dict) and "object_id" in batch
if batch_lacks_ids:
# This fallback is brittle to any re-ordering of data that occurs during data loading
batch_object_ids = [
object_ids[id] for id in range(write_index, write_index + len(batch_results))
]
elif batch_has_ids:
if isinstance(batch["object_id"], list):
batch_object_ids = batch["object_id"]
else:
batch_object_ids = batch["object_id"].tolist()
elif isinstance(batch, dict):
msg = "Dataset dictionary should be returning object_ids to avoid ordering errors. "
msg += "Modify the __getitem__ or __iter__ function of your dataset to include 'object_id' "
msg += "with unique values per data member in the dictionary it returns."
raise RuntimeError(msg)
else:
msg = f"Could not determine object IDs from batch. Batch has type {type(batch)}"
raise RuntimeError(msg)
# Save results from this batch in a numpy file as a structured array
data_writer.write_batch(np.array(batch_object_ids), [t.numpy() for t in batch_results])
write_index += batch_len
# Run inference
evaluator = create_evaluator(model, _save_batch, config)
evaluator.run(data_loader)
# Write out a dictionary to map IDs->Batch
data_writer.write_index()
# Write out our tensorboard stuff
tensorboardx_logger.close()
# Log completion
logger.info("Inference Complete.")
return InferenceDataSet(config, results_dir)
@staticmethod
[docs]
def load_model_weights(config, model):
"""Loads the model weights from a file. Raises RuntimeError if this is not possible due to
config, missing or malformed file
Parameters
----------
config : dict
Full runtime configuration
model : nn.Module
The model class to load weights into
"""
from hyrax.config_utils import find_most_recent_results_dir
weights_file: Union[str, Path] | None = (
config["infer"]["model_weights_file"] if config["infer"]["model_weights_file"] else None
)
if weights_file is None:
recent_results_path = find_most_recent_results_dir(config, "train")
if recent_results_path is None:
raise RuntimeError("Must define model_weights_file in the [infer] section of hyrax config.")
weights_file = recent_results_path / config["train"]["weights_filename"]
# Ensure weights file is a path object.
weights_file_path = Path(weights_file)
if not weights_file_path.exists():
raise RuntimeError(f"Model Weights file {weights_file_path} does not exist")
try:
model.load(weights_file_path)
config["infer"]["model_weights_file"] = str(weights_file_path)
except Exception as err:
msg = f"Model weights file {weights_file_path} did not load properly. Are you sure you are "
msg += "predicting using the correct model"
raise RuntimeError(msg) from err