Source code for lightning_ir.schedulers.schedulers

  1"""Generic schedulers for LightningIR."""
  2
  3from abc import ABC, abstractmethod
  4from collections.abc import Sequence
  5from typing import Any
  6
  7from lightning import Callback, LightningModule, Trainer
  8
  9from ..base import LightningIRModule
 10
 11# TODO add final value to all schedulers
 12# TODO add cosine decay scheduler
 13
 14
[docs] 15class LambdaWarmupScheduler(ABC):
[docs] 16 def __init__( 17 self, 18 num_warmup_steps: int, 19 num_delay_steps: int = 0, 20 *args, 21 **kwargs, 22 ) -> None: 23 """Base class for schedulers with warmup. 24 25 Args: 26 num_warmup_steps (int): Number of warmup steps. 27 num_delay_steps (int): Number of steps to delay the scheduler for. Defaults to 0. 28 """ 29 self.num_warmup_steps = num_warmup_steps 30 self.num_delay_steps = num_delay_steps 31 super().__init__(*args, **kwargs)
32
[docs] 33 @abstractmethod 34 def value_lambda(self, current_step: int) -> float: 35 """Lambda function to adjust the value at each step. 36 37 Args: 38 current_step (int): Current step. 39 Returns: 40 float: Value at the current step. 41 """ 42 ...
43 44 def _check_delay(self, current_step: int) -> bool: 45 return current_step < self.num_delay_steps 46 47 def _check_warmup(self, current_step: int) -> bool: 48 return current_step < self.num_warmup_steps + self.num_delay_steps
49 50
[docs] 51class LinearSchedulerWithLinearWarmup(LambdaWarmupScheduler):
[docs] 52 def __init__( 53 self, 54 num_warmup_steps: int, 55 num_training_steps: int, 56 *args, 57 final_value: float = 0.0, 58 num_delay_steps: int = 0, 59 **kwargs, 60 ) -> None: 61 """Scheduler for linearly decreasing values with linear warmup. 62 63 Args: 64 num_warmup_steps (int): Number of warmup steps. 65 num_training_steps (int): Number of training steps. 66 final_value (float, optional): Final value that should be reached at the end of decay. Defaults to 0.0. 67 num_delay_steps (int | None): Number of steps to delay warmup / decay. Defaults to 0. 68 """ 69 self.num_training_steps = num_training_steps 70 self.final_value = final_value 71 super().__init__(num_warmup_steps, num_delay_steps, *args, **kwargs)
72
[docs] 73 def value_lambda(self, current_step: int) -> float: 74 """Lambda function for linearly decreasing values with linear warmup. 75 76 Args: 77 current_step (int): Current step. 78 Returns: 79 float: Value at the current step. 80 """ 81 if self._check_delay(current_step): 82 return 0.0 83 if self._check_warmup(current_step): 84 return (current_step - self.num_delay_steps) / self.num_warmup_steps 85 current_step = current_step - self.num_delay_steps - self.num_warmup_steps 86 remaining_steps = self.num_training_steps - self.num_delay_steps - self.num_warmup_steps 87 step_size = (1 - self.final_value) / remaining_steps 88 return max(self.final_value, 1 - step_size * current_step)
89 90
[docs] 91class ConstantSchedulerWithLinearWarmup(LambdaWarmupScheduler):
[docs] 92 def value_lambda(self, current_step: int) -> float: 93 """Lambda function for no decay with linear warmup. 94 95 Args: 96 current_step (int): Current step. 97 Returns: 98 float: Value at the current step. 99 """ 100 if self._check_delay(current_step): 101 return 0.0 102 if self._check_warmup(current_step): 103 return (current_step - self.num_delay_steps) / self.num_warmup_steps 104 return 1.0
105 106
[docs] 107class ConstantSchedulerWithQuadraticWarmup(LambdaWarmupScheduler):
[docs] 108 def value_lambda(self, current_step: int) -> float: 109 """Lambda function for no decay with quadratic warmup. 110 111 Args: 112 current_step (int): Current step. 113 Returns: 114 float: Value at the current step. 115 """ 116 if self._check_delay(current_step): 117 return 0.0 118 if self._check_warmup(current_step): 119 return ((current_step - self.num_delay_steps) / self.num_warmup_steps) ** 2 120 return 1.0
121 122
[docs] 123class GenericScheduler(Callback, ABC):
[docs] 124 def __init__(self, *args, keys: Sequence[str] | None = None, **kwargs) -> None: 125 super().__init__(*args, **kwargs) 126 if keys is None: 127 raise ValueError("keys must be provided") 128 self.keys = keys 129 self.values: dict[str, float] = {}
130 131 def step(self, key: str, current_step: int) -> float: 132 value = self.values[key] 133 return value * self.value_lambda(current_step) 134 135 @abstractmethod 136 def value_lambda(self, current_step: int) -> float: ... 137 138 def get_value(self, sub_keys: Sequence[str], obj: object) -> object: 139 for sub_key in sub_keys: 140 try: 141 obj = obj[int(sub_key)] 142 except ValueError: 143 obj = getattr(obj, sub_key) 144 return obj 145 146 def set_value(self, sub_keys: Sequence[str], obj: object, value: float) -> None: 147 obj = self.get_value(sub_keys[:-1], obj) 148 setattr(obj, sub_keys[-1], value) 149
[docs] 150 def on_train_start(self, trainer: Trainer, pl_module: LightningIRModule) -> None: 151 for key in self.keys: 152 sub_keys = key.split(".") 153 self.values[key] = float(self.get_value(sub_keys, pl_module))
154
[docs] 155 def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None: 156 step = trainer.global_step + 1 157 for key in self.keys: 158 value = self.step(key, step) 159 sub_keys = key.split(".") 160 self.set_value(sub_keys, pl_module, value)
161
[docs] 162 def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: 163 for key in self.keys: 164 value = self.values[key] 165 sub_keys = key.split(".") 166 self.set_value(sub_keys, pl_module, value)
167 168
[docs] 169class GenericLinearSchedulerWithLinearWarmup(LinearSchedulerWithLinearWarmup, GenericScheduler): 170 pass
171 172
[docs] 173class GenericConstantSchedulerWithLinearWarmup(ConstantSchedulerWithLinearWarmup, GenericScheduler): 174 pass
175 176
[docs] 177class GenericConstantSchedulerWithQuadraticWarmup(ConstantSchedulerWithQuadraticWarmup, GenericScheduler): 178 pass