Source code for hyrax.verbs.lookup
import logging
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Union
import numpy as np
from .verb_registry import Verb, hyrax_verb
[docs]
logger = logging.getLogger(__name__)
@hyrax_verb
[docs]
class Lookup(Verb):
"""Look up an inference result using the ID of a data member"""
[docs]
description = "Look up an inference result using the ID of a data member."
@staticmethod
[docs]
def setup_parser(parser: ArgumentParser):
"""Set up our arguments by configuring a subparser
[docs]
Parameters
----------
parser : ArgumentParser
The sub-parser to configure
"""
parser.add_argument("-i", "--id", type=str, required=True, help="ID of image")
parser.add_argument(
"-r", "--results-dir", type=str, required=False, help="Directory containing inference results."
)
[docs]
def run_cli(self, args: Namespace | None = None):
"""Entrypoint to Lookup from the CLI.
Parameters
----------
args : Optional[Namespace], optional
The parsed command line arguments
"""
logger.info("Lookup run from cli")
if args is None:
raise RuntimeError("Run CLI called with no arguments.")
# This is where we map from CLI parsed args to a
# self.run (args) call.
vector = self.run(id=args.id, results_dir=args.results_dir)
if vector is None:
logger.info("No inference result found")
else:
logger.info("Inference result found")
print(vector)
[docs]
def run(self, id: str, results_dir: Union[Path, str] | None = None) -> np.ndarray | None:
"""Lookup the latent-space representation of a particular ID
Requires the relevant dataset to be configured, and for inference to have been run.
Parameters
----------
id : str
The ID of the input data to look up the inference result
results_dir : str, Optional
The directory containing the inference results.
Returns
-------
Optional[np.ndarray]
The output tensor of the model for the given input.
"""
from hyrax.datasets.result_factories import load_results_dataset
inference_dataset = load_results_dataset(self.config, results_dir=results_dir, verb="infer")
all_ids = np.array(inference_dataset.ids())
lookup_index = np.argwhere(all_ids == id)
if len(lookup_index) == 1:
result = inference_dataset[lookup_index[0]]
return np.asarray(result)
elif len(lookup_index) > 1:
raise RuntimeError("Inference result directory has duplicate ID numbers")
return None