Prototype Trainer#

class ablator.main.proto.ProtoTrainer(wrapper: ModelWrapper, run_config: RunConfig)[source]

Manages resources for Prototyping. This trainer runs an experiment of a single prototype model (Therefore no ablation study nor HPO).

Parameters:
wrapperModelWrapper

The main model wrapper.

run_configRunConfig

Running configuration for the model.

Raises:
RuntimeError

If the experiment directory is not defined in the running configuration.

Examples

Below is a complete workflow on how to launch a prototype experiment with ProtoTrainer, from defining the config to launching the experiment:

  • Define training config:

>>> my_optimizer_config = OptimizerConfig("sgd", {"lr": 0.5, "weight_decay": 0.5})
>>> my_scheduler_config = SchedulerConfig("step", arguments={"step_size": 1, "gamma": 0.99})
>>> train_config = TrainConfig(
...     dataset="[Dataset Name]",
...     batch_size=32,
...     epochs=10,
...     optimizer_config = my_optimizer_config,
...     scheduler_config = my_scheduler_config
... )
  • Define model config: we use the default one with no custom hyperparameters (sometimes you would want to customize it to run ablation study/ HPO on the model’s hyperparameters in a parallel experiment, which needs ParallelTrainer and ParallelConfig instead of ProtoTrainer and RunConfig):

>>> model_config = ModelConfig()
  • Define run config:

>>> run_config = RunConfig(
...     train_config=train_config,
...     model_config=model_config,
...     metrics_n_batches = 800,
...     experiment_dir = "/tmp/experiments",
...     device="cpu",
...     amp=False,
...     random_seed = 42
... )
  • Create model wrapper:

>>> class MyModelWrapper(ModelWrapper):
>>>     def __init__(self, *args, **kwargs):
>>>         super().__init__(*args, **kwargs)
>>>
>>>     def make_dataloader_train(self, run_config: RunConfig):
>>>         return torch.utils.data.DataLoader(<train_dataset>, batch_size=32, shuffle=True)
>>>
>>>     def make_dataloader_val(self, run_config: RunConfig):
>>>         return torch.utils.data.DataLoader(<val_dataset>, batch_size=32, shuffle=False)
  • After gathering all configurations and model wrapper, it’s time we initialize and launch the prototype trainer. When launching the experiment, we must provide a working directory, which points to a git repository that is used for keeping track of the code differences:

>>> wrapper = MyModelWrapper(
...     model_class=<your_ModelModule_class>,
... )
>>> ablator = ProtoTrainer(
...     wrapper=wrapper,
...     run_config=run_config,
... )
>>> metrics = ablator.launch(working_directory=os.getcwd())  # suppose current directory is tracked by git
Attributes:
wrapperModelWrapper

The main model wrapper.

run_configRunConfig

Running configuration for the model.

experiment_dirPath

The path object to the experiment directory.

launch(working_directory: str, resume: bool = False, debug: bool = False) dict[str, float][source]

Launch the prototype experiment (train, evaluate the single prototype model) and return metrics.

Parameters:
working_directorystr

The working directory points to a git repository that is used for keeping track of the code differences.

resumebool

Whether to resume training the model from existing checkpoints and existing experiment state. By default False

debugbool, optional

Whether to train models in debug mode, by default False.

Returns:
dict[str, float]

Metrics returned after training.

Raises:
RuntimeError

If the config.experiment_id is unspecified but resuming an experiment or the experiment directory is not empty but using a remote storage configuration.