Source code for lightning_ir.loss.regularization
1"""
2Regularization loss functions for the Lightning IR framework.
3
4This module contains loss functions that apply regularization to embeddings
5to prevent overfitting and improve generalization.
6"""
7
8from __future__ import annotations
9
10from typing import TYPE_CHECKING
11
12import torch
13
14from .base import RegularizationLossFunction
15
16if TYPE_CHECKING:
17 from ..bi_encoder import BiEncoderOutput
18
19
[docs]
20class L2Regularization(RegularizationLossFunction):
21 """L2 Regularization loss function for query and document embeddings.
22
23 L2 Regularization, also known as Ridge Regression, adds a penalty term to the loss function that is proportional to
24 the square of the magnitude of the model's parameters (in this case, the query and document embeddings). This
25 encourages the model to keep the embeddings small, which can help prevent overfitting by discouraging the model
26 from relying too heavily on any single feature. The L2 penalty is differentiable and leads to a smooth optimization
27 landscape, making it a popular choice for regularization in machine learning models.
28
29 Originally proposed in: `Ridge Regression: Biased Estimation for Nonorthogonal Problems
30 <https://homepages.math.uic.edu/~lreyzin/papers/ridge.pdf>`_"""
31
[docs]
32 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor:
33 """Compute the L2 regularization loss.
34
35 Args:
36 output (BiEncoderOutput): The output from the model containing query and document embeddings.
37 Returns:
38 torch.Tensor: The computed loss.
39 """
40 query_embeddings, doc_embeddings = self.process_embeddings(output)
41 query_loss = self.query_weight * query_embeddings.norm(dim=-1).mean()
42 doc_loss = self.doc_weight * doc_embeddings.norm(dim=-1).mean()
43 loss = query_loss + doc_loss
44 return loss
45
46
[docs]
47class L1Regularization(RegularizationLossFunction):
48 """L1 Regularization loss function for query and document embeddings.
49
50 L1 Regularization, also known as Lasso Regression, adds a penalty term to the loss function that is proportional to
51 the absolute value of the model's parameters (in this case, the query and document embeddings). This encourages
52 sparsity in the embeddings, meaning that it pushes many of the embedding dimensions to be exactly zero. This can
53 lead to more interpretable models and can also help with feature selection by effectively removing less important
54 features.
55
56 Originally proposed in: `Regression Shrinkage and Selection via the Lasso
57 <https://academic.oup.com/jrsssb/article/58/1/267/7027929>`_"""
58
[docs]
59 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor:
60 """Compute the L1 regularization loss.
61
62 Args:
63 output (BiEncoderOutput): The output from the model containing query and document embeddings.
64 Returns:
65 torch.Tensor: The computed loss.
66 """
67 query_embeddings, doc_embeddings = self.process_embeddings(output)
68 query_loss = self.query_weight * query_embeddings.norm(p=1, dim=-1).mean()
69 doc_loss = self.doc_weight * doc_embeddings.norm(p=1, dim=-1).mean()
70 loss = query_loss + doc_loss
71 return loss
72
73
[docs]
74class FLOPSRegularization(RegularizationLossFunction):
75 """FLOPS Regularization loss function for query and document embeddings.
76
77 FLOPS Regularization adds a penalty to the loss function that encourages the model to produce sparse embeddings,
78 which can lead to more efficient inference by reducing the number of non-zero parameters. This is particularly
79 beneficial for large-scale retrieval systems where computational efficiency is crucial. The FLOPS regularization
80 term is designed to minimize the number of floating-point operations (FLOPS) required during inference by promoting
81 sparsity in the embeddings, effectively encouraging the model to focus on the most important features while
82 ignoring less relevant ones.
83
84 Originally proposed in: `Minimizing FLOPS to Learn Efficient Sparse Representations
85 <https://arxiv.org/pdf/2004.05665>`_"""
86
[docs]
87 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor:
88 """Compute the FLOPS regularization loss.
89
90 Args:
91 output (BiEncoderOutput): The output from the model containing query and document embeddings.
92 Returns:
93 torch.Tensor: The computed loss.
94 """
95 query_embeddings, doc_embeddings = self.process_embeddings(output)
96 query_loss = torch.sum(torch.mean(torch.abs(query_embeddings), dim=0) ** 2)
97 doc_loss = torch.sum(torch.mean(torch.abs(doc_embeddings), dim=0) ** 2)
98 loss = self.query_weight * query_loss + self.doc_weight * doc_loss
99 return loss