Source code for hyrax.verbs.to_onnx

import logging

from .verb_registry import Verb, hyrax_verb

[docs] logger = logging.getLogger(__name__)
@hyrax_verb
[docs] class ToOnnx(Verb): """Export the model to ONNX format"""
[docs] cli_name = "to_onnx"
[docs] add_parser_kwargs = {}
[docs] @staticmethod
[docs] def setup_parser(parser): """Setup parser for ONNX export verb""" parser.add_argument( "--input-model-directory", type=str, required=False,
help="Directory containing the trained model to export.", )
[docs] def run_cli(self, args=None): """Run the ONNX export verb from the CLI""" logger.info("Exporting model to ONNX format.") self.run(args.input_model_directory)
[docs] def run(self, input_model_directory: str = None): """Export the model to ONNX format and save it to the specified path.""" import shutil from pathlib import Path from hyrax.config_utils import ( ConfigManager, create_results_dir, find_most_recent_results_dir, ) from hyrax.model_exporters import export_to_onnx from hyrax.pytorch_ignite import dist_data_loader, setup_dataset, setup_model config = self.config # Resolve the input directory in this order; 1) input_model_directory arg, # 2) config value, 3) most recent train results if input_model_directory: input_directory = Path(input_model_directory) if not input_directory.exists(): logger.error(f"Input model directory {input_directory} does not exist.") return elif config["onnx"]["input_model_directory"]: input_directory = Path(config["onnx"]["input_model_directory"]) if not input_directory.exists(): logger.error(f"Input model directory in the config file {input_directory} does not exist.") return else: input_directory = find_most_recent_results_dir(config, "train") if not input_directory: logger.error("No previous training results directory found for ONNX export.") return output_dir = create_results_dir(config, "onnx") # grab the config file from the input directory, and render it. config_file = input_directory / "runtime_config.toml" config_manager = ConfigManager(runtime_config_filepath=config_file) config_from_training = config_manager.config # copy the to_tensor.py file from the input directory to the output directory to_tensor_src = input_directory / "to_tensor.py" to_tensor_dst = output_dir / "to_tensor.py" if to_tensor_src.exists(): shutil.copy(to_tensor_src, to_tensor_dst) # Use the config file to locate and assemble the trained weight file path weights_file_path = input_directory / config_from_training["train"]["weights_filename"] if not weights_file_path.exists(): raise RuntimeError(f"Could not find trained model weights: {weights_file_path}") # Use the config in the model directory to load the dataset(s) and create # The data loader instance to provide a data sample to the ONNX exporter. dataset = setup_dataset(config_from_training) model = setup_model(config_from_training, dataset["infer"]) # Load the trained weights and send the model to the CPU for ONNX export. model.load(weights_file_path) model.train(False) # Create an instance of the dataloader so that we can request a sample batch. infer_data_loader, _ = dist_data_loader(dataset["infer"], config_from_training, False) # Generate the `context` dictionary that will be provided to the ONNX exporter. context = { "results_dir": output_dir, "ml_framework": "pytorch", } # Get a sample of input data. batch_sample = next(iter(infer_data_loader)) batch_sample = model.to_tensor(batch_sample) export_to_onnx(model, batch_sample, config, context)