1"""Generic schedulers for LightningIR."""
2
3from abc import ABC, abstractmethod
4from collections.abc import Sequence
5from typing import Any
6
7from lightning import Callback, LightningModule, Trainer
8
9from ..base import LightningIRModule
10
11# TODO add final value to all schedulers
12# TODO add cosine decay scheduler
13
14
[docs]
15class LambdaWarmupScheduler(ABC):
[docs]
16 def __init__(
17 self,
18 num_warmup_steps: int,
19 num_delay_steps: int = 0,
20 *args,
21 **kwargs,
22 ) -> None:
23 """Base class for schedulers with warmup.
24
25 Args:
26 num_warmup_steps (int): Number of warmup steps.
27 num_delay_steps (int): Number of steps to delay the scheduler for. Defaults to 0.
28 """
29 self.num_warmup_steps = num_warmup_steps
30 self.num_delay_steps = num_delay_steps
31 super().__init__(*args, **kwargs)
32
[docs]
33 @abstractmethod
34 def value_lambda(self, current_step: int) -> float:
35 """Lambda function to adjust the value at each step.
36
37 Args:
38 current_step (int): Current step.
39 Returns:
40 float: Value at the current step.
41 """
42 ...
43
44 def _check_delay(self, current_step: int) -> bool:
45 return current_step < self.num_delay_steps
46
47 def _check_warmup(self, current_step: int) -> bool:
48 return current_step < self.num_warmup_steps + self.num_delay_steps
49
50
[docs]
51class LinearSchedulerWithLinearWarmup(LambdaWarmupScheduler):
[docs]
52 def __init__(
53 self,
54 num_warmup_steps: int,
55 num_training_steps: int,
56 *args,
57 final_value: float = 0.0,
58 num_delay_steps: int = 0,
59 **kwargs,
60 ) -> None:
61 """Scheduler for linearly decreasing values with linear warmup.
62
63 Args:
64 num_warmup_steps (int): Number of warmup steps.
65 num_training_steps (int): Number of training steps.
66 final_value (float, optional): Final value that should be reached at the end of decay. Defaults to 0.0.
67 num_delay_steps (int | None): Number of steps to delay warmup / decay. Defaults to 0.
68 """
69 self.num_training_steps = num_training_steps
70 self.final_value = final_value
71 super().__init__(num_warmup_steps, num_delay_steps, *args, **kwargs)
72
[docs]
73 def value_lambda(self, current_step: int) -> float:
74 """Lambda function for linearly decreasing values with linear warmup.
75
76 Args:
77 current_step (int): Current step.
78 Returns:
79 float: Value at the current step.
80 """
81 if self._check_delay(current_step):
82 return 0.0
83 if self._check_warmup(current_step):
84 return (current_step - self.num_delay_steps) / self.num_warmup_steps
85 current_step = current_step - self.num_delay_steps - self.num_warmup_steps
86 remaining_steps = self.num_training_steps - self.num_delay_steps - self.num_warmup_steps
87 step_size = (1 - self.final_value) / remaining_steps
88 return max(self.final_value, 1 - step_size * current_step)
89
90
[docs]
91class ConstantSchedulerWithLinearWarmup(LambdaWarmupScheduler):
[docs]
92 def value_lambda(self, current_step: int) -> float:
93 """Lambda function for no decay with linear warmup.
94
95 Args:
96 current_step (int): Current step.
97 Returns:
98 float: Value at the current step.
99 """
100 if self._check_delay(current_step):
101 return 0.0
102 if self._check_warmup(current_step):
103 return (current_step - self.num_delay_steps) / self.num_warmup_steps
104 return 1.0
105
106
[docs]
107class ConstantSchedulerWithQuadraticWarmup(LambdaWarmupScheduler):
[docs]
108 def value_lambda(self, current_step: int) -> float:
109 """Lambda function for no decay with quadratic warmup.
110
111 Args:
112 current_step (int): Current step.
113 Returns:
114 float: Value at the current step.
115 """
116 if self._check_delay(current_step):
117 return 0.0
118 if self._check_warmup(current_step):
119 return ((current_step - self.num_delay_steps) / self.num_warmup_steps) ** 2
120 return 1.0
121
122
[docs]
123class GenericScheduler(Callback, ABC):
[docs]
124 def __init__(self, *args, keys: Sequence[str] | None = None, **kwargs) -> None:
125 super().__init__(*args, **kwargs)
126 if keys is None:
127 raise ValueError("keys must be provided")
128 self.keys = keys
129 self.values: dict[str, float] = {}
130
131 def step(self, key: str, current_step: int) -> float:
132 value = self.values[key]
133 return value * self.value_lambda(current_step)
134
135 @abstractmethod
136 def value_lambda(self, current_step: int) -> float: ...
137
138 def get_value(self, sub_keys: Sequence[str], obj: object) -> object:
139 for sub_key in sub_keys:
140 try:
141 obj = obj[int(sub_key)]
142 except ValueError:
143 obj = getattr(obj, sub_key)
144 return obj
145
146 def set_value(self, sub_keys: Sequence[str], obj: object, value: float) -> None:
147 obj = self.get_value(sub_keys[:-1], obj)
148 setattr(obj, sub_keys[-1], value)
149
[docs]
150 def on_train_start(self, trainer: Trainer, pl_module: LightningIRModule) -> None:
151 for key in self.keys:
152 sub_keys = key.split(".")
153 self.values[key] = float(self.get_value(sub_keys, pl_module))
154
[docs]
155 def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None:
156 step = trainer.global_step + 1
157 for key in self.keys:
158 value = self.step(key, step)
159 sub_keys = key.split(".")
160 self.set_value(sub_keys, pl_module, value)
161
[docs]
162 def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
163 for key in self.keys:
164 value = self.values[key]
165 sub_keys = key.split(".")
166 self.set_value(sub_keys, pl_module, value)
167
168
[docs]
169class GenericLinearSchedulerWithLinearWarmup(LinearSchedulerWithLinearWarmup, GenericScheduler):
170 pass
171
172
[docs]
173class GenericConstantSchedulerWithLinearWarmup(ConstantSchedulerWithLinearWarmup, GenericScheduler):
174 pass
175
176
[docs]
177class GenericConstantSchedulerWithQuadraticWarmup(ConstantSchedulerWithQuadraticWarmup, GenericScheduler):
178 pass