1"""Generic schedulers for LightningIR."""
2
3from abc import ABC, abstractmethod
4from typing import Any, Dict, Sequence
5
6from lightning import Callback, LightningModule, Trainer
7
8from ..base import LightningIRModule
9
10# TODO add final value to all schedulers
11# TODO add cosine decay scheduler
12
13
[docs]
14class LambdaWarmupScheduler(ABC):
[docs]
15 def __init__(
16 self,
17 num_warmup_steps: int,
18 num_delay_steps: int = 0,
19 *args,
20 **kwargs,
21 ) -> None:
22 """Base class for schedulers with warmup.
23
24 :param num_warmup_steps: Number of warmup steps
25 :type num_warmup_steps: int
26 :param num_delay_steps: Number of steps to delay scheduler for, defaults to 0
27 :type num_delay_steps: int, optional
28 """
29 self.num_warmup_steps = num_warmup_steps
30 self.num_delay_steps = num_delay_steps
31 super().__init__(*args, **kwargs)
32
[docs]
33 @abstractmethod
34 def value_lambda(self, current_step: int) -> float:
35 """Lambda function to adjust the value at each step.
36
37 :param current_step: Current step
38 :type current_step: int
39 :return: Value at the current step
40 :rtype: float
41 """
42 ...
43
44 def _check_delay(self, current_step: int) -> bool:
45 return current_step < self.num_delay_steps
46
47 def _check_warmup(self, current_step: int) -> bool:
48 return current_step < self.num_warmup_steps + self.num_delay_steps
49
50
[docs]
51class LinearSchedulerWithLinearWarmup(LambdaWarmupScheduler):
52
[docs]
53 def __init__(
54 self,
55 num_warmup_steps: int,
56 num_training_steps: int,
57 *args,
58 final_value: float = 0.0,
59 num_delay_steps: int = 0,
60 **kwargs,
61 ) -> None:
62 """Scheduler for linearly decreasing values with linear warmup.
63
64 :param num_warmup_steps: Number of warmup steps
65 :type num_warmup_steps: int
66 :param num_training_steps: Number of training steps
67 :type num_training_steps: int
68 :param final_value: The final value that should be reached at the end of decay, defaults to 0.0
69 :type final_value: float, optional
70 :param num_delay_steps: Number of steps to delay warmup / decay, defaults to 0
71 :type num_delay_steps: int, optional
72 """
73 self.num_training_steps = num_training_steps
74 self.final_value = final_value
75 super().__init__(num_warmup_steps, num_delay_steps, *args, **kwargs)
76
[docs]
77 def value_lambda(self, current_step: int) -> float:
78 """Lambda function for linearly decreasing values with linear warmup.
79
80 :param current_step: Current step
81 :type current_step: int
82 :return: Value at the current step
83 :rtype: float
84 """
85 if self._check_delay(current_step):
86 return 0.0
87 if self._check_warmup(current_step):
88 return (current_step - self.num_delay_steps) / self.num_warmup_steps
89 current_step = current_step - self.num_delay_steps - self.num_warmup_steps
90 remaining_steps = self.num_training_steps - self.num_delay_steps - self.num_warmup_steps
91 step_size = (1 - self.final_value) / remaining_steps
92 return max(self.final_value, 1 - step_size * current_step)
93
94
[docs]
95class ConstantSchedulerWithLinearWarmup(LambdaWarmupScheduler):
[docs]
96 def value_lambda(self, current_step: int) -> float:
97 """Lambda function for no decay with linear warmup.
98
99 :param current_step: Current step
100 :type current_step: int
101 :return: Value at the current step
102 :rtype: float
103 """
104 if self._check_delay(current_step):
105 return 0.0
106 if self._check_warmup(current_step):
107 return (current_step - self.num_delay_steps) / self.num_warmup_steps
108 return 1.0
109
110
[docs]
111class ConstantSchedulerWithQuadraticWarmup(LambdaWarmupScheduler):
[docs]
112 def value_lambda(self, current_step: int) -> float:
113 """Lambda function for no decay with quadratic warmup.
114
115 :param current_step: Current step
116 :type current_step: int
117 :return: Value at the current step
118 :rtype: float
119 """
120 if self._check_delay(current_step):
121 return 0.0
122 if self._check_warmup(current_step):
123 return ((current_step - self.num_delay_steps) / self.num_warmup_steps) ** 2
124 return 1.0
125
126
[docs]
127class GenericScheduler(Callback, ABC):
128
[docs]
129 def __init__(self, *args, keys: Sequence[str] | None = None, **kwargs) -> None:
130 super().__init__(*args, **kwargs)
131 if keys is None:
132 raise ValueError("keys must be provided")
133 self.keys = keys
134 self.values: Dict[str, float] = {}
135
136 def step(self, key: str, current_step: int) -> float:
137 value = self.values[key]
138 return value * self.value_lambda(current_step)
139
140 @abstractmethod
141 def value_lambda(self, current_step: int) -> float: ...
142
143 def get_value(self, sub_keys: Sequence[str], obj: object) -> object:
144 for sub_key in sub_keys:
145 try:
146 obj = obj[int(sub_key)]
147 except ValueError:
148 obj = getattr(obj, sub_key)
149 return obj
150
151 def set_value(self, sub_keys: Sequence[str], obj: object, value: float) -> None:
152 obj = self.get_value(sub_keys[:-1], obj)
153 setattr(obj, sub_keys[-1], value)
154
[docs]
155 def on_train_start(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
156 for key in self.keys:
157 sub_keys = key.split(".")
158 self.values[key] = float(self.get_value(sub_keys, pl_module))
159
[docs]
160 def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None:
161 step = trainer.global_step + 1
162 for key in self.keys:
163 value = self.step(key, step)
164 sub_keys = key.split(".")
165 self.set_value(sub_keys, pl_module, value)
166
[docs]
167 def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
168 for key in self.keys:
169 value = self.values[key]
170 sub_keys = key.split(".")
171 self.set_value(sub_keys, pl_module, value)
172
173
[docs]
174class GenericLinearSchedulerWithLinearWarmup(LinearSchedulerWithLinearWarmup, GenericScheduler):
175 pass
176
177
[docs]
178class GenericConstantSchedulerWithLinearWarmup(ConstantSchedulerWithLinearWarmup, GenericScheduler):
179 pass
180
181
[docs]
182class GenericConstantSchedulerWithQuadraticWarmup(ConstantSchedulerWithQuadraticWarmup, GenericScheduler):
183 pass