Source code for lightning_ir.loss.pairwise
1"""
2Pairwise loss functions for the Lightning IR framework.
3
4This module contains loss functions that operate on pairs of items,
5comparing positive and negative examples.
6"""
7
8from __future__ import annotations
9
10from typing import TYPE_CHECKING, Literal
11
12import torch
13
14from .base import PairwiseLossFunction
15
16if TYPE_CHECKING:
17 from ..base import LightningIROutput
18 from ..data import TrainBatch
19
20
[docs]
21class MarginMSE(PairwiseLossFunction):
22 """Mean Squared Error loss with a margin for pairwise ranking tasks.
23
24 MarginMSE optimizes pairwise ranking by penalizing the squared difference between the predicted score margin of a
25 positive and negative document and a target margin. This target margin can be a fixed constant or dynamically
26 derived from the difference in ground truth or teacher scores, making it particularly effective for knowledge
27 distillation tasks where a student model learns to replicate the score distances of a stronger teacher model.
28
29 Originally proposed in: `Improving Efficient Neural Ranking Models with Cross-Architecture Knowledge Distillation \
30 <https://arxiv.org/abs/2010.02666>`_
31 """
32
[docs]
33 def __init__(self, margin: float | Literal["scores"] = 1.0):
34 """Initialize the MarginMSE loss function.
35
36 Args:
37 margin (float | Literal["scores"]): The margin value for the loss.
38 """
39 self.margin = margin
40
[docs]
41 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
42 """Compute the MarginMSE loss.
43
44 Args:
45 output (LightningIROutput): The output from the model containing scores.
46 batch (TrainBatch): The training batch containing targets.
47 Returns:
48 torch.Tensor: The computed loss.
49 Raises:
50 ValueError: If the margin type is invalid.
51 """
52 scores = self.process_scores(output)
53 targets = self.process_targets(scores, batch)
54 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets)
55 pos = scores[query_idcs, pos_idcs]
56 neg = scores[query_idcs, neg_idcs]
57 margin = pos - neg
58 if isinstance(self.margin, float):
59 target_margin = torch.tensor(self.margin, device=scores.device)
60 elif self.margin == "scores":
61 target_margin = targets[query_idcs, pos_idcs] - targets[query_idcs, neg_idcs]
62 else:
63 raise ValueError("invalid margin type")
64 loss = torch.nn.functional.mse_loss(margin, target_margin)
65 return loss
66
67
[docs]
68class ConstantMarginMSE(MarginMSE):
69 """Constant Margin MSE loss for pairwise ranking tasks with a fixed margin."""
70
[docs]
71 def __init__(self, margin: float = 1.0):
72 """Initialize the ConstantMarginMSE loss function.
73
74 Args:
75 margin (float): The fixed margin value for the loss.
76 """
77 super().__init__(margin)
78
79
[docs]
80class SupervisedMarginMSE(MarginMSE):
81 """Supervised Margin MSE loss for pairwise ranking tasks with a dynamic margin."""
82
[docs]
83 def __init__(self):
84 """Initialize the SupervisedMarginMSE loss function."""
85 super().__init__("scores")
86
87
[docs]
88class RankNet(PairwiseLossFunction):
89 """RankNet loss function for pairwise ranking tasks.
90
91 RankNet optimizes pairwise ranking by modeling the probability that a positive document should be ranked higher
92 than a negative document using a logistic function. It computes the margin between the scores of positive and
93 negative pairs and applies a binary cross-entropy loss to maximize the likelihood of correct pairwise orderings.
94 This approach allows the model to learn from relative comparisons rather than absolute score values.
95
96 Originally proposed in: `Learning to Rank using Gradient Descent \
97 <https://dl.acm.org/doi/10.1145/1102351.1102363>`_
98 """
99
[docs]
100 def __init__(self, temperature: float = 1) -> None:
101 """Initialize the RankNet loss function.
102 Args:
103 temperature (float): Temperature parameter for scaling the scores.
104 """
105 super().__init__()
106
107 self.temperature = temperature
108
[docs]
109 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
110 """Compute the RankNet loss.
111
112 Args:
113 output (LightningIROutput): The output from the model containing scores.
114 batch (TrainBatch): The training batch containing targets.
115 Returns:
116 torch.Tensor: The computed loss.
117 """
118 scores = self.process_scores(output) * self.temperature
119 targets = self.process_targets(scores, batch)
120 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets)
121 pos = scores[query_idcs, pos_idcs]
122 neg = scores[query_idcs, neg_idcs]
123 margin = pos - neg
124 loss = torch.nn.functional.binary_cross_entropy_with_logits(margin, torch.ones_like(margin))
125 return loss