Source code for lightning_ir.schedulers.lr_schedulers
1"""Learning rate schedulers for LightningIR."""
2
3import torch
4
5from .schedulers import ConstantSchedulerWithLinearWarmup, LambdaWarmupScheduler, LinearSchedulerWithLinearWarmup
6
7
[docs]
8class WarmupLRScheduler(LambdaWarmupScheduler, torch.optim.lr_scheduler.LambdaLR):
[docs]
9 def __init__(
10 self,
11 optimizer: torch.optim.Optimizer,
12 num_warmup_steps: int,
13 **kwargs,
14 ) -> None:
15 """Base class for learning rate schedulers with warmup.
16
17 :param optimizer: Optimizer to adjust the learning rate for.
18 :type optimizer: torch.optim.Optimizer
19 :param num_warmup_steps: Number of warmup steps.
20 :type num_warmup_steps: int
21 """
22 last_epoch = -1
23 self.interval = "step"
24 super().__init__(
25 optimizer=optimizer,
26 lr_lambda=self.value_lambda,
27 num_warmup_steps=num_warmup_steps,
28 last_epoch=last_epoch,
29 **kwargs,
30 )
31
32
[docs]
33class LinearLRSchedulerWithLinearWarmup(WarmupLRScheduler, LinearSchedulerWithLinearWarmup):
34 """Scheduler for linearly decreasing learning rate with linear warmup."""
35
36 pass
37
38
[docs]
39class ConstantLRSchedulerWithLinearWarmup(WarmupLRScheduler, ConstantSchedulerWithLinearWarmup):
40 """Scheduler for constant learning rate with linear warmup."""
41
42 pass
43
44
45LR_SCHEDULERS = (
46 LinearLRSchedulerWithLinearWarmup,
47 ConstantLRSchedulerWithLinearWarmup,
48 WarmupLRScheduler,
49 torch.optim.lr_scheduler.LRScheduler,
50)