Source code for ablator.modules.scheduler

import typing as ty
from abc import abstractmethod

from torch import nn
from torch.optim.lr_scheduler import OneCycleLR, ReduceLROnPlateau, StepLR, _LRScheduler
from torch.optim import Optimizer

from ablator.config.main import ConfigBase, Derived, configclass


Scheduler = ty.Union[_LRScheduler, ReduceLROnPlateau, ty.Any]

StepType = ty.Literal["train", "val", "epoch"]


[docs]@configclass class SchedulerArgs(ConfigBase): """ Abstract base class for defining arguments to initialize a learning rate scheduler. Attributes ---------- step_when : StepType The step type at which the scheduler.step() should be invoked: ``'train'``, ``'val'``, or ``'epoch'``. """ # step every train step or every validation step step_when: StepType
[docs] @abstractmethod def init_scheduler(self, model, optimizer): """ Abstract method to be implemented by derived classes, which creates and returns a scheduler object. """ raise NotImplementedError("init_optimizer method not implemented.")
[docs]@configclass class SchedulerConfig(ConfigBase): """ Class that defines a configuration for a learning rate scheduler. Attributes ---------- name : str The name of the scheduler. arguments : SchedulerArgs The arguments needed to initialize the scheduler. """ name: str arguments: SchedulerArgs
[docs] def __init__(self, name, arguments: dict[str, ty.Any]): """ Initializes the scheduler configuration. Parameters ---------- name : str The name of the scheduler, this can be any in ``['None', 'step', 'cycle', 'plateau']``. arguments : dict[str, ty.Any] The arguments for the scheduler, specific to a certain type of scheduler. Examples -------- In the following example, ``scheduler_config`` will initialize property ``arguments`` of type ``StepLRConfig``, setting ``step_size=1``, ``gamma=0.99`` as its properties. We also have access to ``init_scheduler()`` method of the property, which initalizes an StepLR scheduler. This method is actually called in ``make_scheduler()`` >>> scheduler_config = SchedulerConfig("step", arguments={"step_size": 1, "gamma": 0.99}) """ _arguments: None | StepLRConfig | OneCycleConfig | PlateuaConfig if (argument_cls := SCHEDULER_CONFIG_MAP[name]) is None: _arguments = StepLRConfig(gamma=1) else: _arguments = argument_cls(**arguments) super().__init__(name=name, arguments=_arguments)
[docs] def make_scheduler(self, model: nn.Module, optimizer: Optimizer) -> Scheduler: """ Creates a new scheduler for an optimizer, based on the configuration. Parameters ---------- model: nn.Module Some schedulers require information from the model. The model is passed as an argument. optimizer The optimizer used to update the model parameters, whose learning rate we want to monitor. Returns ------- Scheduler The scheduler. Examples -------- >>> scheduler_config = SchedulerConfig("step", arguments={"step_size": 1, "gamma": 0.99}) >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.7, momentum=0.9) >>> scheduler_config.make_scheduler(model, optimizer) """ return self.arguments.init_scheduler(model, optimizer)
[docs]@configclass class OneCycleConfig(SchedulerArgs): """ Configuration class for the OneCycleLR scheduler. Attributes ---------- max_lr : float Upper learning rate boundaries in the cycle. total_steps : Derived[int] The total number of steps to run the scheduler in a cycle. step_when : StepType The step type at which the scheduler.step() should be invoked: ``'train'``, ``'val'``, or ``'epoch'``. """ max_lr: float total_steps: Derived[int] step_when: StepType = "train"
[docs] def init_scheduler(self, model: nn.Module, optimizer: Optimizer): """ Initializes the OneCycleLR scheduler. Creates and returns a OneCycleLR scheduler that monitors optimizer's learning rate. Parameters ---------- model : nn.Module The model. optimizer : Optimizer The optimizer used to update the model parameters, whose learning rate we want to monitor. Returns ------- OneCycleLR The OneCycleLR scheduler, initialized with arguments defined as attributes of this class. Examples -------- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.7, momentum=0.9) >>> scheduler = OneCycleConfig(max_lr=0.5, total_steps=100) >>> scheduler.init_scheduler(model, optimizer) """ kwargs = self.to_dict() del kwargs["step_when"] return OneCycleLR(optimizer, **kwargs)
[docs]@configclass class PlateuaConfig(SchedulerArgs): """Configuration class for ReduceLROnPlateau scheduler. Attributes ---------- patience : int Number of epochs with no improvement after which learning rate will be reduced. min_lr : float A lower bound on the learning rate. mode : str One of ``'min'``, ``'max'``, or ``'auto'``, which defines the direction of optimization, so as to adjust the learning rate accordingly, i.e when a certain metric ceases improving. factor : float Factor by which the learning rate will be reduced. ``new_lr = lr * factor``. threshold : float Threshold for measuring the new optimum, to only focus on significant changes. verbose : bool If ``True``, prints a message to ``stdout`` for each update. step_when : StepType The step type at which the scheduler should be invoked: ``'train'``, ``'val'``, or ``'epoch'``. """ patience: int = 10 min_lr: float = 1e-5 mode: str = "min" factor: float = 0.0 # TODO {fixme} this is error prone -> new_lr = 0 threshold: float = 1e-4 verbose: bool = False # TODO fix mypy errors for custom types # type: ignore step_when: StepType = "val"
[docs] def init_scheduler(self, model: nn.Module, optimizer: Optimizer): """ Initialize the ReduceLROnPlateau scheduler. Parameters ---------- model : nn.Module The model being optimized. optimizer : Optimizer The optimizer used to update the model parameters, whose learning rate we want to monitor. Returns ------- ReduceLROnPlateau The ReduceLROnPlateau scheduler, initialized with arguments defined as attributes of this class. Examples -------- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.7, momentum=0.9) >>> scheduler = PlateuaConfig(min_lr=1e-7, mode='min') >>> scheduler.init_scheduler(model, optimizer) """ kwargs = self.to_dict() del kwargs["step_when"] return ReduceLROnPlateau(optimizer, **kwargs)
[docs]@configclass class StepLRConfig(SchedulerArgs): """ Configuration class for StepLR scheduler. Parameters ---------- step_size : int Period of learning rate decay, by default 1. gamma : float Multiplicative factor of learning rate decay, by default 0.99. step_when : StepType The step type at which the scheduler should be invoked: ``'train'``, ``'val'``, or ``'epoch'``. """ step_size: int = 1 gamma: float = 0.99 # TODO fix mypy errors for custom types # type: ignore step_when: StepType = "epoch"
[docs] def init_scheduler(self, model: nn.Module, optimizer: Optimizer): """ Initialize the StepLR scheduler for a given model and optimizer. Parameters ---------- model : nn.Module The model to apply the scheduler. optimizer : Optimizer The optimizer used to update the model parameters, whose learning rate we want to monitor. Returns ------- StepLR The StepLR scheduler, initialized with arguments defined as attributes of this class. Examples -------- >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.7, momentum=0.9) >>> scheduler = StepLRConfig(step_size=20, gamma=0.9) >>> scheduler.init_scheduler(model, optimizer) """ kwargs = self.to_dict() del kwargs["step_when"] return StepLR(optimizer, **kwargs)
SCHEDULER_CONFIG_MAP = { "none": None, "step": StepLRConfig, "cycle": OneCycleConfig, "plateau": PlateuaConfig, }