Source code for lightning_ir.schedulers.schedulers

  1"""Generic schedulers for LightningIR."""
  2
  3from abc import ABC, abstractmethod
  4from typing import Any, Dict, Sequence
  5
  6from lightning import Callback, LightningModule, Trainer
  7
  8from ..base import LightningIRModule
  9
 10# TODO add final value to all schedulers
 11# TODO add cosine decay scheduler
 12
 13
[docs] 14class LambdaWarmupScheduler(ABC):
[docs] 15 def __init__( 16 self, 17 num_warmup_steps: int, 18 num_delay_steps: int = 0, 19 *args, 20 **kwargs, 21 ) -> None: 22 """Base class for schedulers with warmup. 23 24 Args: 25 num_warmup_steps (int): Number of warmup steps. 26 num_delay_steps (int): Number of steps to delay the scheduler for. Defaults to 0. 27 """ 28 self.num_warmup_steps = num_warmup_steps 29 self.num_delay_steps = num_delay_steps 30 super().__init__(*args, **kwargs)
31
[docs] 32 @abstractmethod 33 def value_lambda(self, current_step: int) -> float: 34 """Lambda function to adjust the value at each step. 35 36 Args: 37 current_step (int): Current step. 38 Returns: 39 float: Value at the current step. 40 """ 41 ...
42 43 def _check_delay(self, current_step: int) -> bool: 44 return current_step < self.num_delay_steps 45 46 def _check_warmup(self, current_step: int) -> bool: 47 return current_step < self.num_warmup_steps + self.num_delay_steps
48 49
[docs] 50class LinearSchedulerWithLinearWarmup(LambdaWarmupScheduler): 51
[docs] 52 def __init__( 53 self, 54 num_warmup_steps: int, 55 num_training_steps: int, 56 *args, 57 final_value: float = 0.0, 58 num_delay_steps: int = 0, 59 **kwargs, 60 ) -> None: 61 """Scheduler for linearly decreasing values with linear warmup. 62 63 Args: 64 num_warmup_steps (int): Number of warmup steps. 65 num_training_steps (int): Number of training steps. 66 final_value (float, optional): Final value that should be reached at the end of decay. Defaults to 0.0. 67 num_delay_steps (int | None): Number of steps to delay warmup / decay. Defaults to 0. 68 """ 69 self.num_training_steps = num_training_steps 70 self.final_value = final_value 71 super().__init__(num_warmup_steps, num_delay_steps, *args, **kwargs)
72
[docs] 73 def value_lambda(self, current_step: int) -> float: 74 """Lambda function for linearly decreasing values with linear warmup. 75 76 Args: 77 current_step (int): Current step. 78 Returns: 79 float: Value at the current step. 80 """ 81 if self._check_delay(current_step): 82 return 0.0 83 if self._check_warmup(current_step): 84 return (current_step - self.num_delay_steps) / self.num_warmup_steps 85 current_step = current_step - self.num_delay_steps - self.num_warmup_steps 86 remaining_steps = self.num_training_steps - self.num_delay_steps - self.num_warmup_steps 87 step_size = (1 - self.final_value) / remaining_steps 88 return max(self.final_value, 1 - step_size * current_step)
89 90
[docs] 91class ConstantSchedulerWithLinearWarmup(LambdaWarmupScheduler):
[docs] 92 def value_lambda(self, current_step: int) -> float: 93 """Lambda function for no decay with linear warmup. 94 95 Args: 96 current_step (int): Current step. 97 Returns: 98 float: Value at the current step. 99 """ 100 if self._check_delay(current_step): 101 return 0.0 102 if self._check_warmup(current_step): 103 return (current_step - self.num_delay_steps) / self.num_warmup_steps 104 return 1.0
105 106
[docs] 107class ConstantSchedulerWithQuadraticWarmup(LambdaWarmupScheduler):
[docs] 108 def value_lambda(self, current_step: int) -> float: 109 """Lambda function for no decay with quadratic warmup. 110 111 Args: 112 current_step (int): Current step. 113 Returns: 114 float: Value at the current step. 115 """ 116 if self._check_delay(current_step): 117 return 0.0 118 if self._check_warmup(current_step): 119 return ((current_step - self.num_delay_steps) / self.num_warmup_steps) ** 2 120 return 1.0
121 122
[docs] 123class GenericScheduler(Callback, ABC): 124
[docs] 125 def __init__(self, *args, keys: Sequence[str] | None = None, **kwargs) -> None: 126 super().__init__(*args, **kwargs) 127 if keys is None: 128 raise ValueError("keys must be provided") 129 self.keys = keys 130 self.values: Dict[str, float] = {}
131 132 def step(self, key: str, current_step: int) -> float: 133 value = self.values[key] 134 return value * self.value_lambda(current_step) 135 136 @abstractmethod 137 def value_lambda(self, current_step: int) -> float: ... 138 139 def get_value(self, sub_keys: Sequence[str], obj: object) -> object: 140 for sub_key in sub_keys: 141 try: 142 obj = obj[int(sub_key)] 143 except ValueError: 144 obj = getattr(obj, sub_key) 145 return obj 146 147 def set_value(self, sub_keys: Sequence[str], obj: object, value: float) -> None: 148 obj = self.get_value(sub_keys[:-1], obj) 149 setattr(obj, sub_keys[-1], value) 150
[docs] 151 def on_train_start(self, trainer: Trainer, pl_module: LightningIRModule) -> None: 152 for key in self.keys: 153 sub_keys = key.split(".") 154 self.values[key] = float(self.get_value(sub_keys, pl_module))
155
[docs] 156 def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None: 157 step = trainer.global_step + 1 158 for key in self.keys: 159 value = self.step(key, step) 160 sub_keys = key.split(".") 161 self.set_value(sub_keys, pl_module, value)
162
[docs] 163 def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: 164 for key in self.keys: 165 value = self.values[key] 166 sub_keys = key.split(".") 167 self.set_value(sub_keys, pl_module, value)
168 169
[docs] 170class GenericLinearSchedulerWithLinearWarmup(LinearSchedulerWithLinearWarmup, GenericScheduler): 171 pass
172 173
[docs] 174class GenericConstantSchedulerWithLinearWarmup(ConstantSchedulerWithLinearWarmup, GenericScheduler): 175 pass
176 177
[docs] 178class GenericConstantSchedulerWithQuadraticWarmup(ConstantSchedulerWithQuadraticWarmup, GenericScheduler): 179 pass