Source code for lightning_ir.loss.regularization

 1"""
 2Regularization loss functions for the Lightning IR framework.
 3
 4This module contains loss functions that apply regularization to embeddings
 5to prevent overfitting and improve generalization.
 6"""
 7
 8from __future__ import annotations
 9
10from typing import TYPE_CHECKING
11
12import torch
13
14from .base import RegularizationLossFunction
15
16if TYPE_CHECKING:
17    from ..bi_encoder import BiEncoderOutput
18
19
[docs] 20class L2Regularization(RegularizationLossFunction): 21 """L2 Regularization loss function for query and document embeddings. 22 Originally proposed in: `Ridge Regression: Biased Estimation for Nonorthogonal Problems 23 <https://homepages.math.uic.edu/~lreyzin/papers/ridge.pdf>`_""" 24
[docs] 25 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: 26 """Compute the L2 regularization loss. 27 28 Args: 29 output (BiEncoderOutput): The output from the model containing query and document embeddings. 30 Returns: 31 torch.Tensor: The computed loss. 32 """ 33 query_embeddings, doc_embeddings = self.process_embeddings(output) 34 query_loss = self.query_weight * query_embeddings.norm(dim=-1).mean() 35 doc_loss = self.doc_weight * doc_embeddings.norm(dim=-1).mean() 36 loss = query_loss + doc_loss 37 return loss
38 39
[docs] 40class L1Regularization(RegularizationLossFunction): 41 """L1 Regularization loss function for query and document embeddings. 42 Originally proposed in: `Regression Shrinkage and Selection via the Lasso 43 <https://academic.oup.com/jrsssb/article/58/1/267/7027929>`_""" 44
[docs] 45 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: 46 """Compute the L1 regularization loss. 47 48 Args: 49 output (BiEncoderOutput): The output from the model containing query and document embeddings. 50 Returns: 51 torch.Tensor: The computed loss. 52 """ 53 query_embeddings, doc_embeddings = self.process_embeddings(output) 54 query_loss = self.query_weight * query_embeddings.norm(p=1, dim=-1).mean() 55 doc_loss = self.doc_weight * doc_embeddings.norm(p=1, dim=-1).mean() 56 loss = query_loss + doc_loss 57 return loss
58 59
[docs] 60class FLOPSRegularization(RegularizationLossFunction): 61 """FLOPS Regularization loss function for query and document embeddings. 62 Originally proposed in: `Minimizing FLOPS to Learn Efficient Sparse Representations 63 <https://arxiv.org/pdf/2004.05665>`_""" 64
[docs] 65 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: 66 """Compute the FLOPS regularization loss. 67 68 Args: 69 output (BiEncoderOutput): The output from the model containing query and document embeddings. 70 Returns: 71 torch.Tensor: The computed loss. 72 """ 73 query_embeddings, doc_embeddings = self.process_embeddings(output) 74 query_loss = torch.sum(torch.mean(torch.abs(query_embeddings), dim=0) ** 2) 75 doc_loss = torch.sum(torch.mean(torch.abs(doc_embeddings), dim=0) ** 2) 76 loss = self.query_weight * query_loss + self.doc_weight * doc_loss 77 return loss