Source code for lightning_ir.loss.pairwise

  1"""
  2Pairwise loss functions for the Lightning IR framework.
  3
  4This module contains loss functions that operate on pairs of items,
  5comparing positive and negative examples.
  6"""
  7
  8from __future__ import annotations
  9
 10from typing import TYPE_CHECKING, Literal
 11
 12import torch
 13
 14from .base import PairwiseLossFunction
 15
 16if TYPE_CHECKING:
 17    from ..base import LightningIROutput
 18    from ..data import TrainBatch
 19
 20
[docs] 21class MarginMSE(PairwiseLossFunction): 22 """Mean Squared Error loss with a margin for pairwise ranking tasks. 23 Originally proposed in: `Improving Efficient Neural Ranking Models with Cross-Architecture Knowledge Distillation \ 24 <https://arxiv.org/abs/2010.02666>`_ 25 """ 26
[docs] 27 def __init__(self, margin: float | Literal["scores"] = 1.0): 28 """Initialize the MarginMSE loss function. 29 30 Args: 31 margin (float | Literal["scores"]): The margin value for the loss. 32 """ 33 self.margin = margin
34
[docs] 35 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 36 """Compute the MarginMSE loss. 37 38 Args: 39 output (LightningIROutput): The output from the model containing scores. 40 batch (TrainBatch): The training batch containing targets. 41 Returns: 42 torch.Tensor: The computed loss. 43 Raises: 44 ValueError: If the margin type is invalid. 45 """ 46 scores = self.process_scores(output) 47 targets = self.process_targets(scores, batch) 48 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets) 49 pos = scores[query_idcs, pos_idcs] 50 neg = scores[query_idcs, neg_idcs] 51 margin = pos - neg 52 if isinstance(self.margin, float): 53 target_margin = torch.tensor(self.margin, device=scores.device) 54 elif self.margin == "scores": 55 target_margin = targets[query_idcs, pos_idcs] - targets[query_idcs, neg_idcs] 56 else: 57 raise ValueError("invalid margin type") 58 loss = torch.nn.functional.mse_loss(margin, target_margin) 59 return loss
60 61
[docs] 62class ConstantMarginMSE(MarginMSE): 63 """Constant Margin MSE loss for pairwise ranking tasks with a fixed margin.""" 64
[docs] 65 def __init__(self, margin: float = 1.0): 66 """Initialize the ConstantMarginMSE loss function. 67 68 Args: 69 margin (float): The fixed margin value for the loss. 70 """ 71 super().__init__(margin)
72 73
[docs] 74class SupervisedMarginMSE(MarginMSE): 75 """Supervised Margin MSE loss for pairwise ranking tasks with a dynamic margin.""" 76
[docs] 77 def __init__(self): 78 """Initialize the SupervisedMarginMSE loss function.""" 79 super().__init__("scores")
80 81
[docs] 82class RankNet(PairwiseLossFunction): 83 """RankNet loss function for pairwise ranking tasks. 84 Originally proposed in: `Learning to Rank using Gradient Descent \ 85 <https://dl.acm.org/doi/10.1145/1102351.1102363>`_ 86 """ 87
[docs] 88 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 89 """Compute the RankNet loss. 90 91 Args: 92 output (LightningIROutput): The output from the model containing scores. 93 batch (TrainBatch): The training batch containing targets. 94 Returns: 95 torch.Tensor: The computed loss. 96 """ 97 scores = self.process_scores(output) 98 targets = self.process_targets(scores, batch) 99 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets) 100 pos = scores[query_idcs, pos_idcs] 101 neg = scores[query_idcs, neg_idcs] 102 margin = pos - neg 103 loss = torch.nn.functional.binary_cross_entropy_with_logits(margin, torch.ones_like(margin)) 104 return loss