Source code for lightning_ir.loss.regularization

 1"""
 2Regularization loss functions for the Lightning IR framework.
 3
 4This module contains loss functions that apply regularization to embeddings
 5to prevent overfitting and improve generalization.
 6"""
 7
 8from __future__ import annotations
 9
10from typing import TYPE_CHECKING
11
12import torch
13
14from .base import RegularizationLossFunction
15
16if TYPE_CHECKING:
17    from ..bi_encoder import BiEncoderOutput
18
19
[docs] 20class L2Regularization(RegularizationLossFunction): 21 """L2 Regularization loss function for query and document embeddings. 22 23 L2 Regularization, also known as Ridge Regression, adds a penalty term to the loss function that is proportional to 24 the square of the magnitude of the model's parameters (in this case, the query and document embeddings). This 25 encourages the model to keep the embeddings small, which can help prevent overfitting by discouraging the model 26 from relying too heavily on any single feature. The L2 penalty is differentiable and leads to a smooth optimization 27 landscape, making it a popular choice for regularization in machine learning models. 28 29 Originally proposed in: `Ridge Regression: Biased Estimation for Nonorthogonal Problems 30 <https://homepages.math.uic.edu/~lreyzin/papers/ridge.pdf>`_""" 31
[docs] 32 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: 33 """Compute the L2 regularization loss. 34 35 Args: 36 output (BiEncoderOutput): The output from the model containing query and document embeddings. 37 Returns: 38 torch.Tensor: The computed loss. 39 """ 40 query_embeddings, doc_embeddings = self.process_embeddings(output) 41 query_loss = self.query_weight * query_embeddings.norm(dim=-1).mean() 42 doc_loss = self.doc_weight * doc_embeddings.norm(dim=-1).mean() 43 loss = query_loss + doc_loss 44 return loss
45 46
[docs] 47class L1Regularization(RegularizationLossFunction): 48 """L1 Regularization loss function for query and document embeddings. 49 50 L1 Regularization, also known as Lasso Regression, adds a penalty term to the loss function that is proportional to 51 the absolute value of the model's parameters (in this case, the query and document embeddings). This encourages 52 sparsity in the embeddings, meaning that it pushes many of the embedding dimensions to be exactly zero. This can 53 lead to more interpretable models and can also help with feature selection by effectively removing less important 54 features. 55 56 Originally proposed in: `Regression Shrinkage and Selection via the Lasso 57 <https://academic.oup.com/jrsssb/article/58/1/267/7027929>`_""" 58
[docs] 59 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: 60 """Compute the L1 regularization loss. 61 62 Args: 63 output (BiEncoderOutput): The output from the model containing query and document embeddings. 64 Returns: 65 torch.Tensor: The computed loss. 66 """ 67 query_embeddings, doc_embeddings = self.process_embeddings(output) 68 query_loss = self.query_weight * query_embeddings.norm(p=1, dim=-1).mean() 69 doc_loss = self.doc_weight * doc_embeddings.norm(p=1, dim=-1).mean() 70 loss = query_loss + doc_loss 71 return loss
72 73
[docs] 74class FLOPSRegularization(RegularizationLossFunction): 75 """FLOPS Regularization loss function for query and document embeddings. 76 77 FLOPS Regularization adds a penalty to the loss function that encourages the model to produce sparse embeddings, 78 which can lead to more efficient inference by reducing the number of non-zero parameters. This is particularly 79 beneficial for large-scale retrieval systems where computational efficiency is crucial. The FLOPS regularization 80 term is designed to minimize the number of floating-point operations (FLOPS) required during inference by promoting 81 sparsity in the embeddings, effectively encouraging the model to focus on the most important features while 82 ignoring less relevant ones. 83 84 Originally proposed in: `Minimizing FLOPS to Learn Efficient Sparse Representations 85 <https://arxiv.org/pdf/2004.05665>`_""" 86
[docs] 87 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: 88 """Compute the FLOPS regularization loss. 89 90 Args: 91 output (BiEncoderOutput): The output from the model containing query and document embeddings. 92 Returns: 93 torch.Tensor: The computed loss. 94 """ 95 query_embeddings, doc_embeddings = self.process_embeddings(output) 96 query_loss = torch.sum(torch.mean(torch.abs(query_embeddings), dim=0) ** 2) 97 doc_loss = torch.sum(torch.mean(torch.abs(doc_embeddings), dim=0) ** 2) 98 loss = self.query_weight * query_loss + self.doc_weight * doc_loss 99 return loss