Source code for hyrax.verbs.engine
import logging
from hyrax.trace import get_trace, trace_verb_data
from .verb_registry import Verb, hyrax_verb
[docs]
logger = logging.getLogger(__name__)
@hyrax_verb
[docs]
class Engine(Verb):
"""This verb drives inference with an ONNX model in production."""
[docs]
description = "Run inference with an ONNX model."
@staticmethod
[docs]
def setup_parser(parser):
"""Setup parser for engine verb"""
parser.add_argument(
"--model-directory",
type=str,
required=False,
[docs]
help="Directory containing the ONNX model.",
)
[docs]
def run_cli(self, args=None):
"""CLI stub for Engine verb"""
logger.info("`engine` run from CLI.")
self.run(model_directory=args.model_directory if args else None)
@trace_verb_data
[docs]
def run(self, model_directory: str = None):
"""
Run inference with an ONNX model.
This method performs the following steps:
- Read in the user config
- Prepare all the datasets requested
- Implement a simple strategy for reading in batches of data samples
- Process the samples with any custom collate functions as well as a default collate function
- Pass the collated batch to the appropriate to_tensor function
- Send that output to the ONNX-ified model
- Persist the results of inference
Parameters
----------
model_directory : str, optional
Directory containing the ONNX model. If not provided, uses the config file
or finds the most recent ONNX export directory.
"""
from pathlib import Path
import onnxruntime
from hyrax.config_utils import (
create_results_dir,
find_most_recent_results_dir,
)
from hyrax.datasets.data_provider import DataProvider
from hyrax.datasets.result_factories import create_results_writer
from hyrax.plugin_utils import load_prepare_inputs, load_to_tensor
from hyrax.pytorch_ignite import setup_dataset
config = self.config
# Find the directory that contains the ONNX model, prepare_inputs.py, etc.
if model_directory:
input_directory = Path(model_directory)
if not input_directory.exists():
logger.error(f"Model directory {input_directory} does not exist.")
return
elif config["engine"]["model_directory"]:
input_directory = Path(config["engine"]["model_directory"])
if not input_directory.exists():
logger.error(f"Model directory in the config file {input_directory} does not exist.")
return
else:
input_directory = find_most_recent_results_dir(config, "onnx")
if not input_directory:
logger.error("No previous training results directory found for ONNX export.")
return
# Here we load the appropriate prepare_inputs function from onnx output.
# Try loading prepare_inputs first (new name), fall back to to_tensor for backward compatibility
prepare_inputs_fn = load_prepare_inputs(input_directory)
to_tensor_fn = load_to_tensor(input_directory)
if prepare_inputs_fn:
logger.debug("Using prepare_inputs function from model directory.")
elif to_tensor_fn:
# Backward compatibility: use to_tensor if prepare_inputs is not found
logger.warning(
"Using deprecated to_tensor function. "
"Please update to prepare_inputs for future compatibility."
)
prepare_inputs_fn = to_tensor_fn
else:
logger.error("No prepare_inputs or to_tensor function found in the model directory.")
return
# Setup tracing on all data handling functions for this verb (noop if tracing not enabled.)
prepare_inputs_fn = self._setup_trace(prepare_inputs_fn)
# ~ Load the ONNX model from the input directory.
onnx_file_name = input_directory / "model.onnx"
self.ort_session = onnxruntime.InferenceSession(onnx_file_name)
# For now we use `setup_dataset` to get our datasets back. Later we can
# optimize this, because we know that we'll only need the `infer` part
# of the data_request dictionary. And we can assume that we'll be working
# with map-style datasets. But for now, this gets us going.
dataset = setup_dataset(config, splits=("infer",), shuffle=False)
# In the `train` and `infer` verbs, we use `dist_data_loader` to create
# our data loaders. But here in `engine`, we can assume that we can simply
# find the length of our dataset and then iterate over it in batches.
infer_dataset = dataset["infer"]
batch_size = config["data_loader"]["batch_size"]
# Initialize the ResultDatasetWriter to persist results of inference
result_dir = create_results_dir(config, "engine")
self.results_writer = create_results_writer(result_dir)
# Determine which indices to iterate over
# If split_fraction is configured, setup_dataset will have already
# computed and assigned split_indices to the DataProvider. We need
# to respect those indices rather than iterating over the full dataset.
if isinstance(infer_dataset, DataProvider) and infer_dataset.split_indices is not None:
indices_to_process = infer_dataset.split_indices
else:
indices_to_process = list(range(len(infer_dataset)))
# Work through the dataset in steps of `batch_size`
for start_idx in range(0, len(indices_to_process), batch_size):
end_idx = min(start_idx + batch_size, len(indices_to_process))
batch_indices = indices_to_process[start_idx:end_idx]
batch = [infer_dataset[i] for i in batch_indices]
# Here we convert the batch from a list of dictionaries into a
# dictionary of lists by using the DataProvider.collate function.
collated_batch = infer_dataset.collate(batch)
# Pass the collated batch to the prepare_inputs function
prepared_batch = prepare_inputs_fn(collated_batch)
ort_inputs = self.create_ort_inputs(prepared_batch)
onnx_results = self.run_onnx_batch(ort_inputs)
# Finally, we persist the results of inference.
# For now, collated_batch will always have an "object_id" key that
# is a list of strings. However, we should move to a state where the
# object ids are taken from the primary dataset's "primary_id_field",
# which will contain the required data - then remove the "object_id" key.
if "object_id" not in collated_batch:
msg = "Dataset dictionary should be returning object_ids to avoid ordering errors. "
msg += f"Could not determine object IDs from batch. Batch has keys {collated_batch.keys()}"
raise RuntimeError(msg)
# Save the output of the onnx model per batch. Onnx results are
# returned as a 1-element list containing a numpy array with first
# dimension as batch size.
self.results_writer.write_batch(collated_batch["object_id"], [i for i in onnx_results[0]])
# Write the final index file for the inference results.
self.results_writer.commit()
[docs]
def run_onnx_batch(self, ort_inputs):
"""
Run the batch using our onnx runtime session
Only split out because this is when data is mutated and we need to be able to trace it.
"""
return self.ort_session.run(None, ort_inputs)
[docs]
def _setup_trace(self, prepare_inputs_fn):
trace = get_trace()
if trace:
trace.instrument_engine_verb(self)
return trace.instrument_prepare_inputs_fn(prepare_inputs_fn)
return prepare_inputs_fn