1"""
2Base classes and abstract interfaces for loss functions in the Lightning IR framework.
3
4This module defines the abstract base classes and common functionality for all loss functions
5used in the Lightning IR framework.
6"""
7
8from __future__ import annotations
9
10from abc import ABC, abstractmethod
11from typing import TYPE_CHECKING, Literal, Tuple
12
13import torch
14
15if TYPE_CHECKING:
16 from ..base import LightningIROutput
17 from ..bi_encoder import BiEncoderOutput
18 from ..data import TrainBatch
19
20
[docs]
21class LossFunction(ABC):
22 """Base class for loss functions in the Lightning IR framework."""
23
[docs]
24 @abstractmethod
25 def compute_loss(self, output: LightningIROutput, *args, **kwargs) -> torch.Tensor:
26 """Compute the loss for the given output.
27
28 Args:
29 output (LightningIROutput): The output from the model.
30 Returns:
31 torch.Tensor: The computed loss.
32 """
33 ...
34
[docs]
35 def process_scores(self, output: LightningIROutput) -> torch.Tensor:
36 """Process the scores from the output.
37
38 Args:
39 output (LightningIROutput): The output from the model.
40 Returns:
41 torch.Tensor: The scores tensor.
42 """
43 if output.scores is None:
44 raise ValueError("Expected scores in LightningIROutput")
45 return output.scores
46
[docs]
47 def process_targets(self, scores: torch.Tensor, batch: TrainBatch) -> torch.Tensor:
48 """Process the targets from the batch.
49
50 Args:
51 scores (torch.Tensor): The scores tensor.
52 batch (TrainBatch): The training batch.
53 Returns:
54 torch.Tensor: The processed targets tensor.
55 """
56 targets = batch.targets
57 if targets is None:
58 raise ValueError("Expected targets in TrainBatch")
59 if targets.ndim > scores.ndim:
60 return targets.amax(-1)
61 return targets
62
63
[docs]
64class ScoringLossFunction(LossFunction):
65 """Base class for loss functions that operate on scores."""
66
[docs]
67 @abstractmethod
68 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
69 """Compute the loss based on the scores and targets in the output and batch.
70
71 Args:
72 output (LightningIROutput): The output from the model containing scores.
73 batch (TrainBatch): The training batch containing targets.
74 Returns:
75 torch.Tensor: The computed loss.
76 """
77 ...
78
79
[docs]
80class EmbeddingLossFunction(LossFunction):
81 """Base class for loss functions that operate on embeddings."""
82
[docs]
83 @abstractmethod
84 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor:
85 """Compute the loss based on the embeddings in the output.
86
87 Args:
88 output (BiEncoderOutput): The output from the model containing query and document embeddings.
89 Returns:
90 torch.Tensor: The computed loss.
91 """
92 ...
93
94
[docs]
95class RegularizationLossFunction(EmbeddingLossFunction):
96 """Base class for regularization loss functions that operate on embeddings."""
97
[docs]
98 def __init__(self, query_weight: float = 1e-4, doc_weight: float = 1e-4) -> None:
99 """Initialize the RegularizationLossFunction.
100
101 Args:
102 query_weight (float): Weight for the query embeddings regularization. Defaults to 1e-4.
103 doc_weight (float): Weight for the document embeddings regularization. Defaults to 1e-4.
104 """
105 self.query_weight = query_weight
106 self.doc_weight = doc_weight
107
[docs]
108 def process_embeddings(self, output: BiEncoderOutput) -> Tuple[torch.Tensor, torch.Tensor]:
109 """Process the embeddings from the output.
110
111 Args:
112 output (BiEncoderOutput): The output from the model containing query and document embeddings.
113 Returns:
114 Tuple[torch.Tensor, torch.Tensor]: The processed query and document embeddings.
115 Raises:
116 ValueError: If query_embeddings are not present in the output.
117 ValueError: If doc_embeddings are not present in the output.
118 """
119 query_embeddings = output.query_embeddings
120 doc_embeddings = output.doc_embeddings
121 if query_embeddings is None:
122 raise ValueError("Expected query_embeddings in BiEncoderOutput")
123 if doc_embeddings is None:
124 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
125 return query_embeddings.embeddings, doc_embeddings.embeddings
126
127
[docs]
128class PairwiseLossFunction(ScoringLossFunction):
129 """Base class for pairwise loss functions."""
130
[docs]
131 def get_pairwise_idcs(self, targets: torch.Tensor) -> Tuple[torch.Tensor, ...]:
132 """Get pairwise indices for positive and negative samples based on targets.
133
134 Args:
135 targets (torch.Tensor): The targets tensor containing relevance labels.
136 Returns:
137 Tuple[torch.Tensor, ...]: Indices of positive and negative samples.
138 """
139 # positive items are items where label is greater than other label in sample
140 return torch.nonzero(targets[..., None] > targets[:, None], as_tuple=True)
141
142
[docs]
143class ListwiseLossFunction(ScoringLossFunction):
144 """Base class for listwise loss functions."""
145
146 pass
147
148
[docs]
149class InBatchLossFunction(LossFunction):
150 """Base class for in-batch loss functions that compute in-batch indices for positive and negative samples."""
151
[docs]
152 def __init__(
153 self,
154 pos_sampling_technique: Literal["all", "first"] = "all",
155 neg_sampling_technique: Literal["all", "first", "all_and_non_first"] = "all",
156 max_num_neg_samples: int | None = None,
157 ):
158 """Initialize the InBatchLossFunction.
159
160 Args:
161 pos_sampling_technique (Literal["all", "first"]): Technique for positive sample sampling.
162 neg_sampling_technique (Literal["all", "first", "all_and_non_first"]): Technique for negative sample
163 sampling.
164 max_num_neg_samples (int | None): Maximum number of negative samples to consider. If None, all negative
165 samples are considered.
166 Raises:
167 ValueError: If the negative sampling technique is invalid for the given positive sampling technique.
168 """
169 super().__init__()
170 self.pos_sampling_technique = pos_sampling_technique
171 self.neg_sampling_technique = neg_sampling_technique
172 self.max_num_neg_samples = max_num_neg_samples
173 if self.neg_sampling_technique == "all_and_non_first" and self.pos_sampling_technique != "first":
174 raise ValueError("all_and_non_first is only valid with pos_sampling_technique first")
175
176 def _get_pos_mask(
177 self,
178 num_queries: int,
179 num_docs: int,
180 max_idx: torch.Tensor,
181 min_idx: torch.Tensor,
182 output: LightningIROutput,
183 batch: TrainBatch,
184 ) -> torch.Tensor:
185 """Get the mask for positive samples based on the sampling technique.
186
187 Args:
188 num_queries (int): Number of queries in the batch.
189 num_docs (int): Number of documents per query.
190 max_idx (torch.Tensor): Maximum index for each query.
191 min_idx (torch.Tensor): Minimum index for each query.
192 output (LightningIROutput): The output from the model containing scores.
193 batch (TrainBatch): The training batch containing targets.
194 Returns:
195 torch.Tensor: A mask tensor indicating the positions of positive samples.
196 Raises:
197 ValueError: If the positive sampling technique is invalid.
198 """
199 if self.pos_sampling_technique == "all":
200 pos_mask = torch.arange(num_queries * num_docs)[None].greater_equal(min_idx) & torch.arange(
201 num_queries * num_docs
202 )[None].less(max_idx)
203 elif self.pos_sampling_technique == "first":
204 pos_mask = torch.arange(num_queries * num_docs)[None].eq(min_idx)
205 else:
206 raise ValueError("invalid pos sampling technique")
207 return pos_mask
208
209 def _get_neg_mask(
210 self,
211 num_queries: int,
212 num_docs: int,
213 max_idx: torch.Tensor,
214 min_idx: torch.Tensor,
215 output: LightningIROutput,
216 batch: TrainBatch,
217 ) -> torch.Tensor:
218 """Get the mask for negative samples based on the sampling technique.
219
220 Args:
221 num_queries (int): Number of queries in the batch.
222 num_docs (int): Number of documents per query.
223 max_idx (torch.Tensor): Maximum index for each query.
224 min_idx (torch.Tensor): Minimum index for each query.
225 output (LightningIROutput): The output from the model containing scores.
226 batch (TrainBatch): The training batch containing targets.
227 Returns:
228 torch.Tensor: A mask tensor indicating the positions of negative samples.
229 Raises:
230 ValueError: If the negative sampling technique is invalid.
231 """
232 if self.neg_sampling_technique == "all_and_non_first":
233 neg_mask = torch.arange(num_queries * num_docs)[None].not_equal(min_idx)
234 elif self.neg_sampling_technique == "all":
235 neg_mask = torch.arange(num_queries * num_docs)[None].less(min_idx) | torch.arange(num_queries * num_docs)[
236 None
237 ].greater_equal(max_idx)
238 elif self.neg_sampling_technique == "first":
239 neg_mask = torch.arange(num_queries * num_docs)[None, None].eq(min_idx).any(1) & torch.arange(
240 num_queries * num_docs
241 )[None].ne(min_idx)
242 else:
243 raise ValueError("invalid neg sampling technique")
244 return neg_mask
245
[docs]
246 def get_ib_idcs(self, output: LightningIROutput, batch: TrainBatch) -> Tuple[torch.Tensor, torch.Tensor]:
247 """Get in-batch indices for positive and negative samples.
248
249 Args:
250 output (LightningIROutput): The output from the model containing scores.
251 batch (TrainBatch): The training batch containing targets.
252 Returns:
253 Tuple[torch.Tensor, torch.Tensor]: Indices of positive and negative samples.
254 Raises:
255 ValueError: If scores are not present in the output.
256 """
257 if output.scores is None:
258 raise ValueError("Expected scores in LightningIROutput")
259 num_queries, num_docs = output.scores.shape
260 min_idx = torch.arange(num_queries)[:, None] * num_docs
261 max_idx = min_idx + num_docs
262 pos_mask = self._get_pos_mask(num_queries, num_docs, max_idx, min_idx, output, batch)
263 neg_mask = self._get_neg_mask(num_queries, num_docs, max_idx, min_idx, output, batch)
264 pos_idcs = pos_mask.nonzero(as_tuple=True)[1]
265 neg_idcs = neg_mask.nonzero(as_tuple=True)[1]
266 if self.max_num_neg_samples is not None:
267 neg_idcs = neg_idcs.view(num_queries, -1)
268 if neg_idcs.shape[-1] > 1:
269 neg_idcs = neg_idcs[:, torch.randperm(neg_idcs.shape[-1])]
270 neg_idcs = neg_idcs[:, : self.max_num_neg_samples]
271 neg_idcs = neg_idcs.reshape(-1)
272 return pos_idcs, neg_idcs