Source code for lightning_ir.schedulers.lr_schedulers

 1"""Learning rate schedulers for LightningIR."""
 2
 3import torch
 4
 5from .schedulers import ConstantSchedulerWithLinearWarmup, LambdaWarmupScheduler, LinearSchedulerWithLinearWarmup
 6
 7
[docs] 8class WarmupLRScheduler(LambdaWarmupScheduler, torch.optim.lr_scheduler.LambdaLR):
[docs] 9 def __init__( 10 self, 11 optimizer: torch.optim.Optimizer, 12 num_warmup_steps: int, 13 **kwargs, 14 ) -> None: 15 """Base class for learning rate schedulers with warmup. 16 17 Args: 18 optimizer (torch.optim.Optimizer): Optimizer to adjust the learning rate for. 19 num_warmup_steps (int): Number of warmup steps. 20 """ 21 last_epoch = -1 22 self.interval = "step" 23 super().__init__( 24 optimizer=optimizer, 25 lr_lambda=self.value_lambda, 26 num_warmup_steps=num_warmup_steps, 27 last_epoch=last_epoch, 28 **kwargs, 29 )
30 31
[docs] 32class LinearLRSchedulerWithLinearWarmup(WarmupLRScheduler, LinearSchedulerWithLinearWarmup): 33 """Scheduler for linearly decreasing learning rate with linear warmup.""" 34 35 pass
36 37
[docs] 38class ConstantLRSchedulerWithLinearWarmup(WarmupLRScheduler, ConstantSchedulerWithLinearWarmup): 39 """Scheduler for constant learning rate with linear warmup.""" 40 41 pass
42 43 44LR_SCHEDULERS = ( 45 LinearLRSchedulerWithLinearWarmup, 46 ConstantLRSchedulerWithLinearWarmup, 47 WarmupLRScheduler, 48 torch.optim.lr_scheduler.LRScheduler, 49)