1"""Generic schedulers for LightningIR."""
2
3from abc import ABC, abstractmethod
4from typing import Any, Dict, Sequence
5
6from lightning import Callback, LightningModule, Trainer
7
8from ..base import LightningIRModule
9
10# TODO add final value to all schedulers
11# TODO add cosine decay scheduler
12
13
[docs]
14class LambdaWarmupScheduler(ABC):
[docs]
15 def __init__(
16 self,
17 num_warmup_steps: int,
18 num_delay_steps: int = 0,
19 *args,
20 **kwargs,
21 ) -> None:
22 """Base class for schedulers with warmup.
23
24 Args:
25 num_warmup_steps (int): Number of warmup steps.
26 num_delay_steps (int): Number of steps to delay the scheduler for. Defaults to 0.
27 """
28 self.num_warmup_steps = num_warmup_steps
29 self.num_delay_steps = num_delay_steps
30 super().__init__(*args, **kwargs)
31
[docs]
32 @abstractmethod
33 def value_lambda(self, current_step: int) -> float:
34 """Lambda function to adjust the value at each step.
35
36 Args:
37 current_step (int): Current step.
38 Returns:
39 float: Value at the current step.
40 """
41 ...
42
43 def _check_delay(self, current_step: int) -> bool:
44 return current_step < self.num_delay_steps
45
46 def _check_warmup(self, current_step: int) -> bool:
47 return current_step < self.num_warmup_steps + self.num_delay_steps
48
49
[docs]
50class LinearSchedulerWithLinearWarmup(LambdaWarmupScheduler):
51
[docs]
52 def __init__(
53 self,
54 num_warmup_steps: int,
55 num_training_steps: int,
56 *args,
57 final_value: float = 0.0,
58 num_delay_steps: int = 0,
59 **kwargs,
60 ) -> None:
61 """Scheduler for linearly decreasing values with linear warmup.
62
63 Args:
64 num_warmup_steps (int): Number of warmup steps.
65 num_training_steps (int): Number of training steps.
66 final_value (float, optional): Final value that should be reached at the end of decay. Defaults to 0.0.
67 num_delay_steps (int | None): Number of steps to delay warmup / decay. Defaults to 0.
68 """
69 self.num_training_steps = num_training_steps
70 self.final_value = final_value
71 super().__init__(num_warmup_steps, num_delay_steps, *args, **kwargs)
72
[docs]
73 def value_lambda(self, current_step: int) -> float:
74 """Lambda function for linearly decreasing values with linear warmup.
75
76 Args:
77 current_step (int): Current step.
78 Returns:
79 float: Value at the current step.
80 """
81 if self._check_delay(current_step):
82 return 0.0
83 if self._check_warmup(current_step):
84 return (current_step - self.num_delay_steps) / self.num_warmup_steps
85 current_step = current_step - self.num_delay_steps - self.num_warmup_steps
86 remaining_steps = self.num_training_steps - self.num_delay_steps - self.num_warmup_steps
87 step_size = (1 - self.final_value) / remaining_steps
88 return max(self.final_value, 1 - step_size * current_step)
89
90
[docs]
91class ConstantSchedulerWithLinearWarmup(LambdaWarmupScheduler):
[docs]
92 def value_lambda(self, current_step: int) -> float:
93 """Lambda function for no decay with linear warmup.
94
95 Args:
96 current_step (int): Current step.
97 Returns:
98 float: Value at the current step.
99 """
100 if self._check_delay(current_step):
101 return 0.0
102 if self._check_warmup(current_step):
103 return (current_step - self.num_delay_steps) / self.num_warmup_steps
104 return 1.0
105
106
[docs]
107class ConstantSchedulerWithQuadraticWarmup(LambdaWarmupScheduler):
[docs]
108 def value_lambda(self, current_step: int) -> float:
109 """Lambda function for no decay with quadratic warmup.
110
111 Args:
112 current_step (int): Current step.
113 Returns:
114 float: Value at the current step.
115 """
116 if self._check_delay(current_step):
117 return 0.0
118 if self._check_warmup(current_step):
119 return ((current_step - self.num_delay_steps) / self.num_warmup_steps) ** 2
120 return 1.0
121
122
[docs]
123class GenericScheduler(Callback, ABC):
124
[docs]
125 def __init__(self, *args, keys: Sequence[str] | None = None, **kwargs) -> None:
126 super().__init__(*args, **kwargs)
127 if keys is None:
128 raise ValueError("keys must be provided")
129 self.keys = keys
130 self.values: Dict[str, float] = {}
131
132 def step(self, key: str, current_step: int) -> float:
133 value = self.values[key]
134 return value * self.value_lambda(current_step)
135
136 @abstractmethod
137 def value_lambda(self, current_step: int) -> float: ...
138
139 def get_value(self, sub_keys: Sequence[str], obj: object) -> object:
140 for sub_key in sub_keys:
141 try:
142 obj = obj[int(sub_key)]
143 except ValueError:
144 obj = getattr(obj, sub_key)
145 return obj
146
147 def set_value(self, sub_keys: Sequence[str], obj: object, value: float) -> None:
148 obj = self.get_value(sub_keys[:-1], obj)
149 setattr(obj, sub_keys[-1], value)
150
[docs]
151 def on_train_start(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
152 for key in self.keys:
153 sub_keys = key.split(".")
154 self.values[key] = float(self.get_value(sub_keys, pl_module))
155
[docs]
156 def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None:
157 step = trainer.global_step + 1
158 for key in self.keys:
159 value = self.step(key, step)
160 sub_keys = key.split(".")
161 self.set_value(sub_keys, pl_module, value)
162
[docs]
163 def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
164 for key in self.keys:
165 value = self.values[key]
166 sub_keys = key.split(".")
167 self.set_value(sub_keys, pl_module, value)
168
169
[docs]
170class GenericLinearSchedulerWithLinearWarmup(LinearSchedulerWithLinearWarmup, GenericScheduler):
171 pass
172
173
[docs]
174class GenericConstantSchedulerWithLinearWarmup(ConstantSchedulerWithLinearWarmup, GenericScheduler):
175 pass
176
177
[docs]
178class GenericConstantSchedulerWithQuadraticWarmup(ConstantSchedulerWithQuadraticWarmup, GenericScheduler):
179 pass