hyrax.trace
===========

.. py:module:: hyrax.trace


Attributes
----------

.. autoapisummary::

   hyrax.trace.trace_result
   hyrax.trace.logger


Classes
-------

.. autoapisummary::

   hyrax.trace.TraceContext
   hyrax.trace.TraceDef
   hyrax.trace.TracePrintable
   hyrax.trace.TraceResult
   hyrax.trace.TraceStage
   hyrax.trace.TraceCall


Functions
---------

.. autoapisummary::

   hyrax.trace.trace_dataset_func
   hyrax.trace.trace_model_func
   hyrax.trace.trace_func
   hyrax.trace.trace_verb_data
   hyrax.trace.get_trace
   hyrax.trace.reset_trace


Module Contents
---------------

.. py:data:: trace_result
   :value: None


.. py:data:: logger

.. py:function:: trace_dataset_func(func=None, *, params_to_capture=None, result_name=None, stage_name='dataset_getter')

   Decorator to add tracing to a custom dataset function. By default captures all parameters and
   return value placing the function in the 'dataset_getter' stage.


.. py:function:: trace_model_func(func=None, *, params_to_capture=None, result_name=None, stage_name='evaluation')

   Decorator to add tracing to a custom model function. By default captures all parameters and
   return value placing the function in the 'evaluation' stage.


.. py:function:: trace_func(func=None, *, params_to_capture=None, result_name=None, stage_name)

   Generic decorator to trace a user-defined function in a particular stage.

   The name of a Trace Result stage must be provided to use this decorator directly.


.. py:function:: trace_verb_data(verb_run_func)

   Simple wrapper decorator for verbs to implement the trace=<num data items> interface

   This decorator does two things:

   * Adds a ``trace`` keyword argument that controls how many data items are
     run through the verb by modifying Hyrax config and shimming selected
     DataProvider methods.
   * Preserves the verb's return value in normal mode, but when ``trace`` is
     set it returns a TraceResult object that captures call order, parameter
     values, and return values for major steps in Hyrax's default pipeline.



.. py:class:: TraceContext(trace_arg: Any, config)

   In order to trace we: 1) shim class methods and 2) modify hyrax config.

   Due to the class-level shims it is absolutely vital that even during exception handling we are able to
   remove these shims. This removal returns classes to their pre-trace state and keeps the effects of
   the shimming contained to the runtime of a single verb in a long-running notebook.

   Therefore verbs using data tracing should use the @trace_verb_data decorator or
   implement the pattern below:

   .. code-block:: python

       with TraceContext(trace, self.config) as modified_config:
           self.config = modified_config

           ...verb code...

           return get_trace() if trace else retval



   .. py:attribute:: trace_arg


   .. py:attribute:: config


   .. py:method:: __enter__()


   .. py:method:: __exit__(exc_type, exc_value, traceback)


.. py:function:: get_trace()

   Get the current global trace results object. Returns None if no trace is active


.. py:function:: reset_trace()

   Reset the current global trace results object, removing all captured data
   Valid to call whether trace is active.


.. py:class:: TraceDef

   A record that needs to be filled out whenever a function is instrumented for tracing in TraceResult

   Contains values that must be passed through TraceResult.instrument_*, TraceResult.make_shim,
   and TraceResult.trace_call in order that TraceResults are legible when printed.


   .. py:attribute:: disp_name
      :type:  str


   .. py:attribute:: func_name
      :type:  str


   .. py:attribute:: params_to_capture
      :type:  dict[str, int]


   .. py:attribute:: result_name
      :type:  str


   .. py:attribute:: stage_name
      :type:  str


.. py:class:: TracePrintable

   Bases: :py:obj:`abc.ABC`


   Base class defining foundational behavior for TraceResult, TraceStage, and TraceCall which are the
   user-accessible and building blocks of a trace.

   Child classes must implement __str__ for printing and __getitem__ for inspection.


   .. py:method:: __repr__()

      __repr__ and __str__ mean the same thing. This goes against python philosophy on __repr__ being
      essentially a serialized string of the class; however notebooks call __repr__ to display objects,
      and we would like the __str__ code to have correct connotation for robots an humans viewing the
      code through a peephole. That is: __str__ means "Human readable and perhaps incomplete representation"



   .. py:method:: __getattr__(attr)

      __getattr__ always calls getitem. This implements the notion that if you get a trace object in
      a notebook, you ought to be able to equally well say trace_result["evaluation"] and
      trace_result.evaluation to ask for just the function calls in the evaluation stage. The intent is
      to make it so that any attempt by the user to look inside the class routes to the things they probably
      want.



   .. py:method:: __dir__()

      Force implementation of __dir__ on subclasses to direct typeahead in notebook environments toward
      valid identifiers within the trace.



   .. py:method:: __str__()
      :abstractmethod:



   .. py:method:: __getitem__(idx)
      :abstractmethod:



   .. py:method:: _valid_keys()
      :abstractmethod:



.. py:class:: TraceResult(trace_batch_size: int)

   Bases: :py:obj:`TracePrintable`


   Result of a hyrax data tracing run, returned from certain data-handling verbs when trace=<non-zero number>
   is passed in a notebook.

   This object represents a small set of calls intended to track a handful of data values through the
   entire hyrax data processing pipeline in order to enable debugging of data issues.

   This object is meant to be printed out in a notebook, and contains multiple stages that are accessible
   using either trace_result.stage_name or trace_result["stage_name"] syntax.

   1. "dataset_getter" stage
   In this stage the HyraxQL getter functions on whatever datasets are in use are called. If you implemented
   a custom dataset, these are functions you wrote. Any dataset class functions decorated with
   @trace_dataset_func also have calls reported in this stage

   2. "resolve_data" stage
   In this stage DataProvider.resolve_data combines the results of the individual data getters into
   data dictionaries which each contain all requested columns for each datum

   3. "collate" stage
   Each data column is combined into a single batch tensor in this stage. If your dataset defines a custom
   collate function (e.g. time-series data with different lengths) it will be evaluated in this stage.
   Any NaN handling that is configured into hyrax is also performed in this stage.

   4. "prepare_inputs" stage
   The ML Model's prepare_inputs function is called in this stage and converts the data dictionary
   containing each column of batched tensor data into a single batch tensor that will form the input to
   the model's evaluation functions. If the model is doing supervized learning, the output will be
   a tuple of numpy arrays (inputs_0, [inputs_1, ..., inputs_n], labels)

   5. "evaluation" stage
   The ML model is evaluated or the training loop is run. Functions will be functions run on the model
   during this process, including `train_batch` , `forward` and similar. If you implemented a custom model
   you wrote these functions. Any model functions decorated with @trace_model_func will also show up here.



   .. py:attribute:: shimmed_funcs
      :value: []



   .. py:attribute:: trace_batch_size


   .. py:method:: reset()

      Reset the Trace Result object to having no calls



   .. py:method:: __getitem__(ref)


   .. py:method:: _valid_keys()


   .. py:method:: reduce_len(cls)

      Inserts a len method which reduces the length of the passed in class in order to
      accommodate early return in trace mode.

      This is necessary because hyrax does not control the main loop of inference/training
      for most ML verbs, so the layer that does control it must get an appropriate stop condition
      from Hyrax's data structures



   .. py:method:: remove_class_level_shims()

      Clean up all of our class level shims. This should happen when verbs exit
      even if via exception. See TraceContext for the mechanism by which this is achieved.



   .. py:method:: trace_call(trace_def: TraceDef, *args)

      This is the main location where data is collected. Shim functions call this method in order to
      log to the trace that a call to the shimmed function has occurred.

      We capture parameters and return value here.



   .. py:method:: instrument_prepare_inputs(model)

      Instrument the prepare_inputs function on an instance of a model. This occurs when we load the
      model and will be using a prepare_inputs function which was attached to the model by
      hyrax machinery (@hyrax_model). This is usually a old to_tensor function, a loaded prepare_inputs
      function from a checkpoint or our default prepare_inputs function.

      Note: Class level shimming of prepare_inputs occurs in the constructor and covers the case where
      the model class defines prepare_inputs directly



   .. py:method:: instrument_prepare_inputs_fn(prepare_inputs_fn)

      Instrument the prepare_inputs function on a bare function. This is used in the engine verb
      when we don't have a pytorch model class to attach to.



   .. py:method:: instrument_dataset_getter(dataset, getter, friendly_name, field_name)

      Instrument a dataset get_* function. Called by DataProvider to insert shims before
      any betters are called



   .. py:method:: instrument_dataset_collate(dataset, collate_fn, friendly_name)

      Instrument a dataset collate function. Also called by DataProvider to insert shims
      into all the custom dataset collate functions it finds during dataset preparation.



   .. py:method:: instrument_dataprovider(dataprovider)

      Instrument the various data handling functions in DataProvider.

      We use instance level shims here



   .. py:method:: instrument_engine_verb(engine_verb)

      Instrument the various data handling functions in the engine verb.

      These are instance level shims, because by the time we know whether a verb is tracing or not, the
      verb class instance has already been constructed, so we must operate on the instance.



   .. py:method:: instrument_instance_data_handler(obj, original_member, trace_def: TraceDef)

      Inserts trace instrumentation on a method of a python class instance.

      DOES NOT WORK ON classes, see instrument_class_data_handler.

      :param obj: The instance of the object that has the member function we are shimming
      :type obj: class instance
      :param original_member: The callable we are shimming. Obtain via ``obj.method_name`` or
                              ``getattr(obj, "method_name")``.
      :type original_member: callable
      :param trace_def: A TraceDef defining what we're tracing from this function.
      :type trace_def: TraceDef

      :returns: The shim callable that has been set on ``obj`` at ``trace_def.func_name``.
      :rtype: callable



   .. py:method:: instrument_class_data_handler(cls, trace_def: TraceDef)

      Inserts trace instrumentation on a method of a python class.

      The shimmed method is placed on the class and is not returned.

      DOES NOT WORK ON class instances, see instrument_instance_data_handler.

      :param cls: The class that has the member function we are shimming
      :type cls: class
      :param trace_def: A TraceDef defining what we're tracing from this function.
      :type trace_def: TraceDef

      :rtype: None



   .. py:method:: _make_shim(original_func, trace_def: TraceDef)
      :staticmethod:


      Make a shim function for the instrument_* functions to use.

      :param original_func: The function (or bound method) being shimmed.
      :type original_func: callable
      :param trace_def: Describes what data to capture during the call.
      :type trace_def: TraceDef



   .. py:method:: __str__()

      Print out the stages of the trace.



.. py:class:: TraceStage

   Bases: :py:obj:`TracePrintable`


   This is a container that holds a list of TraceCalls in order of execution, representing an entire
   stage of a TraceResult.

   It is intended to be printed and examined from a notebook.

   It supports two modes of user access through [] / __getitem__:
     1) [] with a number gets access to a TraceCall by number
     2) [] with a function name gets access to all of those functions as a list[TraceCall]


   .. py:attribute:: calls
      :value: []



   .. py:attribute:: func_dict


   .. py:method:: append(call_record)

      Append a single call record to this stage.



   .. py:method:: __getitem__(idx_or_func_name)


   .. py:method:: _valid_keys()


   .. py:method:: __len__()


   .. py:method:: __str__()


   .. py:method:: _repr_calls()


.. py:class:: TraceCall

   Bases: :py:obj:`TracePrintable`


   An individual function call that is part of a trace, Captures argument and return values of
   the given function, which are accessible via [] or . operators.

   This object is intended to be printed and examined from a notebook


   .. py:attribute:: disp_name
      :type:  str


   .. py:attribute:: func_name
      :type:  str


   .. py:attribute:: params
      :type:  dict[str, Any]


   .. py:attribute:: retval
      :type:  dict[str, Any]


   .. py:attribute:: duration_ns
      :type:  float


   .. py:method:: __str__()


   .. py:method:: __getitem__(key)


   .. py:method:: _valid_keys()


   .. py:method:: _repr_value(param_value)


