1"""
2Approximate ranking loss functions for the Lightning IR framework.
3
4This module contains loss functions that use approximation techniques to compute
5ranking-based metrics like NDCG, MRR, and RankMSE in a differentiable manner.
6"""
7
8from __future__ import annotations
9
10from typing import TYPE_CHECKING, Literal
11
12import torch
13
14from .base import ListwiseLossFunction
15
16if TYPE_CHECKING:
17 from ..base import LightningIROutput
18 from ..data import TrainBatch
19
20
[docs]
21class ApproxLossFunction(ListwiseLossFunction):
22 """Base class for approximate loss functions that compute ranks from scores."""
23
[docs]
24 def __init__(self, temperature: float = 1) -> None:
25 """Initialize the ApproxLossFunction.
26
27 Args:
28 temperature (float): Temperature parameter for scaling the scores. Defaults to 1.
29 """
30 super().__init__()
31 self.temperature = temperature
32
[docs]
33 @staticmethod
34 def get_approx_ranks(scores: torch.Tensor, temperature: float) -> torch.Tensor:
35 """Compute approximate ranks from scores.
36
37 Args:
38 scores (torch.Tensor): The input scores.
39 temperature (float): Temperature parameter for scaling the scores.
40 Returns:
41 torch.Tensor: The computed approximate ranks.
42 """
43 score_diff = scores[:, None] - scores[..., None]
44 normalized_score_diff = torch.sigmoid(score_diff / temperature)
45 # set diagonal to 0
46 normalized_score_diff = normalized_score_diff * (1 - torch.eye(scores.shape[1], device=scores.device))
47 approx_ranks = normalized_score_diff.sum(-1) + 1
48 return approx_ranks
49
50
[docs]
51class ApproxNDCG(ApproxLossFunction):
52 """Approximate NDCG loss function for ranking tasks.
53
54 Standard NDCG relies on non-differentiable sorting operations that prevent the use of gradient descent for direct
55 optimization. Approximate NDCG overcomes this limitation by replacing the sorting step with a smooth,
56 differentiable surrogate function that estimates the rank of each document based on its score. This approach allows
57 the model to optimize a loss that is mathematically aligned with the final evaluation metric, reducing the mismatch
58 between training objectives and testing performance.
59
60 Originally proposed in: `Cumulated Gain-Based Evaluation of IR Techniques \
61 <https://dl.acm.org/doi/10.1145/582415.582418>`_"""
62
[docs]
63 def __init__(self, temperature: float = 1, scale_gains: bool = True):
64 """Initialize the ApproxNDCG loss function.
65
66 Args:
67 temperature (float): Temperature parameter for scaling the scores. Defaults to 1.
68 scale_gains (bool): Whether to scale the gains. Defaults to True.
69 """
70 super().__init__(temperature)
71 self.scale_gains = scale_gains
72
[docs]
73 @staticmethod
74 def get_dcg(
75 ranks: torch.Tensor,
76 targets: torch.Tensor,
77 k: int | None = None,
78 scale_gains: bool = True,
79 ) -> torch.Tensor:
80 """Compute the Discounted Cumulative Gain (DCG) for the given ranks and targets.
81
82 Args:
83 ranks (torch.Tensor): The ranks of the items.
84 targets (torch.Tensor): The relevance scores of the items.
85 k (int | None): Optional cutoff for the ranks. If provided, only computes DCG for the top k items.
86 scale_gains (bool): Whether to scale the gains. Defaults to True.
87 Returns:
88 torch.Tensor: The computed DCG values.
89 """
90 log_ranks = torch.log2(1 + ranks)
91 discounts = 1 / log_ranks
92 if scale_gains:
93 gains = 2**targets - 1
94 else:
95 gains = targets
96 dcgs = gains * discounts
97 if k is not None:
98 dcgs = dcgs.masked_fill(ranks > k, 0)
99 return dcgs.sum(dim=-1)
100
[docs]
101 @staticmethod
102 def get_ndcg(
103 ranks: torch.Tensor,
104 targets: torch.Tensor,
105 k: int | None = None,
106 scale_gains: bool = True,
107 optimal_targets: torch.Tensor | None = None,
108 ) -> torch.Tensor:
109 """Compute the Normalized Discounted Cumulative Gain (NDCG) for the given ranks and targets.
110
111 Args:
112 ranks (torch.Tensor): The ranks of the items.
113 targets (torch.Tensor): The relevance scores of the items.
114 k (int | None): Cutoff for the ranks. If provided, only computes NDCG for the top k items. Defaults to None.
115 scale_gains (bool): Whether to scale the gains. Defaults to True.
116 optimal_targets (torch.Tensor | None): Optional tensor of optimal targets for normalization. If None, uses
117 the targets. Defaults to None.
118 Returns:
119 torch.Tensor: The computed NDCG values.
120 """
121 targets = targets.clamp(min=0)
122 if optimal_targets is None:
123 optimal_targets = targets
124 optimal_ranks = torch.argsort(torch.argsort(optimal_targets, descending=True))
125 optimal_ranks = optimal_ranks + 1
126 dcg = ApproxNDCG.get_dcg(ranks, targets, k, scale_gains)
127 idcg = ApproxNDCG.get_dcg(optimal_ranks, optimal_targets, k, scale_gains)
128 ndcg = dcg / (idcg.clamp(min=1e-12))
129 return ndcg
130
[docs]
131 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
132 """Compute the ApproxNDCG loss.
133
134 Args:
135 output (LightningIROutput): The output from the model containing scores.
136 batch (TrainBatch): The training batch containing targets.
137 Returns:
138 torch.Tensor: The computed loss.
139 """
140 scores = self.process_scores(output)
141 targets = self.process_targets(scores, batch)
142 approx_ranks = self.get_approx_ranks(scores, self.temperature)
143 ndcg = self.get_ndcg(approx_ranks, targets, k=None, scale_gains=self.scale_gains)
144 loss = 1 - ndcg
145 return loss.mean()
146
147
[docs]
148class ApproxMRR(ApproxLossFunction):
149 """Approximate Mean Reciprocal Rank (MRR) loss function for ranking tasks.
150
151 Mean Reciprocal Rank (MRR) is a metric used to evaluate ranking systems by focusing on the position of the
152 first relevant result, making it ideal for tasks like question answering where the user wants one correct answer
153 immediately. It assigns a score of 1/k, where k is the rank of the first relevant document; for example, if the
154 correct result is at position 1, the score is 1, but if it is at position 10, the score drops to 0.1. The final
155 MRR is simply the average of these reciprocal scores across all queries in the dataset.
156 Approximate MRR replaces the non-differentiable discrete ranking operation with a smooth, differentiable surrogate
157 function based on pairwise score comparisons, enabling the model to directly maximize the reciprocal rank of the
158 relevant document via gradient descent.
159 """
160
[docs]
161 def __init__(self, temperature: float = 1):
162 """Initialize the ApproxMRR loss function.
163
164 Args:
165 temperature (float): Temperature parameter for scaling the scores. Defaults to 1.
166 """
167 super().__init__(temperature)
168
[docs]
169 @staticmethod
170 def get_mrr(ranks: torch.Tensor, targets: torch.Tensor, k: int | None = None) -> torch.Tensor:
171 """Compute the Mean Reciprocal Rank (MRR) for the given ranks and targets.
172
173 Args:
174 ranks (torch.Tensor): The ranks of the items.
175 targets (torch.Tensor): The relevance scores of the items.
176 k (int | None): Optional cutoff for the ranks. If provided, only computes MRR for the top k items.
177 Returns:
178 torch.Tensor: The computed MRR values.
179 """
180 targets = targets.clamp(None, 1)
181 reciprocal_ranks = 1 / ranks
182 mrr = reciprocal_ranks * targets
183 if k is not None:
184 mrr = mrr.masked_fill(ranks > k, 0)
185 mrr = mrr.max(dim=-1)[0]
186 return mrr
187
[docs]
188 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
189 """Compute the ApproxMRR loss.
190
191 Args:
192 output (LightningIROutput): The output from the model containing scores.
193 batch (TrainBatch): The training batch containing targets.
194 Returns:
195 torch.Tensor: The computed loss.
196 """
197 scores = self.process_scores(output)
198 targets = self.process_targets(scores, batch)
199 approx_ranks = self.get_approx_ranks(scores, self.temperature)
200 mrr = self.get_mrr(approx_ranks, targets, k=None)
201 loss = 1 - mrr
202 return loss.mean()
203
204
[docs]
205class ApproxRankMSE(ApproxLossFunction):
206 """Approximate Rank Mean Squared Error (RankMSE) loss function for ranking tasks.
207
208 Rank Mean Squared Error (RankMSE) penalizes the squared differences between predicted document ranks and their
209 ground truth ranks. Because standard discrete sorting prevents gradient descent, Approximate RankMSE uses a
210 smooth, differentiable approximation of these ranks. It computes the Mean Squared Error between the continuous
211 approximate ranks and the true target ranks, optionally applying position-based discounting (like log2 or
212 reciprocal weights) to penalize errors at the top of the list more heavily.
213
214 Originally proposed in: `Rank-DistiLLM: Closing the Effectiveness Gap Between Cross-Encoders and LLMs
215 for Passage Re-ranking <https://link.springer.com/chapter/10.1007/978-3-031-88714-7_31>`_
216 """
217
[docs]
218 def __init__(
219 self,
220 temperature: float = 1,
221 discount: Literal["log2", "reciprocal"] | None = None,
222 ):
223 """Initialize the ApproxRankMSE loss function.
224
225 Args:
226 temperature (float): Temperature parameter for scaling the scores. Defaults to 1.
227 discount (Literal["log2", "reciprocal"] | None): Discounting strategy for the loss. If None, no discounting
228 is applied. Defaults to None.
229 """
230 super().__init__(temperature)
231 self.discount = discount
232
[docs]
233 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
234 """Compute the ApproxRankMSE loss.
235
236 Args:
237 output (LightningIROutput): The output from the model containing scores.
238 batch (TrainBatch): The training batch containing targets.
239 Returns:
240 torch.Tensor: The computed loss.
241 """
242 scores = self.process_scores(output)
243 targets = self.process_targets(scores, batch)
244 approx_ranks = self.get_approx_ranks(scores, self.temperature)
245 ranks = torch.argsort(torch.argsort(targets, descending=True)) + 1
246 loss = torch.nn.functional.mse_loss(approx_ranks, ranks.to(approx_ranks), reduction="none")
247 if self.discount == "log2":
248 weight = 1 / torch.log2(ranks + 1)
249 elif self.discount == "reciprocal":
250 weight = 1 / ranks
251 else:
252 weight = 1
253 loss = loss * weight
254 loss = loss.mean()
255 return loss