hyrax.trace

Contents

hyrax.trace#

Attributes#

Classes#

TraceContext

In order to trace we: 1) shim class methods and 2) modify hyrax config.

TraceDef

A record that needs to be filled out whenever a function is instrumented for tracing in TraceResult

TracePrintable

Base class defining foundational behavior for TraceResult, TraceStage, and TraceCall which are the

TraceResult

Result of a hyrax data tracing run, returned from certain data-handling verbs when trace=<non-zero number>

TraceStage

This is a container that holds a list of TraceCalls in order of execution, representing an entire

TraceCall

An individual function call that is part of a trace, Captures argument and return values of

Functions#

trace_dataset_func([func, params_to_capture, ...])

Decorator to add tracing to a custom dataset function. By default captures all parameters and

trace_model_func([func, params_to_capture, ...])

Decorator to add tracing to a custom model function. By default captures all parameters and

trace_func([func, params_to_capture, result_name])

Generic decorator to trace a user-defined function in a particular stage.

trace_verb_data(verb_run_func)

Simple wrapper decorator for verbs to implement the trace=<num data items> interface

get_trace()

Get the current global trace results object. Returns None if no trace is active

reset_trace()

Reset the current global trace results object, removing all captured data

Module Contents#

trace_result = None[source]#
logger[source]#
trace_dataset_func(func=None, *, params_to_capture=None, result_name=None, stage_name='dataset_getter')[source]#

Decorator to add tracing to a custom dataset function. By default captures all parameters and return value placing the function in the ‘dataset_getter’ stage.

trace_model_func(func=None, *, params_to_capture=None, result_name=None, stage_name='evaluation')[source]#

Decorator to add tracing to a custom model function. By default captures all parameters and return value placing the function in the ‘evaluation’ stage.

trace_func(func=None, *, params_to_capture=None, result_name=None, stage_name)[source]#

Generic decorator to trace a user-defined function in a particular stage.

The name of a Trace Result stage must be provided to use this decorator directly.

trace_verb_data(verb_run_func)[source]#

Simple wrapper decorator for verbs to implement the trace=<num data items> interface

This decorator does two things:

  • Adds a trace keyword argument that controls how many data items are run through the verb by modifying Hyrax config and shimming selected DataProvider methods.

  • Preserves the verb’s return value in normal mode, but when trace is set it returns a TraceResult object that captures call order, parameter values, and return values for major steps in Hyrax’s default pipeline.

class TraceContext(trace_arg: Any, config)[source]#

In order to trace we: 1) shim class methods and 2) modify hyrax config.

Due to the class-level shims it is absolutely vital that even during exception handling we are able to remove these shims. This removal returns classes to their pre-trace state and keeps the effects of the shimming contained to the runtime of a single verb in a long-running notebook.

Therefore verbs using data tracing should use the @trace_verb_data decorator or implement the pattern below:

with TraceContext(trace, self.config) as modified_config:
    self.config = modified_config

    ...verb code...

    return get_trace() if trace else retval
trace_arg[source]#
config[source]#
__enter__()[source]#
__exit__(exc_type, exc_value, traceback)[source]#
get_trace()[source]#

Get the current global trace results object. Returns None if no trace is active

reset_trace()[source]#

Reset the current global trace results object, removing all captured data Valid to call whether trace is active.

class TraceDef[source]#

A record that needs to be filled out whenever a function is instrumented for tracing in TraceResult

Contains values that must be passed through TraceResult.instrument_*, TraceResult.make_shim, and TraceResult.trace_call in order that TraceResults are legible when printed.

disp_name: str[source]#
func_name: str[source]#
params_to_capture: dict[str, int][source]#
result_name: str[source]#
stage_name: str[source]#
class TracePrintable[source]#

Bases: abc.ABC

Base class defining foundational behavior for TraceResult, TraceStage, and TraceCall which are the user-accessible and building blocks of a trace.

Child classes must implement __str__ for printing and __getitem__ for inspection.

__repr__()[source]#

__repr__ and __str__ mean the same thing. This goes against python philosophy on __repr__ being essentially a serialized string of the class; however notebooks call __repr__ to display objects, and we would like the __str__ code to have correct connotation for robots an humans viewing the code through a peephole. That is: __str__ means “Human readable and perhaps incomplete representation”

__getattr__(attr)[source]#

__getattr__ always calls getitem. This implements the notion that if you get a trace object in a notebook, you ought to be able to equally well say trace_result[“evaluation”] and trace_result.evaluation to ask for just the function calls in the evaluation stage. The intent is to make it so that any attempt by the user to look inside the class routes to the things they probably want.

__dir__()[source]#

Force implementation of __dir__ on subclasses to direct typeahead in notebook environments toward valid identifiers within the trace.

abstractmethod __str__()[source]#
abstractmethod __getitem__(idx)[source]#
abstractmethod _valid_keys()[source]#
class TraceResult(trace_batch_size: int)[source]#

Bases: TracePrintable

Result of a hyrax data tracing run, returned from certain data-handling verbs when trace=<non-zero number> is passed in a notebook.

This object represents a small set of calls intended to track a handful of data values through the entire hyrax data processing pipeline in order to enable debugging of data issues.

This object is meant to be printed out in a notebook, and contains multiple stages that are accessible using either trace_result.stage_name or trace_result[“stage_name”] syntax.

1. “dataset_getter” stage In this stage the HyraxQL getter functions on whatever datasets are in use are called. If you implemented a custom dataset, these are functions you wrote. Any dataset class functions decorated with @trace_dataset_func also have calls reported in this stage

2. “resolve_data” stage In this stage DataProvider.resolve_data combines the results of the individual data getters into data dictionaries which each contain all requested columns for each datum

3. “collate” stage Each data column is combined into a single batch tensor in this stage. If your dataset defines a custom collate function (e.g. time-series data with different lengths) it will be evaluated in this stage. Any NaN handling that is configured into hyrax is also performed in this stage.

4. “prepare_inputs” stage The ML Model’s prepare_inputs function is called in this stage and converts the data dictionary containing each column of batched tensor data into a single batch tensor that will form the input to the model’s evaluation functions. If the model is doing supervized learning, the output will be a tuple of numpy arrays (inputs_0, [inputs_1, …, inputs_n], labels)

5. “evaluation” stage The ML model is evaluated or the training loop is run. Functions will be functions run on the model during this process, including train_batch , forward and similar. If you implemented a custom model you wrote these functions. Any model functions decorated with @trace_model_func will also show up here.

shimmed_funcs = [][source]#
trace_batch_size[source]#
reset()[source]#

Reset the Trace Result object to having no calls

__getitem__(ref)[source]#
_valid_keys()[source]#
reduce_len(cls)[source]#

Inserts a len method which reduces the length of the passed in class in order to accommodate early return in trace mode.

This is necessary because hyrax does not control the main loop of inference/training for most ML verbs, so the layer that does control it must get an appropriate stop condition from Hyrax’s data structures

remove_class_level_shims()[source]#

Clean up all of our class level shims. This should happen when verbs exit even if via exception. See TraceContext for the mechanism by which this is achieved.

trace_call(trace_def: TraceDef, *args)[source]#

This is the main location where data is collected. Shim functions call this method in order to log to the trace that a call to the shimmed function has occurred.

We capture parameters and return value here.

instrument_prepare_inputs(model)[source]#

Instrument the prepare_inputs function on an instance of a model. This occurs when we load the model and will be using a prepare_inputs function which was attached to the model by hyrax machinery (@hyrax_model). This is usually a old to_tensor function, a loaded prepare_inputs function from a checkpoint or our default prepare_inputs function.

Note: Class level shimming of prepare_inputs occurs in the constructor and covers the case where the model class defines prepare_inputs directly

instrument_prepare_inputs_fn(prepare_inputs_fn)[source]#

Instrument the prepare_inputs function on a bare function. This is used in the engine verb when we don’t have a pytorch model class to attach to.

instrument_dataset_getter(dataset, getter, friendly_name, field_name)[source]#

Instrument a dataset get_* function. Called by DataProvider to insert shims before any betters are called

instrument_dataset_collate(dataset, collate_fn, friendly_name)[source]#

Instrument a dataset collate function. Also called by DataProvider to insert shims into all the custom dataset collate functions it finds during dataset preparation.

instrument_dataprovider(dataprovider)[source]#

Instrument the various data handling functions in DataProvider.

We use instance level shims here

instrument_engine_verb(engine_verb)[source]#

Instrument the various data handling functions in the engine verb.

These are instance level shims, because by the time we know whether a verb is tracing or not, the verb class instance has already been constructed, so we must operate on the instance.

instrument_instance_data_handler(obj, original_member, trace_def: TraceDef)[source]#

Inserts trace instrumentation on a method of a python class instance.

DOES NOT WORK ON classes, see instrument_class_data_handler.

Parameters:
  • obj (class instance) – The instance of the object that has the member function we are shimming

  • original_member (callable) – The callable we are shimming. Obtain via obj.method_name or getattr(obj, "method_name").

  • trace_def (TraceDef) – A TraceDef defining what we’re tracing from this function.

Returns:

The shim callable that has been set on obj at trace_def.func_name.

Return type:

callable

instrument_class_data_handler(cls, trace_def: TraceDef)[source]#

Inserts trace instrumentation on a method of a python class.

The shimmed method is placed on the class and is not returned.

DOES NOT WORK ON class instances, see instrument_instance_data_handler.

Parameters:
  • cls (class) – The class that has the member function we are shimming

  • trace_def (TraceDef) – A TraceDef defining what we’re tracing from this function.

Return type:

None

static _make_shim(original_func, trace_def: TraceDef)[source]#

Make a shim function for the instrument_* functions to use.

Parameters:
  • original_func (callable) – The function (or bound method) being shimmed.

  • trace_def (TraceDef) – Describes what data to capture during the call.

__str__()[source]#

Print out the stages of the trace.

class TraceStage[source]#

Bases: TracePrintable

This is a container that holds a list of TraceCalls in order of execution, representing an entire stage of a TraceResult.

It is intended to be printed and examined from a notebook.

It supports two modes of user access through [] / __getitem__:
  1. [] with a number gets access to a TraceCall by number

  2. [] with a function name gets access to all of those functions as a list[TraceCall]

calls = [][source]#
func_dict[source]#
append(call_record)[source]#

Append a single call record to this stage.

__getitem__(idx_or_func_name)[source]#
_valid_keys()[source]#
__len__()[source]#
__str__()[source]#
_repr_calls()[source]#
class TraceCall[source]#

Bases: TracePrintable

An individual function call that is part of a trace, Captures argument and return values of the given function, which are accessible via [] or . operators.

This object is intended to be printed and examined from a notebook

disp_name: str[source]#
func_name: str[source]#
params: dict[str, Any][source]#
retval: dict[str, Any][source]#
duration_ns: float[source]#
__str__()[source]#
__getitem__(key)[source]#
_valid_keys()[source]#
_repr_value(param_value)[source]#