Source code for lightning_ir.loss.pairwise

  1"""
  2Pairwise loss functions for the Lightning IR framework.
  3
  4This module contains loss functions that operate on pairs of items,
  5comparing positive and negative examples.
  6"""
  7
  8from __future__ import annotations
  9
 10from typing import TYPE_CHECKING, Literal
 11
 12import torch
 13
 14from .base import PairwiseLossFunction
 15
 16if TYPE_CHECKING:
 17    from ..base import LightningIROutput
 18    from ..data import TrainBatch
 19
 20
[docs] 21class MarginMSE(PairwiseLossFunction): 22 """Mean Squared Error loss with a margin for pairwise ranking tasks. 23 Originally proposed in: `Improving Efficient Neural Ranking Models with Cross-Architecture Knowledge Distillation \ 24 <https://arxiv.org/abs/2010.02666>`_ 25 """ 26
[docs] 27 def __init__(self, margin: float | Literal["scores"] = 1.0): 28 """Initialize the MarginMSE loss function. 29 30 Args: 31 margin (float | Literal["scores"]): The margin value for the loss. 32 """ 33 self.margin = margin
34
[docs] 35 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 36 """Compute the MarginMSE loss. 37 38 Args: 39 output (LightningIROutput): The output from the model containing scores. 40 batch (TrainBatch): The training batch containing targets. 41 Returns: 42 torch.Tensor: The computed loss. 43 Raises: 44 ValueError: If the margin type is invalid. 45 """ 46 scores = self.process_scores(output) 47 targets = self.process_targets(scores, batch) 48 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets) 49 pos = scores[query_idcs, pos_idcs] 50 neg = scores[query_idcs, neg_idcs] 51 margin = pos - neg 52 if isinstance(self.margin, float): 53 target_margin = torch.tensor(self.margin, device=scores.device) 54 elif self.margin == "scores": 55 target_margin = targets[query_idcs, pos_idcs] - targets[query_idcs, neg_idcs] 56 else: 57 raise ValueError("invalid margin type") 58 loss = torch.nn.functional.mse_loss(margin, target_margin) 59 return loss
60 61
[docs] 62class ConstantMarginMSE(MarginMSE): 63 """Constant Margin MSE loss for pairwise ranking tasks with a fixed margin.""" 64
[docs] 65 def __init__(self, margin: float = 1.0): 66 """Initialize the ConstantMarginMSE loss function. 67 68 Args: 69 margin (float): The fixed margin value for the loss. 70 """ 71 super().__init__(margin)
72 73
[docs] 74class SupervisedMarginMSE(MarginMSE): 75 """Supervised Margin MSE loss for pairwise ranking tasks with a dynamic margin.""" 76
[docs] 77 def __init__(self): 78 """Initialize the SupervisedMarginMSE loss function.""" 79 super().__init__("scores")
80 81
[docs] 82class RankNet(PairwiseLossFunction): 83 """RankNet loss function for pairwise ranking tasks. 84 Originally proposed in: `Learning to Rank using Gradient Descent \ 85 <https://dl.acm.org/doi/10.1145/1102351.1102363>`_ 86 """ 87
[docs] 88 def __init__(self, temperature: float = 1) -> None: 89 super().__init__() 90 """Initialize the RankNet loss function. 91 Args: 92 temperature (float): Temperature parameter for scaling the scores. 93 """ 94 self.temperature = temperature
95
[docs] 96 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 97 """Compute the RankNet loss. 98 99 Args: 100 output (LightningIROutput): The output from the model containing scores. 101 batch (TrainBatch): The training batch containing targets. 102 Returns: 103 torch.Tensor: The computed loss. 104 """ 105 scores = self.process_scores(output) * self.temperature 106 targets = self.process_targets(scores, batch) 107 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets) 108 pos = scores[query_idcs, pos_idcs] 109 neg = scores[query_idcs, neg_idcs] 110 margin = pos - neg 111 loss = torch.nn.functional.binary_cross_entropy_with_logits(margin, torch.ones_like(margin)) 112 return loss