Source code for lightning_ir.loss.approximate

  1"""
  2Approximate ranking loss functions for the Lightning IR framework.
  3
  4This module contains loss functions that use approximation techniques to compute
  5ranking-based metrics like NDCG, MRR, and RankMSE in a differentiable manner.
  6"""
  7
  8from __future__ import annotations
  9
 10from typing import TYPE_CHECKING, Literal
 11
 12import torch
 13
 14from .base import ListwiseLossFunction
 15
 16if TYPE_CHECKING:
 17    from ..base import LightningIROutput
 18    from ..data import TrainBatch
 19
 20
[docs] 21class ApproxLossFunction(ListwiseLossFunction): 22 """Base class for approximate loss functions that compute ranks from scores.""" 23
[docs] 24 def __init__(self, temperature: float = 1) -> None: 25 """Initialize the ApproxLossFunction. 26 27 Args: 28 temperature (float): Temperature parameter for scaling the scores. Defaults to 1. 29 """ 30 super().__init__() 31 self.temperature = temperature
32
[docs] 33 @staticmethod 34 def get_approx_ranks(scores: torch.Tensor, temperature: float) -> torch.Tensor: 35 """Compute approximate ranks from scores. 36 37 Args: 38 scores (torch.Tensor): The input scores. 39 temperature (float): Temperature parameter for scaling the scores. 40 Returns: 41 torch.Tensor: The computed approximate ranks. 42 """ 43 score_diff = scores[:, None] - scores[..., None] 44 normalized_score_diff = torch.sigmoid(score_diff / temperature) 45 # set diagonal to 0 46 normalized_score_diff = normalized_score_diff * (1 - torch.eye(scores.shape[1], device=scores.device)) 47 approx_ranks = normalized_score_diff.sum(-1) + 1 48 return approx_ranks
49 50
[docs] 51class ApproxNDCG(ApproxLossFunction): 52 """Approximate NDCG loss function for ranking tasks. 53 54 Standard NDCG relies on non-differentiable sorting operations that prevent the use of gradient descent for direct 55 optimization. Approximate NDCG overcomes this limitation by replacing the sorting step with a smooth, 56 differentiable surrogate function that estimates the rank of each document based on its score. This approach allows 57 the model to optimize a loss that is mathematically aligned with the final evaluation metric, reducing the mismatch 58 between training objectives and testing performance. 59 60 Originally proposed in: `Cumulated Gain-Based Evaluation of IR Techniques \ 61 <https://dl.acm.org/doi/10.1145/582415.582418>`_""" 62
[docs] 63 def __init__(self, temperature: float = 1, scale_gains: bool = True): 64 """Initialize the ApproxNDCG loss function. 65 66 Args: 67 temperature (float): Temperature parameter for scaling the scores. Defaults to 1. 68 scale_gains (bool): Whether to scale the gains. Defaults to True. 69 """ 70 super().__init__(temperature) 71 self.scale_gains = scale_gains
72
[docs] 73 @staticmethod 74 def get_dcg( 75 ranks: torch.Tensor, 76 targets: torch.Tensor, 77 k: int | None = None, 78 scale_gains: bool = True, 79 ) -> torch.Tensor: 80 """Compute the Discounted Cumulative Gain (DCG) for the given ranks and targets. 81 82 Args: 83 ranks (torch.Tensor): The ranks of the items. 84 targets (torch.Tensor): The relevance scores of the items. 85 k (int | None): Optional cutoff for the ranks. If provided, only computes DCG for the top k items. 86 scale_gains (bool): Whether to scale the gains. Defaults to True. 87 Returns: 88 torch.Tensor: The computed DCG values. 89 """ 90 log_ranks = torch.log2(1 + ranks) 91 discounts = 1 / log_ranks 92 if scale_gains: 93 gains = 2**targets - 1 94 else: 95 gains = targets 96 dcgs = gains * discounts 97 if k is not None: 98 dcgs = dcgs.masked_fill(ranks > k, 0) 99 return dcgs.sum(dim=-1)
100
[docs] 101 @staticmethod 102 def get_ndcg( 103 ranks: torch.Tensor, 104 targets: torch.Tensor, 105 k: int | None = None, 106 scale_gains: bool = True, 107 optimal_targets: torch.Tensor | None = None, 108 ) -> torch.Tensor: 109 """Compute the Normalized Discounted Cumulative Gain (NDCG) for the given ranks and targets. 110 111 Args: 112 ranks (torch.Tensor): The ranks of the items. 113 targets (torch.Tensor): The relevance scores of the items. 114 k (int | None): Cutoff for the ranks. If provided, only computes NDCG for the top k items. Defaults to None. 115 scale_gains (bool): Whether to scale the gains. Defaults to True. 116 optimal_targets (torch.Tensor | None): Optional tensor of optimal targets for normalization. If None, uses 117 the targets. Defaults to None. 118 Returns: 119 torch.Tensor: The computed NDCG values. 120 """ 121 targets = targets.clamp(min=0) 122 if optimal_targets is None: 123 optimal_targets = targets 124 optimal_ranks = torch.argsort(torch.argsort(optimal_targets, descending=True)) 125 optimal_ranks = optimal_ranks + 1 126 dcg = ApproxNDCG.get_dcg(ranks, targets, k, scale_gains) 127 idcg = ApproxNDCG.get_dcg(optimal_ranks, optimal_targets, k, scale_gains) 128 ndcg = dcg / (idcg.clamp(min=1e-12)) 129 return ndcg
130
[docs] 131 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 132 """Compute the ApproxNDCG loss. 133 134 Args: 135 output (LightningIROutput): The output from the model containing scores. 136 batch (TrainBatch): The training batch containing targets. 137 Returns: 138 torch.Tensor: The computed loss. 139 """ 140 scores = self.process_scores(output) 141 targets = self.process_targets(scores, batch) 142 approx_ranks = self.get_approx_ranks(scores, self.temperature) 143 ndcg = self.get_ndcg(approx_ranks, targets, k=None, scale_gains=self.scale_gains) 144 loss = 1 - ndcg 145 return loss.mean()
146 147
[docs] 148class ApproxMRR(ApproxLossFunction): 149 """Approximate Mean Reciprocal Rank (MRR) loss function for ranking tasks. 150 151 Mean Reciprocal Rank (MRR) is a metric used to evaluate ranking systems by focusing on the position of the 152 first relevant result, making it ideal for tasks like question answering where the user wants one correct answer 153 immediately. It assigns a score of 1/k, where k is the rank of the first relevant document; for example, if the 154 correct result is at position 1, the score is 1, but if it is at position 10, the score drops to 0.1. The final 155 MRR is simply the average of these reciprocal scores across all queries in the dataset. 156 Approximate MRR replaces the non-differentiable discrete ranking operation with a smooth, differentiable surrogate 157 function based on pairwise score comparisons, enabling the model to directly maximize the reciprocal rank of the 158 relevant document via gradient descent. 159 """ 160
[docs] 161 def __init__(self, temperature: float = 1): 162 """Initialize the ApproxMRR loss function. 163 164 Args: 165 temperature (float): Temperature parameter for scaling the scores. Defaults to 1. 166 """ 167 super().__init__(temperature)
168
[docs] 169 @staticmethod 170 def get_mrr(ranks: torch.Tensor, targets: torch.Tensor, k: int | None = None) -> torch.Tensor: 171 """Compute the Mean Reciprocal Rank (MRR) for the given ranks and targets. 172 173 Args: 174 ranks (torch.Tensor): The ranks of the items. 175 targets (torch.Tensor): The relevance scores of the items. 176 k (int | None): Optional cutoff for the ranks. If provided, only computes MRR for the top k items. 177 Returns: 178 torch.Tensor: The computed MRR values. 179 """ 180 targets = targets.clamp(None, 1) 181 reciprocal_ranks = 1 / ranks 182 mrr = reciprocal_ranks * targets 183 if k is not None: 184 mrr = mrr.masked_fill(ranks > k, 0) 185 mrr = mrr.max(dim=-1)[0] 186 return mrr
187
[docs] 188 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 189 """Compute the ApproxMRR loss. 190 191 Args: 192 output (LightningIROutput): The output from the model containing scores. 193 batch (TrainBatch): The training batch containing targets. 194 Returns: 195 torch.Tensor: The computed loss. 196 """ 197 scores = self.process_scores(output) 198 targets = self.process_targets(scores, batch) 199 approx_ranks = self.get_approx_ranks(scores, self.temperature) 200 mrr = self.get_mrr(approx_ranks, targets, k=None) 201 loss = 1 - mrr 202 return loss.mean()
203 204
[docs] 205class ApproxRankMSE(ApproxLossFunction): 206 """Approximate Rank Mean Squared Error (RankMSE) loss function for ranking tasks. 207 208 Rank Mean Squared Error (RankMSE) penalizes the squared differences between predicted document ranks and their 209 ground truth ranks. Because standard discrete sorting prevents gradient descent, Approximate RankMSE uses a 210 smooth, differentiable approximation of these ranks. It computes the Mean Squared Error between the continuous 211 approximate ranks and the true target ranks, optionally applying position-based discounting (like log2 or 212 reciprocal weights) to penalize errors at the top of the list more heavily. 213 214 Originally proposed in: `Rank-DistiLLM: Closing the Effectiveness Gap Between Cross-Encoders and LLMs 215 for Passage Re-ranking <https://link.springer.com/chapter/10.1007/978-3-031-88714-7_31>`_ 216 """ 217
[docs] 218 def __init__( 219 self, 220 temperature: float = 1, 221 discount: Literal["log2", "reciprocal"] | None = None, 222 ): 223 """Initialize the ApproxRankMSE loss function. 224 225 Args: 226 temperature (float): Temperature parameter for scaling the scores. Defaults to 1. 227 discount (Literal["log2", "reciprocal"] | None): Discounting strategy for the loss. If None, no discounting 228 is applied. Defaults to None. 229 """ 230 super().__init__(temperature) 231 self.discount = discount
232
[docs] 233 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 234 """Compute the ApproxRankMSE loss. 235 236 Args: 237 output (LightningIROutput): The output from the model containing scores. 238 batch (TrainBatch): The training batch containing targets. 239 Returns: 240 torch.Tensor: The computed loss. 241 """ 242 scores = self.process_scores(output) 243 targets = self.process_targets(scores, batch) 244 approx_ranks = self.get_approx_ranks(scores, self.temperature) 245 ranks = torch.argsort(torch.argsort(targets, descending=True)) + 1 246 loss = torch.nn.functional.mse_loss(approx_ranks, ranks.to(approx_ranks), reduction="none") 247 if self.discount == "log2": 248 weight = 1 / torch.log2(ranks + 1) 249 elif self.discount == "reciprocal": 250 weight = 1 / ranks 251 else: 252 weight = 1 253 loss = loss * weight 254 loss = loss.mean() 255 return loss