Source code for hyrax.verbs.train

import logging
from pathlib import Path

from colorama import Back, Fore, Style

from hyrax.trace import trace_verb_data

from .verb_registry import Verb, hyrax_verb

[docs] logger = logging.getLogger(__name__)
@hyrax_verb
[docs] class Train(Verb): """Train verb"""
[docs] cli_name = "train"
[docs] add_parser_kwargs = {}
[docs] description = "Train a model using provided data."
# Dataset groups that the Train verb knows about. # REQUIRED_DATA_GROUPS must be present in the dataset dict returned by setup_dataset. # OPTIONAL_DATA_GROUPS are used when present but do not cause an error if absent.
[docs] REQUIRED_DATA_GROUPS = ("train",)
[docs]
[docs] OPTIONAL_DATA_GROUPS = ("validate", "test")
@staticmethod
[docs] def setup_parser(parser): """We don't need any parser setup for CLI opts""" pass
[docs] def run_cli(self, args=None):
"""CLI stub for Train verb""" logger.info("train run from CLI.") self.run()
@trace_verb_data
[docs] def run(self): """ Run the training process for the configured model and data loader. Returns the trained model. """ import mlflow from hyrax.config_utils import create_results_dir, log_runtime_config from hyrax.gpu_monitor import GpuMonitor from hyrax.pytorch_ignite import ( Events, attach_best_checkpoint, create_trainer, create_validator, dist_data_loader, setup_dataset, setup_model, ) from hyrax.splitting_utils import create_splits from hyrax.tensorboardx_logger import close_tensorboard_logger, init_tensorboard_logger config = self.config # Validate that the user hasn't set both `resume` and `model_weights_file`. # These are mutually exclusive: `resume` restores a full training checkpoint # (model weights, optimizer state, epoch counter), while `model_weights_file` # loads only model parameters and starts training fresh. if config["train"]["resume"] and config["train"]["model_weights_file"]: raise ValueError( "Cannot set both `resume` and `model_weights_file` in the [train] config. " "Use `resume` to continue from a full checkpoint (restores optimizer state and epoch), " "or use `model_weights_file` to start fresh training from pre-trained weights." ) # Create a results directory results_dir = create_results_dir(config, "train") log_runtime_config(config, results_dir) # Create a tensorboardX logger init_tensorboard_logger(log_dir=results_dir) # Instantiate the model and dataset dataset = setup_dataset( config, splits=Train.REQUIRED_DATA_GROUPS + Train.OPTIONAL_DATA_GROUPS, ) create_splits(config, dataset, results_dir=results_dir, persist=True) model = setup_model(config, dataset["train"]) logger.info( f"{Style.BRIGHT}{Fore.BLACK}{Back.GREEN}Training model:{Style.RESET_ALL} " f"{model.__class__.__name__}" ) logger.info(f"{Style.BRIGHT}{Fore.BLACK}{Back.GREEN}Training dataset(s):{Style.RESET_ALL}\n{dataset}") # If a pre-trained weights file is specified, load it before creating the trainer. # This must happen before create_trainer() wraps the model with idist.auto_model # (the distributed wrapper) to avoid parameter key name mismatches. if config["train"]["model_weights_file"]: from hyrax.models.model_utils import load_model_weights load_model_weights(config, model, "train") logger.info( f"{Style.BRIGHT}{Fore.BLACK}{Back.GREEN}Loading pre-trained weights:" f"{Style.RESET_ALL} {config['train']['model_weights_file']}" ) logger.info( f"{Style.BRIGHT}{Fore.BLACK}{Back.GREEN}Fine-tuning mode:{Style.RESET_ALL} " "Training will start from epoch 1 with a fresh optimizer." ) train_shuffle = config["train"]["shuffle"] dataset_splits = [s for s in Train.REQUIRED_DATA_GROUPS + Train.OPTIONAL_DATA_GROUPS if s in dataset] data_loaders = {} for split_name in dataset_splits: data_loaders[split_name] = dist_data_loader( dataset[split_name], config, shuffle=split_name == "train" and train_shuffle, ) train_data_loader = data_loaders["train"] validation_data_loader = data_loaders.get("validate") # Create trainer, a pytorch-ignite `Engine` object trainer = create_trainer(model, config, results_dir) # Dispatch on_epoch_start to all DataProviders at the start of each epoch. @trainer.on(Events.EPOCH_STARTED) def dispatch_epoch_start(engine): for provider in dataset.values(): provider.on_epoch_start("train") # Create a validator if a validation data loader is available if validation_data_loader is not None: validator = create_validator(model, config, validation_data_loader, trainer) attach_best_checkpoint(validator, model, trainer, results_dir) else: attach_best_checkpoint(trainer, model, trainer, results_dir) monitor = GpuMonitor() # Go up to the parent of the results dir so all mlflow results show up in the same directory. results_root_dir = Path(config["general"]["results_dir"]).expanduser().resolve() (results_root_dir / "mlflow").mkdir(parents=True, exist_ok=True) mlflow.set_tracking_uri("sqlite:///" + str(results_root_dir / "mlflow" / "mlflow.db")) # Get experiment_name and cast to string (it's a tomlkit.string by default) experiment_name = str(config["train"]["experiment_name"]) # This will create the experiment if it doesn't exist mlflow.set_experiment(experiment_name) # If run_name is not `false` in the config, use it as the MLFlow run name in # this experiment. Otherwise use the name of the results directory run_name = str(config["train"]["run_name"]) if config["train"]["run_name"] else results_dir.name with mlflow.start_run(log_system_metrics=True, run_name=run_name): Train._log_params(config, results_dir) # Run the training process trainer.run(train_data_loader, max_epochs=config["train"]["epochs"]) # Save the trained model model.save(results_dir / config["train"]["weights_filename"]) monitor.stop() logger.info("Finished Training") close_tensorboard_logger() return model
@staticmethod
[docs] def _log_params(config, results_dir): """Log the various parameters to mlflow from the config file. Parameters ---------- config : dict The main configuration dictionary results_dir: str The full path to the results sub-directory """ import mlflow # Log full path to results subdirectory mlflow.log_param("Results Directory", results_dir) # Log all model params mlflow.log_params(config["model"]) # Log some training and data loader params mlflow.log_param("epochs", config["train"]["epochs"]) mlflow.log_param("batch_size", config["data_loader"]["batch_size"]) # Log the criterion and optimizer params criterion_name = config["criterion"]["name"] mlflow.log_param("criterion", criterion_name) if criterion_name in config: mlflow.log_params(config[criterion_name]) optimizer_name = config["optimizer"]["name"] mlflow.log_param("optimizer", optimizer_name) if optimizer_name in config: mlflow.log_params(config[optimizer_name])