Source code for hyrax.verbs.engine

import logging

from hyrax.trace import get_trace, trace_verb_data

from .verb_registry import Verb, hyrax_verb

[docs] logger = logging.getLogger(__name__)
@hyrax_verb
[docs] class Engine(Verb): """This verb drives inference with an ONNX model in production."""
[docs] cli_name = "engine"
[docs] add_parser_kwargs = {}
[docs] description = "Run inference with an ONNX model."
@staticmethod
[docs] def setup_parser(parser): """Setup parser for engine verb""" parser.add_argument( "--model-directory", type=str, required=False,
[docs] help="Directory containing the ONNX model.", )
[docs] def run_cli(self, args=None): """CLI stub for Engine verb""" logger.info("`engine` run from CLI.") self.run(model_directory=args.model_directory if args else None)
@trace_verb_data
[docs] def run(self, model_directory: str = None): """ Run inference with an ONNX model. This method performs the following steps: - Read in the user config - Prepare all the datasets requested - Implement a simple strategy for reading in batches of data samples - Process the samples with any custom collate functions as well as a default collate function - Pass the collated batch to the appropriate to_tensor function - Send that output to the ONNX-ified model - Persist the results of inference Parameters ---------- model_directory : str, optional Directory containing the ONNX model. If not provided, uses the config file or finds the most recent ONNX export directory. """ from pathlib import Path import onnxruntime from hyrax.config_utils import ( create_results_dir, find_most_recent_results_dir, ) from hyrax.datasets.data_provider import DataProvider from hyrax.datasets.result_factories import create_results_writer from hyrax.plugin_utils import load_prepare_inputs, load_to_tensor from hyrax.pytorch_ignite import setup_dataset config = self.config # Find the directory that contains the ONNX model, prepare_inputs.py, etc. if model_directory: input_directory = Path(model_directory) if not input_directory.exists(): logger.error(f"Model directory {input_directory} does not exist.") return elif config["engine"]["model_directory"]: input_directory = Path(config["engine"]["model_directory"]) if not input_directory.exists(): logger.error(f"Model directory in the config file {input_directory} does not exist.") return else: input_directory = find_most_recent_results_dir(config, "onnx") if not input_directory: logger.error("No previous training results directory found for ONNX export.") return # Here we load the appropriate prepare_inputs function from onnx output. # Try loading prepare_inputs first (new name), fall back to to_tensor for backward compatibility prepare_inputs_fn = load_prepare_inputs(input_directory) to_tensor_fn = load_to_tensor(input_directory) if prepare_inputs_fn: logger.debug("Using prepare_inputs function from model directory.") elif to_tensor_fn: # Backward compatibility: use to_tensor if prepare_inputs is not found logger.warning( "Using deprecated to_tensor function. " "Please update to prepare_inputs for future compatibility." ) prepare_inputs_fn = to_tensor_fn else: logger.error("No prepare_inputs or to_tensor function found in the model directory.") return # Setup tracing on all data handling functions for this verb (noop if tracing not enabled.) prepare_inputs_fn = self._setup_trace(prepare_inputs_fn) # ~ Load the ONNX model from the input directory. onnx_file_name = input_directory / "model.onnx" self.ort_session = onnxruntime.InferenceSession(onnx_file_name) # For now we use `setup_dataset` to get our datasets back. Later we can # optimize this, because we know that we'll only need the `infer` part # of the data_request dictionary. And we can assume that we'll be working # with map-style datasets. But for now, this gets us going. dataset = setup_dataset(config, splits=("infer",), shuffle=False) # In the `train` and `infer` verbs, we use `dist_data_loader` to create # our data loaders. But here in `engine`, we can assume that we can simply # find the length of our dataset and then iterate over it in batches. infer_dataset = dataset["infer"] batch_size = config["data_loader"]["batch_size"] # Initialize the ResultDatasetWriter to persist results of inference result_dir = create_results_dir(config, "engine") self.results_writer = create_results_writer(result_dir) # Determine which indices to iterate over # If split_fraction is configured, setup_dataset will have already # computed and assigned split_indices to the DataProvider. We need # to respect those indices rather than iterating over the full dataset. if isinstance(infer_dataset, DataProvider) and infer_dataset.split_indices is not None: indices_to_process = infer_dataset.split_indices else: indices_to_process = list(range(len(infer_dataset))) # Work through the dataset in steps of `batch_size` for start_idx in range(0, len(indices_to_process), batch_size): end_idx = min(start_idx + batch_size, len(indices_to_process)) batch_indices = indices_to_process[start_idx:end_idx] batch = [infer_dataset[i] for i in batch_indices] # Here we convert the batch from a list of dictionaries into a # dictionary of lists by using the DataProvider.collate function. collated_batch = infer_dataset.collate(batch) # Pass the collated batch to the prepare_inputs function prepared_batch = prepare_inputs_fn(collated_batch) ort_inputs = self.create_ort_inputs(prepared_batch) onnx_results = self.run_onnx_batch(ort_inputs) # Finally, we persist the results of inference. # For now, collated_batch will always have an "object_id" key that # is a list of strings. However, we should move to a state where the # object ids are taken from the primary dataset's "primary_id_field", # which will contain the required data - then remove the "object_id" key. if "object_id" not in collated_batch: msg = "Dataset dictionary should be returning object_ids to avoid ordering errors. " msg += f"Could not determine object IDs from batch. Batch has keys {collated_batch.keys()}" raise RuntimeError(msg) # Save the output of the onnx model per batch. Onnx results are # returned as a 1-element list containing a numpy array with first # dimension as batch size. self.results_writer.write_batch(collated_batch["object_id"], [i for i in onnx_results[0]]) # Write the final index file for the inference results. self.results_writer.commit()
[docs] def create_ort_inputs(self, prepared_batch): """ Create the inputs array for the ONNX model using the expected inputs from the loaded ONNX model and the type and shape of the prepared batch. """ ort_inputs = {} if isinstance(prepared_batch, tuple): for i in range(len(prepared_batch)): # For a supervised model, we expect that at least one of the # element in the prepared batch will be empty, so we only # add non-empty inputs. if len(prepared_batch[i]): ort_inputs[self.ort_session.get_inputs()[i].name] = prepared_batch[i] else: ort_inputs = {self.ort_session.get_inputs()[0].name: prepared_batch} return ort_inputs
[docs] def run_onnx_batch(self, ort_inputs): """ Run the batch using our onnx runtime session Only split out because this is when data is mutated and we need to be able to trace it. """ return self.ort_session.run(None, ort_inputs)
[docs] def _setup_trace(self, prepare_inputs_fn): trace = get_trace() if trace: trace.instrument_engine_verb(self) return trace.instrument_prepare_inputs_fn(prepare_inputs_fn) return prepare_inputs_fn