Source code for lightning_ir.loss.pairwise

  1"""
  2Pairwise loss functions for the Lightning IR framework.
  3
  4This module contains loss functions that operate on pairs of items,
  5comparing positive and negative examples.
  6"""
  7
  8from __future__ import annotations
  9
 10from typing import TYPE_CHECKING, Literal
 11
 12import torch
 13
 14from .base import PairwiseLossFunction
 15
 16if TYPE_CHECKING:
 17    from ..base import LightningIROutput
 18    from ..data import TrainBatch
 19
 20
[docs] 21class MarginMSE(PairwiseLossFunction): 22 """Mean Squared Error loss with a margin for pairwise ranking tasks. 23 24 MarginMSE optimizes pairwise ranking by penalizing the squared difference between the predicted score margin of a 25 positive and negative document and a target margin. This target margin can be a fixed constant or dynamically 26 derived from the difference in ground truth or teacher scores, making it particularly effective for knowledge 27 distillation tasks where a student model learns to replicate the score distances of a stronger teacher model. 28 29 Originally proposed in: `Improving Efficient Neural Ranking Models with Cross-Architecture Knowledge Distillation \ 30 <https://arxiv.org/abs/2010.02666>`_ 31 """ 32
[docs] 33 def __init__(self, margin: float | Literal["scores"] = 1.0): 34 """Initialize the MarginMSE loss function. 35 36 Args: 37 margin (float | Literal["scores"]): The margin value for the loss. 38 """ 39 self.margin = margin
40
[docs] 41 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 42 """Compute the MarginMSE loss. 43 44 Args: 45 output (LightningIROutput): The output from the model containing scores. 46 batch (TrainBatch): The training batch containing targets. 47 Returns: 48 torch.Tensor: The computed loss. 49 Raises: 50 ValueError: If the margin type is invalid. 51 """ 52 scores = self.process_scores(output) 53 targets = self.process_targets(scores, batch) 54 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets) 55 pos = scores[query_idcs, pos_idcs] 56 neg = scores[query_idcs, neg_idcs] 57 margin = pos - neg 58 if isinstance(self.margin, float): 59 target_margin = torch.tensor(self.margin, device=scores.device) 60 elif self.margin == "scores": 61 target_margin = targets[query_idcs, pos_idcs] - targets[query_idcs, neg_idcs] 62 else: 63 raise ValueError("invalid margin type") 64 loss = torch.nn.functional.mse_loss(margin, target_margin) 65 return loss
66 67
[docs] 68class ConstantMarginMSE(MarginMSE): 69 """Constant Margin MSE loss for pairwise ranking tasks with a fixed margin.""" 70
[docs] 71 def __init__(self, margin: float = 1.0): 72 """Initialize the ConstantMarginMSE loss function. 73 74 Args: 75 margin (float): The fixed margin value for the loss. 76 """ 77 super().__init__(margin)
78 79
[docs] 80class SupervisedMarginMSE(MarginMSE): 81 """Supervised Margin MSE loss for pairwise ranking tasks with a dynamic margin.""" 82
[docs] 83 def __init__(self): 84 """Initialize the SupervisedMarginMSE loss function.""" 85 super().__init__("scores")
86 87
[docs] 88class RankNet(PairwiseLossFunction): 89 """RankNet loss function for pairwise ranking tasks. 90 91 RankNet optimizes pairwise ranking by modeling the probability that a positive document should be ranked higher 92 than a negative document using a logistic function. It computes the margin between the scores of positive and 93 negative pairs and applies a binary cross-entropy loss to maximize the likelihood of correct pairwise orderings. 94 This approach allows the model to learn from relative comparisons rather than absolute score values. 95 96 Originally proposed in: `Learning to Rank using Gradient Descent \ 97 <https://dl.acm.org/doi/10.1145/1102351.1102363>`_ 98 """ 99
[docs] 100 def __init__(self, temperature: float = 1) -> None: 101 """Initialize the RankNet loss function. 102 Args: 103 temperature (float): Temperature parameter for scaling the scores. 104 """ 105 super().__init__() 106 107 self.temperature = temperature
108
[docs] 109 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 110 """Compute the RankNet loss. 111 112 Args: 113 output (LightningIROutput): The output from the model containing scores. 114 batch (TrainBatch): The training batch containing targets. 115 Returns: 116 torch.Tensor: The computed loss. 117 """ 118 scores = self.process_scores(output) * self.temperature 119 targets = self.process_targets(scores, batch) 120 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets) 121 pos = scores[query_idcs, pos_idcs] 122 neg = scores[query_idcs, neg_idcs] 123 margin = pos - neg 124 loss = torch.nn.functional.binary_cross_entropy_with_logits(margin, torch.ones_like(margin)) 125 return loss