Source code for lightning_ir.loss.approximate

  1"""
  2Approximate ranking loss functions for the Lightning IR framework.
  3
  4This module contains loss functions that use approximation techniques to compute
  5ranking-based metrics like NDCG, MRR, and RankMSE in a differentiable manner.
  6"""
  7
  8from __future__ import annotations
  9
 10from typing import TYPE_CHECKING, Literal
 11
 12import torch
 13
 14from .base import ListwiseLossFunction
 15
 16if TYPE_CHECKING:
 17    from ..base import LightningIROutput
 18    from ..data import TrainBatch
 19
 20
[docs] 21class ApproxLossFunction(ListwiseLossFunction): 22 """Base class for approximate loss functions that compute ranks from scores.""" 23
[docs] 24 def __init__(self, temperature: float = 1) -> None: 25 """Initialize the ApproxLossFunction. 26 27 Args: 28 temperature (float): Temperature parameter for scaling the scores. Defaults to 1. 29 """ 30 super().__init__() 31 self.temperature = temperature
32
[docs] 33 @staticmethod 34 def get_approx_ranks(scores: torch.Tensor, temperature: float) -> torch.Tensor: 35 """Compute approximate ranks from scores. 36 37 Args: 38 scores (torch.Tensor): The input scores. 39 temperature (float): Temperature parameter for scaling the scores. 40 Returns: 41 torch.Tensor: The computed approximate ranks. 42 """ 43 score_diff = scores[:, None] - scores[..., None] 44 normalized_score_diff = torch.sigmoid(score_diff / temperature) 45 # set diagonal to 0 46 normalized_score_diff = normalized_score_diff * (1 - torch.eye(scores.shape[1], device=scores.device)) 47 approx_ranks = normalized_score_diff.sum(-1) + 1 48 return approx_ranks
49 50
[docs] 51class ApproxNDCG(ApproxLossFunction): 52 """Approximate NDCG loss function for ranking tasks. 53 Originally proposed in: `Cumulated Gain-Based Evaluation of IR Techniques \ 54 <https://dl.acm.org/doi/10.1145/582415.582418>`_""" 55
[docs] 56 def __init__(self, temperature: float = 1, scale_gains: bool = True): 57 """Initialize the ApproxNDCG loss function. 58 59 Args: 60 temperature (float): Temperature parameter for scaling the scores. Defaults to 1. 61 scale_gains (bool): Whether to scale the gains. Defaults to True. 62 """ 63 super().__init__(temperature) 64 self.scale_gains = scale_gains
65
[docs] 66 @staticmethod 67 def get_dcg( 68 ranks: torch.Tensor, 69 targets: torch.Tensor, 70 k: int | None = None, 71 scale_gains: bool = True, 72 ) -> torch.Tensor: 73 """Compute the Discounted Cumulative Gain (DCG) for the given ranks and targets. 74 75 Args: 76 ranks (torch.Tensor): The ranks of the items. 77 targets (torch.Tensor): The relevance scores of the items. 78 k (int | None): Optional cutoff for the ranks. If provided, only computes DCG for the top k items. 79 scale_gains (bool): Whether to scale the gains. Defaults to True. 80 Returns: 81 torch.Tensor: The computed DCG values. 82 """ 83 log_ranks = torch.log2(1 + ranks) 84 discounts = 1 / log_ranks 85 if scale_gains: 86 gains = 2**targets - 1 87 else: 88 gains = targets 89 dcgs = gains * discounts 90 if k is not None: 91 dcgs = dcgs.masked_fill(ranks > k, 0) 92 return dcgs.sum(dim=-1)
93
[docs] 94 @staticmethod 95 def get_ndcg( 96 ranks: torch.Tensor, 97 targets: torch.Tensor, 98 k: int | None = None, 99 scale_gains: bool = True, 100 optimal_targets: torch.Tensor | None = None, 101 ) -> torch.Tensor: 102 """Compute the Normalized Discounted Cumulative Gain (NDCG) for the given ranks and targets. 103 104 Args: 105 ranks (torch.Tensor): The ranks of the items. 106 targets (torch.Tensor): The relevance scores of the items. 107 k (int | None): Cutoff for the ranks. If provided, only computes NDCG for the top k items. Defaults to None. 108 scale_gains (bool): Whether to scale the gains. Defaults to True. 109 optimal_targets (torch.Tensor | None): Optional tensor of optimal targets for normalization. If None, uses 110 the targets. Defaults to None. 111 Returns: 112 torch.Tensor: The computed NDCG values. 113 """ 114 targets = targets.clamp(min=0) 115 if optimal_targets is None: 116 optimal_targets = targets 117 optimal_ranks = torch.argsort(torch.argsort(optimal_targets, descending=True)) 118 optimal_ranks = optimal_ranks + 1 119 dcg = ApproxNDCG.get_dcg(ranks, targets, k, scale_gains) 120 idcg = ApproxNDCG.get_dcg(optimal_ranks, optimal_targets, k, scale_gains) 121 ndcg = dcg / (idcg.clamp(min=1e-12)) 122 return ndcg
123
[docs] 124 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 125 """Compute the ApproxNDCG loss. 126 127 Args: 128 output (LightningIROutput): The output from the model containing scores. 129 batch (TrainBatch): The training batch containing targets. 130 Returns: 131 torch.Tensor: The computed loss. 132 """ 133 scores = self.process_scores(output) 134 targets = self.process_targets(scores, batch) 135 approx_ranks = self.get_approx_ranks(scores, self.temperature) 136 ndcg = self.get_ndcg(approx_ranks, targets, k=None, scale_gains=self.scale_gains) 137 loss = 1 - ndcg 138 return loss.mean()
139 140
[docs] 141class ApproxMRR(ApproxLossFunction): 142 """Approximate Mean Reciprocal Rank (MRR) loss function for ranking tasks.""" 143
[docs] 144 def __init__(self, temperature: float = 1): 145 """Initialize the ApproxMRR loss function. 146 147 Args: 148 temperature (float): Temperature parameter for scaling the scores. Defaults to 1. 149 """ 150 super().__init__(temperature)
151
[docs] 152 @staticmethod 153 def get_mrr(ranks: torch.Tensor, targets: torch.Tensor, k: int | None = None) -> torch.Tensor: 154 """Compute the Mean Reciprocal Rank (MRR) for the given ranks and targets. 155 156 Args: 157 ranks (torch.Tensor): The ranks of the items. 158 targets (torch.Tensor): The relevance scores of the items. 159 k (int | None): Optional cutoff for the ranks. If provided, only computes MRR for the top k items. 160 Returns: 161 torch.Tensor: The computed MRR values. 162 """ 163 targets = targets.clamp(None, 1) 164 reciprocal_ranks = 1 / ranks 165 mrr = reciprocal_ranks * targets 166 if k is not None: 167 mrr = mrr.masked_fill(ranks > k, 0) 168 mrr = mrr.max(dim=-1)[0] 169 return mrr
170
[docs] 171 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 172 """Compute the ApproxMRR loss. 173 174 Args: 175 output (LightningIROutput): The output from the model containing scores. 176 batch (TrainBatch): The training batch containing targets. 177 Returns: 178 torch.Tensor: The computed loss. 179 """ 180 scores = self.process_scores(output) 181 targets = self.process_targets(scores, batch) 182 approx_ranks = self.get_approx_ranks(scores, self.temperature) 183 mrr = self.get_mrr(approx_ranks, targets, k=None) 184 loss = 1 - mrr 185 return loss.mean()
186 187
[docs] 188class ApproxRankMSE(ApproxLossFunction): 189 """Approximate Rank Mean Squared Error (RankMSE) loss function for ranking tasks. 190 Originally proposed in: `Rank-DistiLLM: Closing the Effectiveness Gap Between Cross-Encoders and LLMs 191 for Passage Re-ranking <https://link.springer.com/chapter/10.1007/978-3-031-88714-7_31>`_""" 192
[docs] 193 def __init__( 194 self, 195 temperature: float = 1, 196 discount: Literal["log2", "reciprocal"] | None = None, 197 ): 198 """Initialize the ApproxRankMSE loss function. 199 200 Args: 201 temperature (float): Temperature parameter for scaling the scores. Defaults to 1. 202 discount (Literal["log2", "reciprocal"] | None): Discounting strategy for the loss. If None, no discounting 203 is applied. Defaults to None. 204 """ 205 super().__init__(temperature) 206 self.discount = discount
207
[docs] 208 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 209 """Compute the ApproxRankMSE loss. 210 211 Args: 212 output (LightningIROutput): The output from the model containing scores. 213 batch (TrainBatch): The training batch containing targets. 214 Returns: 215 torch.Tensor: The computed loss. 216 """ 217 scores = self.process_scores(output) 218 targets = self.process_targets(scores, batch) 219 approx_ranks = self.get_approx_ranks(scores, self.temperature) 220 ranks = torch.argsort(torch.argsort(targets, descending=True)) + 1 221 loss = torch.nn.functional.mse_loss(approx_ranks, ranks.to(approx_ranks), reduction="none") 222 if self.discount == "log2": 223 weight = 1 / torch.log2(ranks + 1) 224 elif self.discount == "reciprocal": 225 weight = 1 / ranks 226 else: 227 weight = 1 228 loss = loss * weight 229 loss = loss.mean() 230 return loss