Source code for lightning_ir.loss.pairwise
1"""
2Pairwise loss functions for the Lightning IR framework.
3
4This module contains loss functions that operate on pairs of items,
5comparing positive and negative examples.
6"""
7
8from __future__ import annotations
9
10from typing import TYPE_CHECKING, Literal
11
12import torch
13
14from .base import PairwiseLossFunction
15
16if TYPE_CHECKING:
17 from ..base import LightningIROutput
18 from ..data import TrainBatch
19
20
[docs]
21class MarginMSE(PairwiseLossFunction):
22 """Mean Squared Error loss with a margin for pairwise ranking tasks.
23 Originally proposed in: `Improving Efficient Neural Ranking Models with Cross-Architecture Knowledge Distillation \
24 <https://arxiv.org/abs/2010.02666>`_
25 """
26
[docs]
27 def __init__(self, margin: float | Literal["scores"] = 1.0):
28 """Initialize the MarginMSE loss function.
29
30 Args:
31 margin (float | Literal["scores"]): The margin value for the loss.
32 """
33 self.margin = margin
34
[docs]
35 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
36 """Compute the MarginMSE loss.
37
38 Args:
39 output (LightningIROutput): The output from the model containing scores.
40 batch (TrainBatch): The training batch containing targets.
41 Returns:
42 torch.Tensor: The computed loss.
43 Raises:
44 ValueError: If the margin type is invalid.
45 """
46 scores = self.process_scores(output)
47 targets = self.process_targets(scores, batch)
48 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets)
49 pos = scores[query_idcs, pos_idcs]
50 neg = scores[query_idcs, neg_idcs]
51 margin = pos - neg
52 if isinstance(self.margin, float):
53 target_margin = torch.tensor(self.margin, device=scores.device)
54 elif self.margin == "scores":
55 target_margin = targets[query_idcs, pos_idcs] - targets[query_idcs, neg_idcs]
56 else:
57 raise ValueError("invalid margin type")
58 loss = torch.nn.functional.mse_loss(margin, target_margin)
59 return loss
60
61
[docs]
62class ConstantMarginMSE(MarginMSE):
63 """Constant Margin MSE loss for pairwise ranking tasks with a fixed margin."""
64
[docs]
65 def __init__(self, margin: float = 1.0):
66 """Initialize the ConstantMarginMSE loss function.
67
68 Args:
69 margin (float): The fixed margin value for the loss.
70 """
71 super().__init__(margin)
72
73
[docs]
74class SupervisedMarginMSE(MarginMSE):
75 """Supervised Margin MSE loss for pairwise ranking tasks with a dynamic margin."""
76
[docs]
77 def __init__(self):
78 """Initialize the SupervisedMarginMSE loss function."""
79 super().__init__("scores")
80
81
[docs]
82class RankNet(PairwiseLossFunction):
83 """RankNet loss function for pairwise ranking tasks.
84 Originally proposed in: `Learning to Rank using Gradient Descent \
85 <https://dl.acm.org/doi/10.1145/1102351.1102363>`_
86 """
87
[docs]
88 def __init__(self, temperature: float = 1) -> None:
89 super().__init__()
90 """Initialize the RankNet loss function.
91 Args:
92 temperature (float): Temperature parameter for scaling the scores.
93 """
94 self.temperature = temperature
95
[docs]
96 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
97 """Compute the RankNet loss.
98
99 Args:
100 output (LightningIROutput): The output from the model containing scores.
101 batch (TrainBatch): The training batch containing targets.
102 Returns:
103 torch.Tensor: The computed loss.
104 """
105 scores = self.process_scores(output) * self.temperature
106 targets = self.process_targets(scores, batch)
107 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets)
108 pos = scores[query_idcs, pos_idcs]
109 neg = scores[query_idcs, neg_idcs]
110 margin = pos - neg
111 loss = torch.nn.functional.binary_cross_entropy_with_logits(margin, torch.ones_like(margin))
112 return loss