Source code for ablator.main.mp

import copy
import itertools
import multiprocessing as mp
import sys
import traceback
import types as tys
import typing as ty
import uuid
from collections import defaultdict
from functools import cached_property
from pathlib import Path
import numpy as np

import ray
import torch
from ablator.modules.loggers.file import RemoteFileLogger
from ablator.mp.gpu import GPUError

import ablator.utils.base as butils
from ablator.config.mp import ParallelConfig
from ablator.main.model.wrapper import ModelWrapper
from ablator.main.proto import ProtoTrainer
from ablator.main.state import ExperimentState, TrialState
from ablator.mp.cluster import ClusterManager
from ablator.mp.utils import get_node_ip, ray_init
from ablator.utils.progress_bar import RemoteDisplay, RemoteProgressBar
from ablator.mp.train_remote import train_main_remote
from ablator.config.types import Optional


[docs]class ParallelTrainer(ProtoTrainer): """ A class for parallelizing multiple training processes of models of different configurations with ray. Parameters ---------- wrapper : ModelWrapper The model wrapper for the ``ParallelTrainer``. run_config : ParallelConfig The runtime configuration for this trainer. Attributes ---------- run_config : ParallelConfig Running configuration for parallel training. logger : RemoteFileLogger A centralized logger that writes messages to a file and prints them to the console. experiment_state : ExperimentState This attribute manages optuna trials. gpu_manager : ty.Optional[GPUManager] A GPU manager that manages GPU resources in the cluster. available_resources : dict[str, Resource] A dictionary of available resources on each node. node_manager : NodeManager A node manager that manages nodes and their resources. ray_address : str The address of the ray cluster. total_trials : int Total number of trials to run. gpu_mem_bottleneck : int The minimum memory capacity of all available gpus. cpu : float The number of cpu used per trial. gpu : float The number of gpu used per trial. running_futures : dict[str, list] A dictionary with keys the Node IP and values a list of Ray remote tasks executing on the node aka `futures`. cluster_manager : ClusterManager The cluster manager responsible for scheduling tasks and managing resources Examples -------- Below is a complete workflow on how to launch a parallel experiment with ``ParallelTrainer``, from defining config, getting the model wrapper ready, to launching the experiment: - Define training config: >>> my_optimizer_config = OptimizerConfig("sgd", {"lr": 0.5, "weight_decay": 0.5}) >>> my_scheduler_config = SchedulerConfig("step", arguments={"step_size": 1, "gamma": 0.99}) >>> train_config = TrainConfig( ... dataset="[Dataset Name]", ... batch_size=32, ... epochs=10, ... optimizer_config = my_optimizer_config, ... scheduler_config = my_scheduler_config ... ) - Define model config, we want to run HPO on activation functions and model hidden size: >>> @configclass >>> class CustomModelConfig(ModelConfig): >>> hidden_size: int >>> activation: str >>> model_config = CustomModelConfig(hidden_size=100, activation="relu") - Define search space: >>> search_space = { ... "train_config.optimizer_config.arguments.lr": SearchSpace( ... value_range = [0.001, 0.01], ... value_type = 'float' ... ), ... "model_config.hidden_size": SearchSpace(value_range = [32, 64], value_type = 'int'), ... "model_config.activation": SearchSpace(categorical_values = ["relu", "elu", "leakyRelu"]), ... } - Define run config (remember to redefine the parallel config to update the model config type to be ``CustomModelConfig``): >>> @configclass >>> class CustomParallelConfig(ParallelConfig): ... model_config: CustomModelConfig >>> >>> parallel_config = CustomParallelConfig( ... train_config=train_config, ... model_config=model_config, ... metrics_n_batches = 800, ... experiment_dir = "/tmp/experiments/", ... device="cuda", ... amp=True, ... random_seed = 42, ... total_trials = 20, ... concurrent_trials = 3, ... search_space = search_space, ... optim_metrics = {"val_loss": "min"}, ... optim_metric_name = "val_loss", ... gpu_mb_per_experiment = 1024 ... ) - Create model wrapper: >>> class MyModelWrapper(ModelWrapper): >>> def __init__(self, *args, **kwargs): >>> super().__init__(*args, **kwargs) >>> >>> def make_dataloader_train(self, run_config: CustomParallelConfig): >>> return torch.utils.data.DataLoader(<train_dataset>, batch_size=32, shuffle=True) >>> >>> def make_dataloader_val(self, run_config: CustomParallelConfig): >>> return torch.utils.data.DataLoader(<val_dataset>, batch_size=32, shuffle=False) - After gathering all configurations and model wrapper, we can initialize and launch the parallel trainer: >>> wrapper = MyModelWrapper( ... model_class=<your_ModelModule_class>, ... ) >>> ablator = ParallelTrainer( ... wrapper=wrapper, ... run_config=parallel_config, ... ) >>> ablator.launch(working_directory = os.getcwd(), ray_head_address=None) """ def __init__(self, wrapper: ModelWrapper, run_config: ParallelConfig): # Initialize ``ParallelTrainer`` using config from ``run_config``. self.run_config: ParallelConfig super().__init__(wrapper=wrapper, run_config=run_config) assert issubclass(type(self.run_config), ParallelConfig), ( f"run_config must be of a type - { ParallelConfig.__name__} received" f" {type(self.run_config)}" ) assert issubclass(type(self.wrapper), ModelWrapper), ( f"wrapper must be of a type - { ModelWrapper.__name__} received" f" {self.wrapper}" ) self.logger: RemoteFileLogger self.experiment_state: ExperimentState self.total_trials: int | None self.ray_address: str self._progress_bar: ty.Optional[RemoteProgressBar] = None self._display: butils.Dummy | RemoteDisplay = butils.Dummy() self.running_futures: dict[str, list] = defaultdict(lambda: []) self.cluster_manager: ClusterManager @cached_property def _gpu(self) -> float: """ _gpu virtual number of GPUs used to schedule remotes on a GPU nodes. We handle GPU allocation internally. Returns ------- float mock gpu value i.e. 0.001 Raises ------ ValueError if the `gpu_mb_per_experiment` configuration is not specified when using `device='cuda'` """ device = butils.parse_device(self.run_config.device) if not device.startswith("cuda"): return 0 if self.run_config.gpu_mb_per_experiment is None: raise ValueError( "config attribute `gpu_mb_per_experiment` can not be `None` when" " device=`cuda`" ) return 0.001 @cached_property def _cpu(self) -> float: """ _cpu expected to be run AFTER _init_state as it requires the cluser to be initialized. it is used as a virtual number of `num_cpus` for ray while we handle resource allocation manually. Returns ------- float a virtual number of _cpus to use i.e. 0.001 """ if ( self.run_config.concurrent_trials is None or self.run_config.concurrent_trials > mp.cpu_count() ): self.logger.warn( "Expected CPU core util. can exceed system capacity" f" {mp.cpu_count()}.\nConsider adjusting `concurrent_trials`." ) return 0.01 def _make_remote( self, trial_id: int, run_config: ParallelConfig, node_ip: str, max_error_retries: int = 0, resume: bool = False, ): trial_uuid = f"{run_config.uid}_{str(uuid.uuid4())[:4]}" gpu, manager = (None, None) if self._gpu > 0: gpu, manager = self.cluster_manager.get_gpu( node_ip=node_ip, process_name=trial_uuid ) for node in ray.nodes(): if ( node["NodeManagerAddress"] == node_ip and "GPU" not in node["Resources"] ): raise RuntimeError("Misconfigured Ray cluster.") wrapper = copy.deepcopy(self.wrapper) # pylint: disable=protected-access wrapper._uid = trial_uuid model_obj = ray.put(wrapper) remote_fn = ray.remote( num_gpus=self._gpu, num_cpus=self._cpu, max_calls=1, max_retries=max_error_retries, )(train_main_remote).options( resources={f"node:{node_ip}": 0.001}, name=trial_uuid ) if node_ip == get_node_ip(): run_config.experiment_dir = (self.experiment_dir / trial_uuid).as_posix() elif run_config.remote_config is None: # NOTE this should never happen during normal use-case # the remote_config is automatically created on multi-node cluster # to be the head node of the cluster. raise RuntimeError( "Could not identify remote_config. Critical error encountered." " remote_config unspecified when scheduling remotes on multi-node" " cluster." ) else: run_config.experiment_dir = ( (Path("~") / "ablator").joinpath( *Path(run_config.remote_config.local_path).parts[1:] ) / trial_uuid ).as_posix() list_diffs = self.run_config.diff_str(run_config) diffs = "\n\t".join(list_diffs) action = "Scheduling" if resume is False else "Resuming" msg = ( f"{action} @ {node_ip} with uid: {trial_uuid}\nParameters:" f" \n\t{diffs}\n-----" ) self.logger.info(msg) self.experiment_state.update_trial_state(trial_id, None, TrialState.RUNNING) data_lock = butils.Lock() return remote_fn.remote( model=model_obj, run_config=copy.deepcopy(run_config), mp_logger=self.logger, resource_manager=manager, gpu=gpu, uid=trial_id, fault_tollerant=True, crash_exceptions_types=None, resume=resume, clean_reset=True, progress_bar=self._progress_bar, data_lock=data_lock, ) def _heartbeat(self): self._display.refresh(force=True) # pylint: disable=too-complex def _make_futures(self, soft_limit: int = 10) -> list: # make enough futures such that there are concurrent_trials running. concurrent_trial_limit: int | None = self.run_config.concurrent_trials gpu_util = self.run_config.gpu_mb_per_experiment if self._gpu > 0 else None starting_futures = np.array([len(v) for v in self.running_futures.values()]) def is_limit(node_ip: str | None = None): futures = np.array([len(v) for v in self.running_futures.values()]) return ( futures.sum() - starting_futures.sum() >= soft_limit or ( node_ip is not None and len(self.running_futures[node_ip]) > 0 and concurrent_trial_limit is not None and (futures >= concurrent_trial_limit).all() ) or ( self.total_trials is not None and len(self.experiment_state.valid_trials()) >= self.total_trials ) ) def interleaved_running_futures(): # interleaves the futures from all nodes that are running. return [ x for x in itertools.chain( *itertools.zip_longest(*self.running_futures.values()) ) if x is not None ] while not is_limit(): resources = self.cluster_manager.sorted_resources(gpu_mem=gpu_util) remote_config = self.cluster_manager.remote_config if len(resources) == 0: break for node_ip in resources: if is_limit(node_ip): return interleaved_running_futures() if ( concurrent_trial_limit is not None and len(self.running_futures[node_ip]) >= concurrent_trial_limit ): continue try: trial_id, trial = self.experiment_state.sample_trial() except StopIteration: self.logger.warn( "Received StopIteration signal, trial limit possibly reached" f" {self.total_trials}" ) return interleaved_running_futures() try: trial.remote_config = remote_config future = self._make_remote(trial_id, trial, node_ip) self.running_futures[node_ip].append(future) except GPUError: self.logger.warn(f"Not Enough GPU resources for {node_ip}.") continue return interleaved_running_futures()
[docs] def pre_train_setup(self): """ Used to prepare resources to avoid stalling during training or when resources are shared between trainers. """ mock_wrapper = copy.deepcopy(self.wrapper) mock_config = copy.deepcopy(self.run_config) mock_config.experiment_dir = None future = ( ray.remote( num_gpus=self._gpu, num_cpus=self._cpu, max_calls=1, max_retries=0, )( lambda wrapper: wrapper.init_state( run_config=mock_config, smoke_test=True, debug=True ) ) .options() .remote(ray.put(mock_wrapper)) ) ray.get(future)
@property def total_trials(self) -> Optional[int]: return self.run_config.total_trials @total_trials.setter def total_trials(self, value): self.run_config.total_trials = value def _init_ray( self, working_dir: str = "", address: str | None = None, modules: list[tys.ModuleType] | None = None, excluding_files: list[str] | None = None, verbose: ty.Literal["console", "progress", "silent"] = "silent", ): if excluding_files is None: excluding_files = [".git/**"] _is_ray_init = False if ray.is_initialized(): _is_ray_init = True ray_context = ray.get_runtime_context() self.ray_address = ray_context.gcs_address # TODO find a way to set-up runtime env on running cluster # NOTE this is because https://docs.ray.io/en/latest/ray-core/handling-dependencies.html # `Note: Setting options (1) and (3) per-task or per-actor is # currently unsupported, it can only be set per-job (i.e., in ray.init()).` else: runtime_env = { "working_dir": working_dir, "excludes": [".git"] + excluding_files, "py_modules": modules, } # pylint: disable=cyclic-import,import-outside-toplevel import ablator as ablator_module if modules is None: modules = [ablator_module] if ablator_module not in modules: modules.append(ablator_module) runtime_env["py_modules"] = modules ray_kwargs = { "log_to_driver": verbose == "console", "logging_level": "warning", "include_dashboard": True, # required for `list_nodes` function "address": address, "runtime_env": runtime_env, } ray_cluster = ray_init(**ray_kwargs) self.ray_address = ray_cluster.address_info["address"] return _is_ray_init def _init_state( self, working_dir: str = "", address: str | None = None, modules: list[tys.ModuleType] | None = None, resume: bool = False, excluding_files: list[str] | None = None, debug: bool = False, ): self.stop() verbose = self.run_config.verbose if self.experiment_dir.exists() and not resume: raise RuntimeError(f"Experiment Directory {self.experiment_dir} exists.") self._mount(resume=resume, debug=debug) _is_ray_init = self._init_ray( working_dir=working_dir, address=address, modules=modules, excluding_files=excluding_files, verbose=verbose, ) self.cluster_manager = ClusterManager( private_key_home=Path.home(), sync_directory=self.experiment_dir, ray_address=self.ray_address, remote_config=self.run_config.remote_config, ) self.logger = RemoteFileLogger( path=self.experiment_dir / "mp.log", verbose=verbose == "console" ) self.experiment_dir.joinpath("master_config.yaml").write_text( self.run_config.to_yaml(), encoding="utf-8" ) self.experiment_state = ExperimentState( self.experiment_dir, self.run_config, self.logger, resume=resume ) self.logger.to_remote() # TODO check if this causes an error because it was placed before the ray init if verbose == "progress": raise NotImplementedError( "verbose='progress' currently not supported for mp-training." ) if _is_ray_init: self.logger.warn( "Ray is already initialized. Can not start another instance. Unexpected" " behavior can occur. We recommend to perform `ray.shutdown()` or `ray" " stop` before starting the experiment. You can set 'address=\"local\"'" " on `.launch` to start another cluster." ) # first heartbeat <3 self._heartbeat() diffs = self._get_diffs(working_dir) self.logger.warn(diffs) # flake8: noqa: DOC201, DOC502 # pylint: disable=arguments-renamed,too-complex
[docs] def launch( # type: ignore[override] self, working_directory: str, auxilary_modules: list[tys.ModuleType] | None = None, ray_head_address: str | None = None, resume: bool = False, excluding_files: list[str] | None = None, debug: bool = False, ): """ Set up and launch the parallel ablation experiment. This sets up a ray cluster, and trials of different configuration initialized (or retrieved) will be pushed to the ray cluster to run in parallel. Parameters ---------- working_directory : str The working directory that stores codes and modules that will be used by ray. auxilary_modules : list[tys.ModuleType] | None A list of modules to be used as ray clusters' working environment. ray_head_address : str | None Ray cluster address. resume : bool Whether to resume training the model from existing checkpoints and existing experiment state, by default ``False``. excluding_files : list[str] | None A list of files in `.gitignore` format, that will be excluded from being uploaded to the ray cluster. If unspecified it ignores `.git/**` folder. debug : bool, optional Whether to train model in debug mode. By default ``False`` Raises ------ RuntimeError If the `config.experiment_id` is unspecified but resuming an experiment or the experiment directory is not empty but uses a remote storage configuration. """ try: torch.multiprocessing.set_start_method("spawn") mp.set_start_method("spawn", force=True) except RuntimeError: pass self._init_state( working_dir=working_directory, address=ray_head_address, modules=auxilary_modules, resume=resume, excluding_files=excluding_files, debug=debug, ) if debug: self.pre_train_setup() valid_trials = self.experiment_state.valid_trials() if self.total_trials is not None and len(valid_trials) >= self.total_trials: self.logger.error(f"Trial limit {self.total_trials} was reached. Exiting.") return futures = self._make_futures() metrics: dict[str, float] | None trial_state: TrialState heart_beat_interval = 1 while len(futures) > 0: # pylint: disable=broad-exception-caught try: done_id, futures = ray.wait( futures, num_returns=1, timeout=heart_beat_interval ) if len(done_id) > 0: done_future = done_id[0] for v in self.running_futures.values(): if done_future in v: v.remove(done_future) uid, metrics, trial_state = ray.get(done_future) self.experiment_state.update_trial_state(uid, metrics, trial_state) futures = self._make_futures() except KeyboardInterrupt: self.logger.warn("KeyboardInterrupt signal received.") self._print_summary() sys.exit(0) except StopIteration: # Reached maximum number of sample trials continue except Exception: exception = traceback.format_exc() self.logger.error(f"Unhandled Exception: {exception}") finally: self._heartbeat() self._print_summary()
def _print_summary(self): pending_trials = [ c.id for c in self.experiment_state.get_trials_by_state(TrialState.WAITING) + self.experiment_state.get_trials_by_state(TrialState.RUNNING) ] complete_trials = [ c.id for c in self.experiment_state.get_trials_by_state(TrialState.COMPLETE) ] errored_trials = [ c.id for c in self.experiment_state.get_trials_by_state(TrialState.FAIL) ] self.logger.info( f"There are {len(complete_trials)} complete trials. with ids:" f" {complete_trials}" ) if len(pending_trials) > 0: self.logger.warn( f"There are {len(pending_trials)} unfinished trials. with ids:" f" {pending_trials}" ) if len(errored_trials) > 0: self.logger.error( f"There are {len(errored_trials)} errored trials. with ids:" f" {errored_trials}" )
[docs] def stop(self): super().stop() if hasattr(self, "cluster_manager") and self.cluster_manager is not None: self.cluster_manager.stop()