1"""
2In-batch loss functions for the Lightning IR framework.
3
4This module contains loss functions that operate on batches of data,
5comparing examples within the same batch for training.
6"""
7
8from __future__ import annotations
9
10from typing import TYPE_CHECKING
11
12import torch
13
14from .base import InBatchLossFunction
15
16if TYPE_CHECKING:
17 from ..base import LightningIROutput
18 from ..data import TrainBatch
19
20
[docs]
21class ScoreBasedInBatchLossFunction(InBatchLossFunction):
22 """Base class for in-batch loss functions that compute in-batch indices based on scores."""
23
[docs]
24 def __init__(self, min_target_diff: float, max_num_neg_samples: int | None = None):
25 """Initialize the ScoreBasedInBatchLossFunction.
26
27 Args:
28 min_target_diff (float): Minimum target difference for negative sampling.
29 max_num_neg_samples (int | None): Maximum number of negative samples.
30 """
31 super().__init__(
32 pos_sampling_technique="first",
33 neg_sampling_technique="all_and_non_first",
34 max_num_neg_samples=max_num_neg_samples,
35 )
36 self.min_target_diff = min_target_diff
37
38 def _sort_mask(
39 self, mask: torch.Tensor, num_queries: int, num_docs: int, output: LightningIROutput, batch: TrainBatch
40 ) -> torch.Tensor:
41 """Sort the mask based on the scores and targets.
42
43 Args:
44 mask (torch.Tensor): The initial mask tensor.
45 num_queries (int): Number of queries in the batch.
46 num_docs (int): Number of documents per query.
47 output (LightningIROutput): The output from the model containing scores.
48 batch (TrainBatch): The training batch containing targets.
49 Returns:
50 torch.Tensor: The sorted mask tensor.
51 """
52 scores = self.process_scores(output)
53 targets = self.process_targets(scores, batch)
54 idcs = targets.argsort(descending=True).argsort().cpu()
55 idcs = idcs + torch.arange(num_queries)[:, None] * num_docs
56 block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs
57 return mask.scatter(1, block_idcs, mask.gather(1, idcs))
58
59 def _get_pos_mask(
60 self,
61 num_queries: int,
62 num_docs: int,
63 max_idx: torch.Tensor,
64 min_idx: torch.Tensor,
65 output: LightningIROutput,
66 batch: TrainBatch,
67 ) -> torch.Tensor:
68 """Get the mask for positive samples and sort it based on scores.
69
70 Args:
71 num_queries (int): Number of queries in the batch.
72 num_docs (int): Number of documents per query.
73 max_idx (torch.Tensor): Maximum index for each query.
74 min_idx (torch.Tensor): Minimum index for each query.
75 output (LightningIROutput): The output from the model containing scores.
76 batch (TrainBatch): The training batch containing targets.
77 Returns:
78 torch.Tensor: A mask tensor indicating the positions of positive samples, sorted by scores.
79 """
80 pos_mask = super()._get_pos_mask(num_queries, num_docs, max_idx, min_idx, output, batch)
81 pos_mask = self._sort_mask(pos_mask, num_queries, num_docs, output, batch)
82 return pos_mask
83
84 def _get_neg_mask(
85 self,
86 num_queries: int,
87 num_docs: int,
88 max_idx: torch.Tensor,
89 min_idx: torch.Tensor,
90 output: LightningIROutput,
91 batch: TrainBatch,
92 ) -> torch.Tensor:
93 """Get the mask for negative samples and sort it based on scores.
94
95 Args:
96 num_queries (int): Number of queries in the batch.
97 num_docs (int): Number of documents per query.
98 max_idx (torch.Tensor): Maximum index for each query.
99 min_idx (torch.Tensor): Minimum index for each query.
100 output (LightningIROutput): The output from the model containing scores.
101 batch (TrainBatch): The training batch containing targets.
102 Returns:
103 torch.Tensor: A mask tensor indicating the positions of negative samples, sorted by scores.
104 """
105 neg_mask = super()._get_neg_mask(num_queries, num_docs, max_idx, min_idx, output, batch)
106 neg_mask = self._sort_mask(neg_mask, num_queries, num_docs, output, batch)
107 scores = self.process_scores(output)
108 targets = self.process_targets(scores, batch).cpu()
109 max_score, _ = targets.max(dim=-1, keepdim=True)
110 score_diff = (max_score - targets).cpu()
111 score_mask = score_diff.ge(self.min_target_diff)
112 block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs
113 neg_mask = neg_mask.scatter(1, block_idcs, score_mask)
114 # num_neg_samples might be different between queries
115 num_neg_samples = neg_mask.sum(dim=1)
116 min_num_neg_samples = num_neg_samples.min()
117 additional_neg_samples = num_neg_samples - min_num_neg_samples
118 for query_idx, neg_samples in enumerate(additional_neg_samples):
119 neg_idcs = neg_mask[query_idx].nonzero().squeeze(1)
120 additional_neg_idcs = neg_idcs[torch.randperm(neg_idcs.shape[0])][:neg_samples]
121 assert neg_mask[query_idx, additional_neg_idcs].all()
122 neg_mask[query_idx, additional_neg_idcs] = False
123 assert neg_mask[query_idx].sum().eq(min_num_neg_samples)
124 return neg_mask
125
126
[docs]
127class InBatchCrossEntropy(InBatchLossFunction):
128 """In-batch cross-entropy loss function for ranking tasks.
129 Originally proposed in: `Fast Single-Class Classification and the Principle of Logit Separation
130 <https://arxiv.org/pdf/1705.10246v1>`_"""
131
[docs]
132 def compute_loss(self, output: LightningIROutput) -> torch.Tensor:
133 """Compute the in-batch cross-entropy loss.
134
135 Args:
136 output (LightningIROutput): The output from the model containing scores.
137 Returns:
138 torch.Tensor: The computed loss.
139 """
140 scores = self.process_scores(output)
141 targets = torch.zeros(scores.shape[0], dtype=torch.long, device=scores.device)
142 loss = torch.nn.functional.cross_entropy(scores, targets)
143 return loss
144
145
[docs]
146class ScoreBasedInBatchCrossEntropy(ScoreBasedInBatchLossFunction):
147 """In-batch cross-entropy loss function based on scores for ranking tasks."""
148
[docs]
149 def compute_loss(self, output: LightningIROutput) -> torch.Tensor:
150 """Compute the in-batch cross-entropy loss based on scores.
151
152 Args:
153 output (LightningIROutput): The output from the model containing scores.
154 Returns:
155 torch.Tensor: The computed loss.
156 """
157 scores = self.process_scores(output)
158 targets = torch.zeros(scores.shape[0], dtype=torch.long, device=scores.device)
159 loss = torch.nn.functional.cross_entropy(scores, targets)
160 return loss