import logging
import warnings
from pathlib import Path
from colorama import Back, Fore, Style
from hyrax.trace import trace_verb_data
from .verb_registry import Verb, hyrax_verb
[docs]
logger = logging.getLogger(__name__)
@hyrax_verb
[docs]
class Train(Verb):
"""Train verb"""
[docs]
description = "Train a model using provided data."
# Dataset groups that the Train verb knows about.
# REQUIRED_DATA_GROUPS must be present in the dataset dict returned by setup_dataset.
# OPTIONAL_DATA_GROUPS are used when present but do not cause an error if absent.
[docs]
[docs]
REQUIRED_DATA_GROUPS = ("train",)
[docs]
OPTIONAL_DATA_GROUPS = ("validate", "test")
@staticmethod
[docs]
def setup_parser(parser):
"""We don't need any parser setup for CLI opts"""
pass
[docs]
def run_cli(self, args=None):
"""CLI stub for Train verb"""
logger.info("train run from CLI.")
self.run()
@trace_verb_data
[docs]
def run(self):
"""
Run the training process for the configured model and data loader.
Returns the trained model.
"""
import mlflow
from hyrax.config_utils import create_results_dir, log_runtime_config
from hyrax.gpu_monitor import GpuMonitor
from hyrax.pytorch_ignite import (
attach_best_checkpoint,
create_trainer,
create_validator,
dist_data_loader,
setup_dataset,
setup_model,
)
from hyrax.tensorboardx_logger import close_tensorboard_logger, init_tensorboard_logger
config = self.config
# Validate that the user hasn't set both `resume` and `model_weights_file`.
# These are mutually exclusive: `resume` restores a full training checkpoint
# (model weights, optimizer state, epoch counter), while `model_weights_file`
# loads only model parameters and starts training fresh.
if config["train"]["resume"] and config["train"]["model_weights_file"]:
raise ValueError(
"Cannot set both `resume` and `model_weights_file` in the [train] config. "
"Use `resume` to continue from a full checkpoint (restores optimizer state and epoch), "
"or use `model_weights_file` to start fresh training from pre-trained weights."
)
# Create a results directory
results_dir = create_results_dir(config, "train")
log_runtime_config(config, results_dir)
# Create a tensorboardX logger
init_tensorboard_logger(log_dir=results_dir)
# Instantiate the model and dataset
dataset = setup_dataset(
config,
splits=Train.REQUIRED_DATA_GROUPS + Train.OPTIONAL_DATA_GROUPS,
)
model = setup_model(config, dataset["train"])
logger.info(
f"{Style.BRIGHT}{Fore.BLACK}{Back.GREEN}Training model:{Style.RESET_ALL} "
f"{model.__class__.__name__}"
)
logger.info(f"{Style.BRIGHT}{Fore.BLACK}{Back.GREEN}Training dataset(s):{Style.RESET_ALL}\n{dataset}")
# If a pre-trained weights file is specified, load it before creating the trainer.
# This must happen before create_trainer() wraps the model with idist.auto_model
# (the distributed wrapper) to avoid parameter key name mismatches.
if config["train"]["model_weights_file"]:
from hyrax.models.model_utils import load_model_weights
load_model_weights(config, model, "train")
logger.info(
f"{Style.BRIGHT}{Fore.BLACK}{Back.GREEN}Loading pre-trained weights:"
f"{Style.RESET_ALL} {config['train']['model_weights_file']}"
)
logger.info(
f"{Style.BRIGHT}{Fore.BLACK}{Back.GREEN}Fine-tuning mode:{Style.RESET_ALL} "
"Training will start from epoch 1 with a fresh optimizer."
)
# We know that `dataset` will always be returned as a dictionary with at least
# a `train` and `infer` key. There may be a `validate` key as well.
#
# There are three ways splits can be defined:
#
# 1) Separate dataset groups: the user defined distinct "train" and
# "validate" groups in their data_request (possibly pointing to
# different data_locations). Each DataProvider is loaded
# independently and we pass split=False to dist_data_loader.
#
# 2) split_fraction on shared data: the user defined "train" and
# "validate" groups pointing to the *same* data_location with
# split_fraction values. setup_dataset has already computed
# non-overlapping split_indices on each DataProvider, so
# dist_data_loader with split=False will automatically apply a
# SubsetSequentialSampler. This path is handled identically to (1).
#
# 3) Legacy percentage-based splits: only a "train" group exists and
# no split_fraction is set. We fall back to the old behaviour of
# calling dist_data_loader with split=["train", "validate"] which
# uses config["data_set"] train_size / validate_size.
# Collect split names in two ways:
# - all_splits: all split names that this verb knows about
# (required + optional), used for legacy percentage-based
# splitting where only a "train" group may be defined.
# - dataset_splits: those desired splits that are actually present
# in the dataset dict returned by setup_dataset, used by the
# multi-provider path where each split is an explicit group.
all_splits = list(Train.REQUIRED_DATA_GROUPS) + list(Train.OPTIONAL_DATA_GROUPS)
dataset_splits = [s for s in all_splits if s in dataset]
# Check whether split_fraction was used (path 2 above).
# This is true when the required split's DataProvider has split_indices assigned.
# Path 1 (separate groups without split_fraction) will be handled in the else block.
has_split_groups = isinstance(dataset, dict) and any(
hasattr(dataset.get(s), "split_indices") and dataset[s].split_indices is not None
for s in Train.REQUIRED_DATA_GROUPS
)
data_loaders: dict[str, tuple] = {}
if has_split_groups:
# Path 2: split_fraction was used — each DataProvider has split_indices.
# Create a dataloader per group with split_indices already applied.
# NOTE: Paths 1 and 3 will be completely deprecated in a future release,
# and this will be the only path for training.
for split_name in dataset_splits:
data_loaders[split_name] = dist_data_loader(dataset[split_name], config, False)
elif len(dataset) > 1:
# Path 1: separate dataset groups defined in data_request without split_fraction.
# Each group is an independent DataProvider pointing to different data_locations.
# Create a dataloader per group.
for split_name in dataset_splits:
data_loaders[split_name] = dist_data_loader(dataset[split_name], config, split_name)
else:
# Path 3 (legacy): only "train" exists — use percentage-based
# splitting from config["data_set"].
warnings.warn(
"Defining dataset splits via config['data_set'] "
"(train_size / validate_size / test_size) is deprecated. "
"Please define separate dataset groups with 'split_fraction' "
"in the [data_request] configuration instead. "
"See https://hyrax.readthedocs.io for migration guidance.",
DeprecationWarning,
stacklevel=1,
)
raw = dist_data_loader(dataset["train"], config, all_splits)
# dist_data_loader returns a bare (DataLoader, indices) tuple
# when given a single split name, or a dict when given multiple.
if isinstance(raw, dict):
for split_name in all_splits:
if split_name in raw:
data_loaders[split_name] = raw[split_name]
else:
# Single split — raw is already the (DataLoader, indices) tuple.
data_loaders[all_splits[0]] = raw
train_data_loader, _ = data_loaders["train"]
validation_data_loader, _ = data_loaders.get("validate", (None, None))
# Create trainer, a pytorch-ignite `Engine` object
trainer = create_trainer(model, config, results_dir)
# Create a validator if a validation data loader is available
if validation_data_loader is not None:
validator = create_validator(model, config, validation_data_loader, trainer)
attach_best_checkpoint(validator, model, trainer, results_dir)
else:
attach_best_checkpoint(trainer, model, trainer, results_dir)
monitor = GpuMonitor()
# Go up to the parent of the results dir so all mlflow results show up in the same directory.
results_root_dir = Path(config["general"]["results_dir"]).expanduser().resolve()
(results_root_dir / "mlflow").mkdir(parents=True, exist_ok=True)
mlflow.set_tracking_uri("sqlite:///" + str(results_root_dir / "mlflow" / "mlflow.db"))
# Get experiment_name and cast to string (it's a tomlkit.string by default)
experiment_name = str(config["train"]["experiment_name"])
# This will create the experiment if it doesn't exist
mlflow.set_experiment(experiment_name)
# If run_name is not `false` in the config, use it as the MLFlow run name in
# this experiment. Otherwise use the name of the results directory
run_name = str(config["train"]["run_name"]) if config["train"]["run_name"] else results_dir.name
with mlflow.start_run(log_system_metrics=True, run_name=run_name):
Train._log_params(config, results_dir)
# Run the training process
trainer.run(train_data_loader, max_epochs=config["train"]["epochs"])
# Save the trained model
model.save(results_dir / config["train"]["weights_filename"])
monitor.stop()
logger.info("Finished Training")
close_tensorboard_logger()
return model
@staticmethod
[docs]
def _log_params(config, results_dir):
"""Log the various parameters to mlflow from the config file.
Parameters
----------
config : dict
The main configuration dictionary
results_dir: str
The full path to the results sub-directory
"""
import mlflow
# Log full path to results subdirectory
mlflow.log_param("Results Directory", results_dir)
# Log all model params
mlflow.log_params(config["model"])
# Log some training and data loader params
mlflow.log_param("epochs", config["train"]["epochs"])
mlflow.log_param("batch_size", config["data_loader"]["batch_size"])
# Log the criterion and optimizer params
criterion_name = config["criterion"]["name"]
mlflow.log_param("criterion", criterion_name)
if criterion_name in config:
mlflow.log_params(config[criterion_name])
optimizer_name = config["optimizer"]["name"]
mlflow.log_param("optimizer", optimizer_name)
if optimizer_name in config:
mlflow.log_params(config[optimizer_name])