Source code for lightning_ir.loss.regularization
1"""
2Regularization loss functions for the Lightning IR framework.
3
4This module contains loss functions that apply regularization to embeddings
5to prevent overfitting and improve generalization.
6"""
7
8from __future__ import annotations
9
10from typing import TYPE_CHECKING
11
12import torch
13
14from .base import RegularizationLossFunction
15
16if TYPE_CHECKING:
17 from ..bi_encoder import BiEncoderOutput
18
19
[docs]
20class L2Regularization(RegularizationLossFunction):
21 """L2 Regularization loss function for query and document embeddings.
22 Originally proposed in: `Ridge Regression: Biased Estimation for Nonorthogonal Problems
23 <https://homepages.math.uic.edu/~lreyzin/papers/ridge.pdf>`_"""
24
[docs]
25 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor:
26 """Compute the L2 regularization loss.
27
28 Args:
29 output (BiEncoderOutput): The output from the model containing query and document embeddings.
30 Returns:
31 torch.Tensor: The computed loss.
32 """
33 query_embeddings, doc_embeddings = self.process_embeddings(output)
34 query_loss = self.query_weight * query_embeddings.norm(dim=-1).mean()
35 doc_loss = self.doc_weight * doc_embeddings.norm(dim=-1).mean()
36 loss = query_loss + doc_loss
37 return loss
38
39
[docs]
40class L1Regularization(RegularizationLossFunction):
41 """L1 Regularization loss function for query and document embeddings.
42 Originally proposed in: `Regression Shrinkage and Selection via the Lasso
43 <https://academic.oup.com/jrsssb/article/58/1/267/7027929>`_"""
44
[docs]
45 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor:
46 """Compute the L1 regularization loss.
47
48 Args:
49 output (BiEncoderOutput): The output from the model containing query and document embeddings.
50 Returns:
51 torch.Tensor: The computed loss.
52 """
53 query_embeddings, doc_embeddings = self.process_embeddings(output)
54 query_loss = self.query_weight * query_embeddings.norm(p=1, dim=-1).mean()
55 doc_loss = self.doc_weight * doc_embeddings.norm(p=1, dim=-1).mean()
56 loss = query_loss + doc_loss
57 return loss
58
59
[docs]
60class FLOPSRegularization(RegularizationLossFunction):
61 """FLOPS Regularization loss function for query and document embeddings.
62 Originally proposed in: `Minimizing FLOPS to Learn Efficient Sparse Representations
63 <https://arxiv.org/pdf/2004.05665>`_"""
64
[docs]
65 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor:
66 """Compute the FLOPS regularization loss.
67
68 Args:
69 output (BiEncoderOutput): The output from the model containing query and document embeddings.
70 Returns:
71 torch.Tensor: The computed loss.
72 """
73 query_embeddings, doc_embeddings = self.process_embeddings(output)
74 query_loss = torch.sum(torch.mean(torch.abs(query_embeddings), dim=0) ** 2)
75 doc_loss = torch.sum(torch.mean(torch.abs(doc_embeddings), dim=0) ** 2)
76 loss = self.query_weight * query_loss + self.doc_weight * doc_loss
77 return loss