Source code for lightning_ir.schedulers.lr_schedulers
1"""Learning rate schedulers for LightningIR."""
2
3import torch
4
5from .schedulers import ConstantSchedulerWithLinearWarmup, LambdaWarmupScheduler, LinearSchedulerWithLinearWarmup
6
7
[docs]
8class WarmupLRScheduler(LambdaWarmupScheduler, torch.optim.lr_scheduler.LambdaLR):
[docs]
9 def __init__(
10 self,
11 optimizer: torch.optim.Optimizer,
12 num_warmup_steps: int,
13 **kwargs,
14 ) -> None:
15 """Base class for learning rate schedulers with warmup.
16
17 Args:
18 optimizer (torch.optim.Optimizer): Optimizer to adjust the learning rate for.
19 num_warmup_steps (int): Number of warmup steps.
20 """
21 last_epoch = -1
22 self.interval = "step"
23 super().__init__(
24 optimizer=optimizer,
25 lr_lambda=self.value_lambda,
26 num_warmup_steps=num_warmup_steps,
27 last_epoch=last_epoch,
28 **kwargs,
29 )
30
31
[docs]
32class LinearLRSchedulerWithLinearWarmup(WarmupLRScheduler, LinearSchedulerWithLinearWarmup):
33 """Scheduler for linearly decreasing learning rate with linear warmup."""
34
35 pass
36
37
[docs]
38class ConstantLRSchedulerWithLinearWarmup(WarmupLRScheduler, ConstantSchedulerWithLinearWarmup):
39 """Scheduler for constant learning rate with linear warmup."""
40
41 pass
42
43
44LR_SCHEDULERS = (
45 LinearLRSchedulerWithLinearWarmup,
46 ConstantLRSchedulerWithLinearWarmup,
47 WarmupLRScheduler,
48 torch.optim.lr_scheduler.LRScheduler,
49)