1from __future__ import annotations
2
3from abc import ABC, abstractmethod
4from typing import TYPE_CHECKING, Literal, Tuple
5
6import torch
7
8if TYPE_CHECKING:
9 from ..base import LightningIROutput
10 from ..bi_encoder import BiEncoderOutput
11 from ..data import TrainBatch
12
13
[docs]
14class LossFunction(ABC):
15 @abstractmethod
16 def compute_loss(self, output: LightningIROutput, *args, **kwargs) -> torch.Tensor: ...
17
18 def process_scores(self, output: LightningIROutput) -> torch.Tensor:
19 if output.scores is None:
20 raise ValueError("Expected scores in LightningIROutput")
21 return output.scores
22
23 def process_targets(self, scores: torch.Tensor, batch: TrainBatch) -> torch.Tensor:
24 targets = batch.targets
25 if targets is None:
26 raise ValueError("Expected targets in TrainBatch")
27 if targets.ndim > scores.ndim:
28 return targets.amax(-1)
29 return targets
30
31
[docs]
32class ScoringLossFunction(LossFunction):
33 @abstractmethod
34 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: ...
35
36
[docs]
37class EmbeddingLossFunction(LossFunction):
38 @abstractmethod
39 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: ...
40
41
[docs]
42class PairwiseLossFunction(ScoringLossFunction):
43 def get_pairwise_idcs(self, targets: torch.Tensor) -> Tuple[torch.Tensor, ...]:
44 # positive items are items where label is greater than other label in sample
45 return torch.nonzero(targets[..., None] > targets[:, None], as_tuple=True)
46
47
[docs]
48class ListwiseLossFunction(ScoringLossFunction):
49 pass
50
51
[docs]
52class MarginMSE(PairwiseLossFunction):
[docs]
53 def __init__(self, margin: float | Literal["scores"] = 1.0):
54 self.margin = margin
55
56 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
57 scores = self.process_scores(output)
58 targets = self.process_targets(scores, batch)
59 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets)
60 pos = scores[query_idcs, pos_idcs]
61 neg = scores[query_idcs, neg_idcs]
62 margin = pos - neg
63 if isinstance(self.margin, float):
64 target_margin = torch.tensor(self.margin, device=scores.device)
65 elif self.margin == "scores":
66 target_margin = targets[query_idcs, pos_idcs] - targets[query_idcs, neg_idcs]
67 else:
68 raise ValueError("invalid margin type")
69 loss = torch.nn.functional.mse_loss(margin, target_margin)
70 return loss
71
72
[docs]
73class ConstantMarginMSE(MarginMSE):
[docs]
74 def __init__(self, margin: float = 1.0):
75 super().__init__(margin)
76
77
[docs]
78class SupervisedMarginMSE(MarginMSE):
[docs]
79 def __init__(self):
80 super().__init__("scores")
81
82
[docs]
83class RankNet(PairwiseLossFunction):
84 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
85 scores = self.process_scores(output)
86 targets = self.process_targets(scores, batch)
87 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets)
88 pos = scores[query_idcs, pos_idcs]
89 neg = scores[query_idcs, neg_idcs]
90 margin = pos - neg
91 loss = torch.nn.functional.binary_cross_entropy_with_logits(margin, torch.ones_like(margin))
92 return loss
93
94
[docs]
95class KLDivergence(ListwiseLossFunction):
96 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
97 scores = self.process_scores(output)
98 targets = self.process_targets(scores, batch)
99 scores = torch.nn.functional.log_softmax(scores, dim=-1)
100 targets = torch.nn.functional.log_softmax(targets.to(scores), dim=-1)
101 loss = torch.nn.functional.kl_div(scores, targets, log_target=True, reduction="batchmean")
102 return loss
103
104
[docs]
105class PearsonCorrelation(ListwiseLossFunction):
106 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
107 scores = self.process_scores(output)
108 targets = self.process_targets(scores, batch).to(scores)
109 centered_scores = scores - scores.mean(dim=-1, keepdim=True)
110 centered_targets = targets - targets.mean(dim=-1, keepdim=True)
111 pearson = torch.nn.functional.cosine_similarity(centered_scores, centered_targets, dim=-1)
112 loss = (1 - pearson).mean()
113 return loss
114
115
[docs]
116class InfoNCE(ListwiseLossFunction):
117 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
118 scores = self.process_scores(output)
119 targets = self.process_targets(scores, batch)
120 targets = targets.argmax(dim=1)
121 loss = torch.nn.functional.cross_entropy(scores, targets)
122 return loss
123
124
[docs]
125class ApproxLossFunction(ListwiseLossFunction):
[docs]
126 def __init__(self, temperature: float = 1) -> None:
127 super().__init__()
128 self.temperature = temperature
129
130 @staticmethod
131 def get_approx_ranks(scores: torch.Tensor, temperature: float) -> torch.Tensor:
132 score_diff = scores[:, None] - scores[..., None]
133 normalized_score_diff = torch.sigmoid(score_diff / temperature)
134 # set diagonal to 0
135 normalized_score_diff = normalized_score_diff * (1 - torch.eye(scores.shape[1], device=scores.device))
136 approx_ranks = normalized_score_diff.sum(-1) + 1
137 return approx_ranks
138
139
[docs]
140class ApproxNDCG(ApproxLossFunction):
[docs]
141 def __init__(self, temperature: float = 1, scale_gains: bool = True):
142 super().__init__(temperature)
143 self.scale_gains = scale_gains
144
145 @staticmethod
146 def get_dcg(
147 ranks: torch.Tensor,
148 targets: torch.Tensor,
149 k: int | None = None,
150 scale_gains: bool = True,
151 ) -> torch.Tensor:
152 log_ranks = torch.log2(1 + ranks)
153 discounts = 1 / log_ranks
154 if scale_gains:
155 gains = 2**targets - 1
156 else:
157 gains = targets
158 dcgs = gains * discounts
159 if k is not None:
160 dcgs = dcgs.masked_fill(ranks > k, 0)
161 return dcgs.sum(dim=-1)
162
163 @staticmethod
164 def get_ndcg(
165 ranks: torch.Tensor,
166 targets: torch.Tensor,
167 k: int | None = None,
168 scale_gains: bool = True,
169 optimal_targets: torch.Tensor | None = None,
170 ) -> torch.Tensor:
171 targets = targets.clamp(min=0)
172 if optimal_targets is None:
173 optimal_targets = targets
174 optimal_ranks = torch.argsort(torch.argsort(optimal_targets, descending=True))
175 optimal_ranks = optimal_ranks + 1
176 dcg = ApproxNDCG.get_dcg(ranks, targets, k, scale_gains)
177 idcg = ApproxNDCG.get_dcg(optimal_ranks, optimal_targets, k, scale_gains)
178 ndcg = dcg / (idcg.clamp(min=1e-12))
179 return ndcg
180
181 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
182 scores = self.process_scores(output)
183 scores = self.process_scores(output)
184 targets = self.process_targets(scores, batch)
185 approx_ranks = self.get_approx_ranks(scores, self.temperature)
186 ndcg = self.get_ndcg(approx_ranks, targets, k=None, scale_gains=self.scale_gains)
187 loss = 1 - ndcg
188 return loss.mean()
189
190
[docs]
191class ApproxMRR(ApproxLossFunction):
[docs]
192 def __init__(self, temperature: float = 1):
193 super().__init__(temperature)
194
195 @staticmethod
196 def get_mrr(ranks: torch.Tensor, targets: torch.Tensor, k: int | None = None) -> torch.Tensor:
197 targets = targets.clamp(None, 1)
198 reciprocal_ranks = 1 / ranks
199 mrr = reciprocal_ranks * targets
200 if k is not None:
201 mrr = mrr.masked_fill(ranks > k, 0)
202 mrr = mrr.max(dim=-1)[0]
203 return mrr
204
205 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
206 scores = self.process_scores(output)
207 targets = self.process_targets(scores, batch)
208 approx_ranks = self.get_approx_ranks(scores, self.temperature)
209 mrr = self.get_mrr(approx_ranks, targets, k=None)
210 loss = 1 - mrr
211 return loss.mean()
212
213
[docs]
214class ApproxRankMSE(ApproxLossFunction):
[docs]
215 def __init__(
216 self,
217 temperature: float = 1,
218 discount: Literal["log2", "reciprocal"] | None = None,
219 ):
220 super().__init__(temperature)
221 self.discount = discount
222
223 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
224 scores = self.process_scores(output)
225 targets = self.process_targets(scores, batch)
226 approx_ranks = self.get_approx_ranks(scores, self.temperature)
227 ranks = torch.argsort(torch.argsort(targets, descending=True)) + 1
228 loss = torch.nn.functional.mse_loss(approx_ranks, ranks.to(approx_ranks), reduction="none")
229 if self.discount == "log2":
230 weight = 1 / torch.log2(ranks + 1)
231 elif self.discount == "reciprocal":
232 weight = 1 / ranks
233 else:
234 weight = 1
235 loss = loss * weight
236 loss = loss.mean()
237 return loss
238
239
[docs]
240class NeuralLossFunction(ListwiseLossFunction):
241 # TODO add neural loss functions
242
[docs]
243 def __init__(self, temperature: float = 1, tol: float = 1e-5, max_iter: int = 50) -> None:
244 super().__init__()
245 self.temperature = temperature
246 self.tol = tol
247 self.max_iter = max_iter
248
249 def neural_sort(self, scores: torch.Tensor) -> torch.Tensor:
250 # https://github.com/ermongroup/neuralsort/blob/master/pytorch/neuralsort.py
251 scores = scores.unsqueeze(-1)
252 dim = scores.shape[1]
253 one = torch.ones((dim, 1), device=scores.device)
254
255 A_scores = torch.abs(scores - scores.permute(0, 2, 1))
256 B = torch.matmul(A_scores, torch.matmul(one, torch.transpose(one, 0, 1)))
257 scaling = dim + 1 - 2 * (torch.arange(dim, device=scores.device) + 1)
258 C = torch.matmul(scores, scaling.to(scores).unsqueeze(0))
259
260 P_max = (C - B).permute(0, 2, 1)
261 P_hat = torch.nn.functional.softmax(P_max / self.temperature, dim=-1)
262
263 P_hat = self.sinkhorn_scaling(P_hat)
264
265 return P_hat
266
267 def sinkhorn_scaling(self, mat: torch.Tensor) -> torch.Tensor:
268 # https://github.com/allegro/allRank/blob/master/allrank/models/losses/loss_utils.py#L8
269 idx = 0
270 while True:
271 if (
272 torch.max(torch.abs(mat.sum(dim=2) - 1.0)) < self.tol
273 and torch.max(torch.abs(mat.sum(dim=1) - 1.0)) < self.tol
274 ) or idx > self.max_iter:
275 break
276 mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=1e-12)
277 mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=1e-12)
278 idx += 1
279
280 return mat
281
282 def get_sorted_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
283 permutation_matrix = self.neural_sort(scores)
284 pred_sorted_targets = torch.matmul(permutation_matrix, targets[..., None].to(permutation_matrix)).squeeze(-1)
285 return pred_sorted_targets
286
287
[docs]
288class InBatchLossFunction(LossFunction):
[docs]
289 def __init__(
290 self,
291 pos_sampling_technique: Literal["all", "first"] = "all",
292 neg_sampling_technique: Literal["all", "first", "all_and_non_first"] = "all",
293 max_num_neg_samples: int | None = None,
294 ):
295 super().__init__()
296 self.pos_sampling_technique = pos_sampling_technique
297 self.neg_sampling_technique = neg_sampling_technique
298 self.max_num_neg_samples = max_num_neg_samples
299 if self.neg_sampling_technique == "all_and_non_first" and self.pos_sampling_technique != "first":
300 raise ValueError("all_and_non_first is only valid with pos_sampling_technique first")
301
302 def _get_pos_mask(
303 self,
304 num_queries: int,
305 num_docs: int,
306 max_idx: torch.Tensor,
307 min_idx: torch.Tensor,
308 output: LightningIROutput,
309 batch: TrainBatch,
310 ) -> torch.Tensor:
311 if self.pos_sampling_technique == "all":
312 pos_mask = torch.arange(num_queries * num_docs)[None].greater_equal(min_idx) & torch.arange(
313 num_queries * num_docs
314 )[None].less(max_idx)
315 elif self.pos_sampling_technique == "first":
316 pos_mask = torch.arange(num_queries * num_docs)[None].eq(min_idx)
317 else:
318 raise ValueError("invalid pos sampling technique")
319 return pos_mask
320
321 def _get_neg_mask(
322 self,
323 num_queries: int,
324 num_docs: int,
325 max_idx: torch.Tensor,
326 min_idx: torch.Tensor,
327 output: LightningIROutput,
328 batch: TrainBatch,
329 ) -> torch.Tensor:
330 if self.neg_sampling_technique == "all_and_non_first":
331 neg_mask = torch.arange(num_queries * num_docs)[None].not_equal(min_idx)
332 elif self.neg_sampling_technique == "all":
333 neg_mask = torch.arange(num_queries * num_docs)[None].less(min_idx) | torch.arange(num_queries * num_docs)[
334 None
335 ].greater_equal(max_idx)
336 elif self.neg_sampling_technique == "first":
337 neg_mask = torch.arange(num_queries * num_docs)[None, None].eq(min_idx).any(1) & torch.arange(
338 num_queries * num_docs
339 )[None].ne(min_idx)
340 else:
341 raise ValueError("invalid neg sampling technique")
342 return neg_mask
343
344 def get_ib_idcs(self, output: LightningIROutput, batch: TrainBatch) -> Tuple[torch.Tensor, torch.Tensor]:
345 if output.scores is None:
346 raise ValueError("Expected scores in LightningIROutput")
347 num_queries, num_docs = output.scores.shape
348 min_idx = torch.arange(num_queries)[:, None] * num_docs
349 max_idx = min_idx + num_docs
350 pos_mask = self._get_pos_mask(num_queries, num_docs, max_idx, min_idx, output, batch)
351 neg_mask = self._get_neg_mask(num_queries, num_docs, max_idx, min_idx, output, batch)
352 pos_idcs = pos_mask.nonzero(as_tuple=True)[1]
353 neg_idcs = neg_mask.nonzero(as_tuple=True)[1]
354 if self.max_num_neg_samples is not None:
355 neg_idcs = neg_idcs.view(num_queries, -1)
356 if neg_idcs.shape[-1] > 1:
357 neg_idcs = neg_idcs[:, torch.randperm(neg_idcs.shape[-1])]
358 neg_idcs = neg_idcs[:, : self.max_num_neg_samples]
359 neg_idcs = neg_idcs.reshape(-1)
360 return pos_idcs, neg_idcs
361
362
[docs]
363class ScoreBasedInBatchLossFunction(InBatchLossFunction):
364
[docs]
365 def __init__(self, min_target_diff: float, max_num_neg_samples: int | None = None):
366 super().__init__(
367 pos_sampling_technique="first",
368 neg_sampling_technique="all_and_non_first",
369 max_num_neg_samples=max_num_neg_samples,
370 )
371 self.min_target_diff = min_target_diff
372
373 def _sort_mask(
374 self, mask: torch.Tensor, num_queries: int, num_docs: int, output: LightningIROutput, batch: TrainBatch
375 ) -> torch.Tensor:
376 scores = self.process_scores(output)
377 targets = self.process_targets(scores, batch)
378 idcs = targets.argsort(descending=True).argsort().cpu()
379 idcs = idcs + torch.arange(num_queries)[:, None] * num_docs
380 block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs
381 return mask.scatter(1, block_idcs, mask.gather(1, idcs))
382
383 def _get_pos_mask(
384 self,
385 num_queries: int,
386 num_docs: int,
387 max_idx: torch.Tensor,
388 min_idx: torch.Tensor,
389 output: LightningIROutput,
390 batch: TrainBatch,
391 ) -> torch.Tensor:
392 pos_mask = super()._get_pos_mask(num_queries, num_docs, max_idx, min_idx, output, batch)
393 pos_mask = self._sort_mask(pos_mask, num_queries, num_docs, output, batch)
394 return pos_mask
395
396 def _get_neg_mask(
397 self,
398 num_queries: int,
399 num_docs: int,
400 max_idx: torch.Tensor,
401 min_idx: torch.Tensor,
402 output: LightningIROutput,
403 batch: TrainBatch,
404 ) -> torch.Tensor:
405 neg_mask = super()._get_neg_mask(num_queries, num_docs, max_idx, min_idx, output, batch)
406 neg_mask = self._sort_mask(neg_mask, num_queries, num_docs, output, batch)
407 scores = self.process_scores(output)
408 targets = self.process_targets(scores, batch).cpu()
409 max_score, _ = targets.max(dim=-1, keepdim=True)
410 score_diff = (max_score - targets).cpu()
411 score_mask = score_diff.ge(self.min_target_diff)
412 block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs
413 neg_mask = neg_mask.scatter(1, block_idcs, score_mask)
414 # num_neg_samples might be different between queries
415 num_neg_samples = neg_mask.sum(dim=1)
416 min_num_neg_samples = num_neg_samples.min()
417 additional_neg_samples = num_neg_samples - min_num_neg_samples
418 for query_idx, neg_samples in enumerate(additional_neg_samples):
419 neg_idcs = neg_mask[query_idx].nonzero().squeeze(1)
420 additional_neg_idcs = neg_idcs[torch.randperm(neg_idcs.shape[0])][:neg_samples]
421 assert neg_mask[query_idx, additional_neg_idcs].all()
422 neg_mask[query_idx, additional_neg_idcs] = False
423 assert neg_mask[query_idx].sum().eq(min_num_neg_samples)
424 return neg_mask
425
426
[docs]
427class InBatchCrossEntropy(InBatchLossFunction):
428 def compute_loss(self, output: LightningIROutput) -> torch.Tensor:
429 scores = self.process_scores(output)
430 targets = torch.zeros(scores.shape[0], dtype=torch.long, device=scores.device)
431 loss = torch.nn.functional.cross_entropy(scores, targets)
432 return loss
433
434
[docs]
435class ScoreBasedInBatchCrossEntropy(ScoreBasedInBatchLossFunction):
436
437 def compute_loss(self, output: LightningIROutput) -> torch.Tensor:
438 scores = self.process_scores(output)
439 targets = torch.zeros(scores.shape[0], dtype=torch.long, device=scores.device)
440 loss = torch.nn.functional.cross_entropy(scores, targets)
441 return loss
442
443
[docs]
444class RegularizationLossFunction(EmbeddingLossFunction):
[docs]
445 def __init__(self, query_weight: float = 1e-4, doc_weight: float = 1e-4) -> None:
446 self.query_weight = query_weight
447 self.doc_weight = doc_weight
448
449 def process_embeddings(self, output: BiEncoderOutput) -> Tuple[torch.Tensor, torch.Tensor]:
450 query_embeddings = output.query_embeddings
451 doc_embeddings = output.doc_embeddings
452 if query_embeddings is None:
453 raise ValueError("Expected query_embeddings in BiEncoderOutput")
454 if doc_embeddings is None:
455 raise ValueError("Expected doc_embeddings in BiEncoderOutput")
456 return query_embeddings.embeddings, doc_embeddings.embeddings
457
458
[docs]
459class L2Regularization(RegularizationLossFunction):
460 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor:
461 query_embeddings, doc_embeddings = self.process_embeddings(output)
462 query_loss = self.query_weight * query_embeddings.norm(dim=-1).mean()
463 doc_loss = self.doc_weight * doc_embeddings.norm(dim=-1).mean()
464 loss = query_loss + doc_loss
465 return loss
466
467
[docs]
468class L1Regularization(RegularizationLossFunction):
469 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor:
470 query_embeddings, doc_embeddings = self.process_embeddings(output)
471 query_loss = self.query_weight * query_embeddings.norm(p=1, dim=-1).mean()
472 doc_loss = self.doc_weight * doc_embeddings.norm(p=1, dim=-1).mean()
473 loss = query_loss + doc_loss
474 return loss
475
476
[docs]
477class FLOPSRegularization(RegularizationLossFunction):
478 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor:
479 query_embeddings, doc_embeddings = self.process_embeddings(output)
480 query_loss = torch.sum(torch.mean(torch.abs(query_embeddings), dim=0) ** 2)
481 doc_loss = torch.sum(torch.mean(torch.abs(doc_embeddings), dim=0) ** 2)
482 loss = self.query_weight * query_loss + self.doc_weight * doc_loss
483 return loss