1"""
2Approximate ranking loss functions for the Lightning IR framework.
3
4This module contains loss functions that use approximation techniques to compute
5ranking-based metrics like NDCG, MRR, and RankMSE in a differentiable manner.
6"""
7
8from __future__ import annotations
9
10from typing import TYPE_CHECKING, Literal
11
12import torch
13
14from .base import ListwiseLossFunction
15
16if TYPE_CHECKING:
17 from ..base import LightningIROutput
18 from ..data import TrainBatch
19
20
[docs]
21class ApproxLossFunction(ListwiseLossFunction):
22 """Base class for approximate loss functions that compute ranks from scores."""
23
[docs]
24 def __init__(self, temperature: float = 1) -> None:
25 """Initialize the ApproxLossFunction.
26
27 Args:
28 temperature (float): Temperature parameter for scaling the scores. Defaults to 1.
29 """
30 super().__init__()
31 self.temperature = temperature
32
[docs]
33 @staticmethod
34 def get_approx_ranks(scores: torch.Tensor, temperature: float) -> torch.Tensor:
35 """Compute approximate ranks from scores.
36
37 Args:
38 scores (torch.Tensor): The input scores.
39 temperature (float): Temperature parameter for scaling the scores.
40 Returns:
41 torch.Tensor: The computed approximate ranks.
42 """
43 score_diff = scores[:, None] - scores[..., None]
44 normalized_score_diff = torch.sigmoid(score_diff / temperature)
45 # set diagonal to 0
46 normalized_score_diff = normalized_score_diff * (1 - torch.eye(scores.shape[1], device=scores.device))
47 approx_ranks = normalized_score_diff.sum(-1) + 1
48 return approx_ranks
49
50
[docs]
51class ApproxNDCG(ApproxLossFunction):
52 """Approximate NDCG loss function for ranking tasks.
53 Originally proposed in: `Cumulated Gain-Based Evaluation of IR Techniques \
54 <https://dl.acm.org/doi/10.1145/582415.582418>`_"""
55
[docs]
56 def __init__(self, temperature: float = 1, scale_gains: bool = True):
57 """Initialize the ApproxNDCG loss function.
58
59 Args:
60 temperature (float): Temperature parameter for scaling the scores. Defaults to 1.
61 scale_gains (bool): Whether to scale the gains. Defaults to True.
62 """
63 super().__init__(temperature)
64 self.scale_gains = scale_gains
65
[docs]
66 @staticmethod
67 def get_dcg(
68 ranks: torch.Tensor,
69 targets: torch.Tensor,
70 k: int | None = None,
71 scale_gains: bool = True,
72 ) -> torch.Tensor:
73 """Compute the Discounted Cumulative Gain (DCG) for the given ranks and targets.
74
75 Args:
76 ranks (torch.Tensor): The ranks of the items.
77 targets (torch.Tensor): The relevance scores of the items.
78 k (int | None): Optional cutoff for the ranks. If provided, only computes DCG for the top k items.
79 scale_gains (bool): Whether to scale the gains. Defaults to True.
80 Returns:
81 torch.Tensor: The computed DCG values.
82 """
83 log_ranks = torch.log2(1 + ranks)
84 discounts = 1 / log_ranks
85 if scale_gains:
86 gains = 2**targets - 1
87 else:
88 gains = targets
89 dcgs = gains * discounts
90 if k is not None:
91 dcgs = dcgs.masked_fill(ranks > k, 0)
92 return dcgs.sum(dim=-1)
93
[docs]
94 @staticmethod
95 def get_ndcg(
96 ranks: torch.Tensor,
97 targets: torch.Tensor,
98 k: int | None = None,
99 scale_gains: bool = True,
100 optimal_targets: torch.Tensor | None = None,
101 ) -> torch.Tensor:
102 """Compute the Normalized Discounted Cumulative Gain (NDCG) for the given ranks and targets.
103
104 Args:
105 ranks (torch.Tensor): The ranks of the items.
106 targets (torch.Tensor): The relevance scores of the items.
107 k (int | None): Cutoff for the ranks. If provided, only computes NDCG for the top k items. Defaults to None.
108 scale_gains (bool): Whether to scale the gains. Defaults to True.
109 optimal_targets (torch.Tensor | None): Optional tensor of optimal targets for normalization. If None, uses
110 the targets. Defaults to None.
111 Returns:
112 torch.Tensor: The computed NDCG values.
113 """
114 targets = targets.clamp(min=0)
115 if optimal_targets is None:
116 optimal_targets = targets
117 optimal_ranks = torch.argsort(torch.argsort(optimal_targets, descending=True))
118 optimal_ranks = optimal_ranks + 1
119 dcg = ApproxNDCG.get_dcg(ranks, targets, k, scale_gains)
120 idcg = ApproxNDCG.get_dcg(optimal_ranks, optimal_targets, k, scale_gains)
121 ndcg = dcg / (idcg.clamp(min=1e-12))
122 return ndcg
123
[docs]
124 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
125 """Compute the ApproxNDCG loss.
126
127 Args:
128 output (LightningIROutput): The output from the model containing scores.
129 batch (TrainBatch): The training batch containing targets.
130 Returns:
131 torch.Tensor: The computed loss.
132 """
133 scores = self.process_scores(output)
134 targets = self.process_targets(scores, batch)
135 approx_ranks = self.get_approx_ranks(scores, self.temperature)
136 ndcg = self.get_ndcg(approx_ranks, targets, k=None, scale_gains=self.scale_gains)
137 loss = 1 - ndcg
138 return loss.mean()
139
140
[docs]
141class ApproxMRR(ApproxLossFunction):
142 """Approximate Mean Reciprocal Rank (MRR) loss function for ranking tasks."""
143
[docs]
144 def __init__(self, temperature: float = 1):
145 """Initialize the ApproxMRR loss function.
146
147 Args:
148 temperature (float): Temperature parameter for scaling the scores. Defaults to 1.
149 """
150 super().__init__(temperature)
151
[docs]
152 @staticmethod
153 def get_mrr(ranks: torch.Tensor, targets: torch.Tensor, k: int | None = None) -> torch.Tensor:
154 """Compute the Mean Reciprocal Rank (MRR) for the given ranks and targets.
155
156 Args:
157 ranks (torch.Tensor): The ranks of the items.
158 targets (torch.Tensor): The relevance scores of the items.
159 k (int | None): Optional cutoff for the ranks. If provided, only computes MRR for the top k items.
160 Returns:
161 torch.Tensor: The computed MRR values.
162 """
163 targets = targets.clamp(None, 1)
164 reciprocal_ranks = 1 / ranks
165 mrr = reciprocal_ranks * targets
166 if k is not None:
167 mrr = mrr.masked_fill(ranks > k, 0)
168 mrr = mrr.max(dim=-1)[0]
169 return mrr
170
[docs]
171 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
172 """Compute the ApproxMRR loss.
173
174 Args:
175 output (LightningIROutput): The output from the model containing scores.
176 batch (TrainBatch): The training batch containing targets.
177 Returns:
178 torch.Tensor: The computed loss.
179 """
180 scores = self.process_scores(output)
181 targets = self.process_targets(scores, batch)
182 approx_ranks = self.get_approx_ranks(scores, self.temperature)
183 mrr = self.get_mrr(approx_ranks, targets, k=None)
184 loss = 1 - mrr
185 return loss.mean()
186
187
[docs]
188class ApproxRankMSE(ApproxLossFunction):
189 """Approximate Rank Mean Squared Error (RankMSE) loss function for ranking tasks.
190 Originally proposed in: `Rank-DistiLLM: Closing the Effectiveness Gap Between Cross-Encoders and LLMs
191 for Passage Re-ranking <https://link.springer.com/chapter/10.1007/978-3-031-88714-7_31>`_"""
192
[docs]
193 def __init__(
194 self,
195 temperature: float = 1,
196 discount: Literal["log2", "reciprocal"] | None = None,
197 ):
198 """Initialize the ApproxRankMSE loss function.
199
200 Args:
201 temperature (float): Temperature parameter for scaling the scores. Defaults to 1.
202 discount (Literal["log2", "reciprocal"] | None): Discounting strategy for the loss. If None, no discounting
203 is applied. Defaults to None.
204 """
205 super().__init__(temperature)
206 self.discount = discount
207
[docs]
208 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
209 """Compute the ApproxRankMSE loss.
210
211 Args:
212 output (LightningIROutput): The output from the model containing scores.
213 batch (TrainBatch): The training batch containing targets.
214 Returns:
215 torch.Tensor: The computed loss.
216 """
217 scores = self.process_scores(output)
218 targets = self.process_targets(scores, batch)
219 approx_ranks = self.get_approx_ranks(scores, self.temperature)
220 ranks = torch.argsort(torch.argsort(targets, descending=True)) + 1
221 loss = torch.nn.functional.mse_loss(approx_ranks, ranks.to(approx_ranks), reduction="none")
222 if self.discount == "log2":
223 weight = 1 / torch.log2(ranks + 1)
224 elif self.discount == "reciprocal":
225 weight = 1 / ranks
226 else:
227 weight = 1
228 loss = loss * weight
229 loss = loss.mean()
230 return loss