Source code for lightning_ir.schedulers.schedulers

  1"""Generic schedulers for LightningIR."""
  2
  3from abc import ABC, abstractmethod
  4from typing import Any, Dict, Sequence
  5
  6from lightning import Callback, LightningModule, Trainer
  7
  8from ..base import LightningIRModule
  9
 10# TODO add final value to all schedulers
 11# TODO add cosine decay scheduler
 12
 13
[docs] 14class LambdaWarmupScheduler(ABC):
[docs] 15 def __init__( 16 self, 17 num_warmup_steps: int, 18 num_delay_steps: int = 0, 19 *args, 20 **kwargs, 21 ) -> None: 22 """Base class for schedulers with warmup. 23 24 :param num_warmup_steps: Number of warmup steps 25 :type num_warmup_steps: int 26 :param num_delay_steps: Number of steps to delay scheduler for, defaults to 0 27 :type num_delay_steps: int, optional 28 """ 29 self.num_warmup_steps = num_warmup_steps 30 self.num_delay_steps = num_delay_steps 31 super().__init__(*args, **kwargs)
32
[docs] 33 @abstractmethod 34 def value_lambda(self, current_step: int) -> float: 35 """Lambda function to adjust the value at each step. 36 37 :param current_step: Current step 38 :type current_step: int 39 :return: Value at the current step 40 :rtype: float 41 """ 42 ...
43 44 def _check_delay(self, current_step: int) -> bool: 45 return current_step < self.num_delay_steps 46 47 def _check_warmup(self, current_step: int) -> bool: 48 return current_step < self.num_warmup_steps + self.num_delay_steps
49 50
[docs] 51class LinearSchedulerWithLinearWarmup(LambdaWarmupScheduler): 52
[docs] 53 def __init__( 54 self, 55 num_warmup_steps: int, 56 num_training_steps: int, 57 *args, 58 final_value: float = 0.0, 59 num_delay_steps: int = 0, 60 **kwargs, 61 ) -> None: 62 """Scheduler for linearly decreasing values with linear warmup. 63 64 :param num_warmup_steps: Number of warmup steps 65 :type num_warmup_steps: int 66 :param num_training_steps: Number of training steps 67 :type num_training_steps: int 68 :param final_value: The final value that should be reached at the end of decay, defaults to 0.0 69 :type final_value: float, optional 70 :param num_delay_steps: Number of steps to delay warmup / decay, defaults to 0 71 :type num_delay_steps: int, optional 72 """ 73 self.num_training_steps = num_training_steps 74 self.final_value = final_value 75 super().__init__(num_warmup_steps, num_delay_steps, *args, **kwargs)
76
[docs] 77 def value_lambda(self, current_step: int) -> float: 78 """Lambda function for linearly decreasing values with linear warmup. 79 80 :param current_step: Current step 81 :type current_step: int 82 :return: Value at the current step 83 :rtype: float 84 """ 85 if self._check_delay(current_step): 86 return 0.0 87 if self._check_warmup(current_step): 88 return (current_step - self.num_delay_steps) / self.num_warmup_steps 89 current_step = current_step - self.num_delay_steps - self.num_warmup_steps 90 remaining_steps = self.num_training_steps - self.num_delay_steps - self.num_warmup_steps 91 step_size = (1 - self.final_value) / remaining_steps 92 return max(self.final_value, 1 - step_size * current_step)
93 94
[docs] 95class ConstantSchedulerWithLinearWarmup(LambdaWarmupScheduler):
[docs] 96 def value_lambda(self, current_step: int) -> float: 97 """Lambda function for no decay with linear warmup. 98 99 :param current_step: Current step 100 :type current_step: int 101 :return: Value at the current step 102 :rtype: float 103 """ 104 if self._check_delay(current_step): 105 return 0.0 106 if self._check_warmup(current_step): 107 return (current_step - self.num_delay_steps) / self.num_warmup_steps 108 return 1.0
109 110
[docs] 111class ConstantSchedulerWithQuadraticWarmup(LambdaWarmupScheduler):
[docs] 112 def value_lambda(self, current_step: int) -> float: 113 """Lambda function for no decay with quadratic warmup. 114 115 :param current_step: Current step 116 :type current_step: int 117 :return: Value at the current step 118 :rtype: float 119 """ 120 if self._check_delay(current_step): 121 return 0.0 122 if self._check_warmup(current_step): 123 return ((current_step - self.num_delay_steps) / self.num_warmup_steps) ** 2 124 return 1.0
125 126
[docs] 127class GenericScheduler(Callback, ABC): 128
[docs] 129 def __init__(self, *args, keys: Sequence[str] | None = None, **kwargs) -> None: 130 super().__init__(*args, **kwargs) 131 if keys is None: 132 raise ValueError("keys must be provided") 133 self.keys = keys 134 self.values: Dict[str, float] = {}
135 136 def step(self, key: str, current_step: int) -> float: 137 value = self.values[key] 138 return value * self.value_lambda(current_step) 139 140 @abstractmethod 141 def value_lambda(self, current_step: int) -> float: ... 142 143 def get_value(self, sub_keys: Sequence[str], obj: object) -> object: 144 for sub_key in sub_keys: 145 try: 146 obj = obj[int(sub_key)] 147 except ValueError: 148 obj = getattr(obj, sub_key) 149 return obj 150 151 def set_value(self, sub_keys: Sequence[str], obj: object, value: float) -> None: 152 obj = self.get_value(sub_keys[:-1], obj) 153 setattr(obj, sub_keys[-1], value) 154
[docs] 155 def on_train_start(self, trainer: Trainer, pl_module: LightningIRModule) -> None: 156 for key in self.keys: 157 sub_keys = key.split(".") 158 self.values[key] = float(self.get_value(sub_keys, pl_module))
159
[docs] 160 def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None: 161 step = trainer.global_step + 1 162 for key in self.keys: 163 value = self.step(key, step) 164 sub_keys = key.split(".") 165 self.set_value(sub_keys, pl_module, value)
166
[docs] 167 def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: 168 for key in self.keys: 169 value = self.values[key] 170 sub_keys = key.split(".") 171 self.set_value(sub_keys, pl_module, value)
172 173
[docs] 174class GenericLinearSchedulerWithLinearWarmup(LinearSchedulerWithLinearWarmup, GenericScheduler): 175 pass
176 177
[docs] 178class GenericConstantSchedulerWithLinearWarmup(ConstantSchedulerWithLinearWarmup, GenericScheduler): 179 pass
180 181
[docs] 182class GenericConstantSchedulerWithQuadraticWarmup(ConstantSchedulerWithQuadraticWarmup, GenericScheduler): 183 pass