Source code for lightning_ir.loss.loss

  1from __future__ import annotations
  2
  3from abc import ABC, abstractmethod
  4from typing import TYPE_CHECKING, Literal, Tuple
  5
  6import torch
  7
  8if TYPE_CHECKING:
  9    from ..base import LightningIROutput
 10    from ..bi_encoder import BiEncoderOutput
 11    from ..data import TrainBatch
 12
 13
[docs] 14class LossFunction(ABC): 15 @abstractmethod 16 def compute_loss(self, output: LightningIROutput, *args, **kwargs) -> torch.Tensor: ... 17 18 def process_scores(self, output: LightningIROutput) -> torch.Tensor: 19 if output.scores is None: 20 raise ValueError("Expected scores in LightningIROutput") 21 return output.scores 22 23 def process_targets(self, scores: torch.Tensor, batch: TrainBatch) -> torch.Tensor: 24 targets = batch.targets 25 if targets is None: 26 raise ValueError("Expected targets in TrainBatch") 27 if targets.ndim > scores.ndim: 28 return targets.amax(-1) 29 return targets
30 31
[docs] 32class ScoringLossFunction(LossFunction): 33 @abstractmethod 34 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: ...
35 36
[docs] 37class EmbeddingLossFunction(LossFunction): 38 @abstractmethod 39 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: ...
40 41
[docs] 42class PairwiseLossFunction(ScoringLossFunction): 43 def get_pairwise_idcs(self, targets: torch.Tensor) -> Tuple[torch.Tensor, ...]: 44 # positive items are items where label is greater than other label in sample 45 return torch.nonzero(targets[..., None] > targets[:, None], as_tuple=True)
46 47
[docs] 48class ListwiseLossFunction(ScoringLossFunction): 49 pass
50 51
[docs] 52class MarginMSE(PairwiseLossFunction):
[docs] 53 def __init__(self, margin: float | Literal["scores"] = 1.0): 54 self.margin = margin
55 56 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 57 scores = self.process_scores(output) 58 targets = self.process_targets(scores, batch) 59 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets) 60 pos = scores[query_idcs, pos_idcs] 61 neg = scores[query_idcs, neg_idcs] 62 margin = pos - neg 63 if isinstance(self.margin, float): 64 target_margin = torch.tensor(self.margin, device=scores.device) 65 elif self.margin == "scores": 66 target_margin = targets[query_idcs, pos_idcs] - targets[query_idcs, neg_idcs] 67 else: 68 raise ValueError("invalid margin type") 69 loss = torch.nn.functional.mse_loss(margin, target_margin) 70 return loss
71 72
[docs] 73class ConstantMarginMSE(MarginMSE):
[docs] 74 def __init__(self, margin: float = 1.0): 75 super().__init__(margin)
76 77
[docs] 78class SupervisedMarginMSE(MarginMSE):
[docs] 79 def __init__(self): 80 super().__init__("scores")
81 82
[docs] 83class RankNet(PairwiseLossFunction): 84 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 85 scores = self.process_scores(output) 86 targets = self.process_targets(scores, batch) 87 query_idcs, pos_idcs, neg_idcs = self.get_pairwise_idcs(targets) 88 pos = scores[query_idcs, pos_idcs] 89 neg = scores[query_idcs, neg_idcs] 90 margin = pos - neg 91 loss = torch.nn.functional.binary_cross_entropy_with_logits(margin, torch.ones_like(margin)) 92 return loss
93 94
[docs] 95class KLDivergence(ListwiseLossFunction): 96 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 97 scores = self.process_scores(output) 98 targets = self.process_targets(scores, batch) 99 scores = torch.nn.functional.log_softmax(scores, dim=-1) 100 targets = torch.nn.functional.log_softmax(targets.to(scores), dim=-1) 101 loss = torch.nn.functional.kl_div(scores, targets, log_target=True, reduction="batchmean") 102 return loss
103 104
[docs] 105class PearsonCorrelation(ListwiseLossFunction): 106 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 107 scores = self.process_scores(output) 108 targets = self.process_targets(scores, batch).to(scores) 109 centered_scores = scores - scores.mean(dim=-1, keepdim=True) 110 centered_targets = targets - targets.mean(dim=-1, keepdim=True) 111 pearson = torch.nn.functional.cosine_similarity(centered_scores, centered_targets, dim=-1) 112 loss = (1 - pearson).mean() 113 return loss
114 115
[docs] 116class InfoNCE(ListwiseLossFunction): 117 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 118 scores = self.process_scores(output) 119 targets = self.process_targets(scores, batch) 120 targets = targets.argmax(dim=1) 121 loss = torch.nn.functional.cross_entropy(scores, targets) 122 return loss
123 124
[docs] 125class ApproxLossFunction(ListwiseLossFunction):
[docs] 126 def __init__(self, temperature: float = 1) -> None: 127 super().__init__() 128 self.temperature = temperature
129 130 @staticmethod 131 def get_approx_ranks(scores: torch.Tensor, temperature: float) -> torch.Tensor: 132 score_diff = scores[:, None] - scores[..., None] 133 normalized_score_diff = torch.sigmoid(score_diff / temperature) 134 # set diagonal to 0 135 normalized_score_diff = normalized_score_diff * (1 - torch.eye(scores.shape[1], device=scores.device)) 136 approx_ranks = normalized_score_diff.sum(-1) + 1 137 return approx_ranks
138 139
[docs] 140class ApproxNDCG(ApproxLossFunction):
[docs] 141 def __init__(self, temperature: float = 1, scale_gains: bool = True): 142 super().__init__(temperature) 143 self.scale_gains = scale_gains
144 145 @staticmethod 146 def get_dcg( 147 ranks: torch.Tensor, 148 targets: torch.Tensor, 149 k: int | None = None, 150 scale_gains: bool = True, 151 ) -> torch.Tensor: 152 log_ranks = torch.log2(1 + ranks) 153 discounts = 1 / log_ranks 154 if scale_gains: 155 gains = 2**targets - 1 156 else: 157 gains = targets 158 dcgs = gains * discounts 159 if k is not None: 160 dcgs = dcgs.masked_fill(ranks > k, 0) 161 return dcgs.sum(dim=-1) 162 163 @staticmethod 164 def get_ndcg( 165 ranks: torch.Tensor, 166 targets: torch.Tensor, 167 k: int | None = None, 168 scale_gains: bool = True, 169 optimal_targets: torch.Tensor | None = None, 170 ) -> torch.Tensor: 171 targets = targets.clamp(min=0) 172 if optimal_targets is None: 173 optimal_targets = targets 174 optimal_ranks = torch.argsort(torch.argsort(optimal_targets, descending=True)) 175 optimal_ranks = optimal_ranks + 1 176 dcg = ApproxNDCG.get_dcg(ranks, targets, k, scale_gains) 177 idcg = ApproxNDCG.get_dcg(optimal_ranks, optimal_targets, k, scale_gains) 178 ndcg = dcg / (idcg.clamp(min=1e-12)) 179 return ndcg 180 181 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 182 scores = self.process_scores(output) 183 scores = self.process_scores(output) 184 targets = self.process_targets(scores, batch) 185 approx_ranks = self.get_approx_ranks(scores, self.temperature) 186 ndcg = self.get_ndcg(approx_ranks, targets, k=None, scale_gains=self.scale_gains) 187 loss = 1 - ndcg 188 return loss.mean()
189 190
[docs] 191class ApproxMRR(ApproxLossFunction):
[docs] 192 def __init__(self, temperature: float = 1): 193 super().__init__(temperature)
194 195 @staticmethod 196 def get_mrr(ranks: torch.Tensor, targets: torch.Tensor, k: int | None = None) -> torch.Tensor: 197 targets = targets.clamp(None, 1) 198 reciprocal_ranks = 1 / ranks 199 mrr = reciprocal_ranks * targets 200 if k is not None: 201 mrr = mrr.masked_fill(ranks > k, 0) 202 mrr = mrr.max(dim=-1)[0] 203 return mrr 204 205 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 206 scores = self.process_scores(output) 207 targets = self.process_targets(scores, batch) 208 approx_ranks = self.get_approx_ranks(scores, self.temperature) 209 mrr = self.get_mrr(approx_ranks, targets, k=None) 210 loss = 1 - mrr 211 return loss.mean()
212 213
[docs] 214class ApproxRankMSE(ApproxLossFunction):
[docs] 215 def __init__( 216 self, 217 temperature: float = 1, 218 discount: Literal["log2", "reciprocal"] | None = None, 219 ): 220 super().__init__(temperature) 221 self.discount = discount
222 223 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 224 scores = self.process_scores(output) 225 targets = self.process_targets(scores, batch) 226 approx_ranks = self.get_approx_ranks(scores, self.temperature) 227 ranks = torch.argsort(torch.argsort(targets, descending=True)) + 1 228 loss = torch.nn.functional.mse_loss(approx_ranks, ranks.to(approx_ranks), reduction="none") 229 if self.discount == "log2": 230 weight = 1 / torch.log2(ranks + 1) 231 elif self.discount == "reciprocal": 232 weight = 1 / ranks 233 else: 234 weight = 1 235 loss = loss * weight 236 loss = loss.mean() 237 return loss
238 239
[docs] 240class NeuralLossFunction(ListwiseLossFunction): 241 # TODO add neural loss functions 242
[docs] 243 def __init__(self, temperature: float = 1, tol: float = 1e-5, max_iter: int = 50) -> None: 244 super().__init__() 245 self.temperature = temperature 246 self.tol = tol 247 self.max_iter = max_iter
248 249 def neural_sort(self, scores: torch.Tensor) -> torch.Tensor: 250 # https://github.com/ermongroup/neuralsort/blob/master/pytorch/neuralsort.py 251 scores = scores.unsqueeze(-1) 252 dim = scores.shape[1] 253 one = torch.ones((dim, 1), device=scores.device) 254 255 A_scores = torch.abs(scores - scores.permute(0, 2, 1)) 256 B = torch.matmul(A_scores, torch.matmul(one, torch.transpose(one, 0, 1))) 257 scaling = dim + 1 - 2 * (torch.arange(dim, device=scores.device) + 1) 258 C = torch.matmul(scores, scaling.to(scores).unsqueeze(0)) 259 260 P_max = (C - B).permute(0, 2, 1) 261 P_hat = torch.nn.functional.softmax(P_max / self.temperature, dim=-1) 262 263 P_hat = self.sinkhorn_scaling(P_hat) 264 265 return P_hat 266 267 def sinkhorn_scaling(self, mat: torch.Tensor) -> torch.Tensor: 268 # https://github.com/allegro/allRank/blob/master/allrank/models/losses/loss_utils.py#L8 269 idx = 0 270 while True: 271 if ( 272 torch.max(torch.abs(mat.sum(dim=2) - 1.0)) < self.tol 273 and torch.max(torch.abs(mat.sum(dim=1) - 1.0)) < self.tol 274 ) or idx > self.max_iter: 275 break 276 mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=1e-12) 277 mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=1e-12) 278 idx += 1 279 280 return mat 281 282 def get_sorted_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 283 permutation_matrix = self.neural_sort(scores) 284 pred_sorted_targets = torch.matmul(permutation_matrix, targets[..., None].to(permutation_matrix)).squeeze(-1) 285 return pred_sorted_targets
286 287
[docs] 288class InBatchLossFunction(LossFunction):
[docs] 289 def __init__( 290 self, 291 pos_sampling_technique: Literal["all", "first"] = "all", 292 neg_sampling_technique: Literal["all", "first", "all_and_non_first"] = "all", 293 max_num_neg_samples: int | None = None, 294 ): 295 super().__init__() 296 self.pos_sampling_technique = pos_sampling_technique 297 self.neg_sampling_technique = neg_sampling_technique 298 self.max_num_neg_samples = max_num_neg_samples 299 if self.neg_sampling_technique == "all_and_non_first" and self.pos_sampling_technique != "first": 300 raise ValueError("all_and_non_first is only valid with pos_sampling_technique first")
301 302 def _get_pos_mask( 303 self, 304 num_queries: int, 305 num_docs: int, 306 max_idx: torch.Tensor, 307 min_idx: torch.Tensor, 308 output: LightningIROutput, 309 batch: TrainBatch, 310 ) -> torch.Tensor: 311 if self.pos_sampling_technique == "all": 312 pos_mask = torch.arange(num_queries * num_docs)[None].greater_equal(min_idx) & torch.arange( 313 num_queries * num_docs 314 )[None].less(max_idx) 315 elif self.pos_sampling_technique == "first": 316 pos_mask = torch.arange(num_queries * num_docs)[None].eq(min_idx) 317 else: 318 raise ValueError("invalid pos sampling technique") 319 return pos_mask 320 321 def _get_neg_mask( 322 self, 323 num_queries: int, 324 num_docs: int, 325 max_idx: torch.Tensor, 326 min_idx: torch.Tensor, 327 output: LightningIROutput, 328 batch: TrainBatch, 329 ) -> torch.Tensor: 330 if self.neg_sampling_technique == "all_and_non_first": 331 neg_mask = torch.arange(num_queries * num_docs)[None].not_equal(min_idx) 332 elif self.neg_sampling_technique == "all": 333 neg_mask = torch.arange(num_queries * num_docs)[None].less(min_idx) | torch.arange(num_queries * num_docs)[ 334 None 335 ].greater_equal(max_idx) 336 elif self.neg_sampling_technique == "first": 337 neg_mask = torch.arange(num_queries * num_docs)[None, None].eq(min_idx).any(1) & torch.arange( 338 num_queries * num_docs 339 )[None].ne(min_idx) 340 else: 341 raise ValueError("invalid neg sampling technique") 342 return neg_mask 343 344 def get_ib_idcs(self, output: LightningIROutput, batch: TrainBatch) -> Tuple[torch.Tensor, torch.Tensor]: 345 if output.scores is None: 346 raise ValueError("Expected scores in LightningIROutput") 347 num_queries, num_docs = output.scores.shape 348 min_idx = torch.arange(num_queries)[:, None] * num_docs 349 max_idx = min_idx + num_docs 350 pos_mask = self._get_pos_mask(num_queries, num_docs, max_idx, min_idx, output, batch) 351 neg_mask = self._get_neg_mask(num_queries, num_docs, max_idx, min_idx, output, batch) 352 pos_idcs = pos_mask.nonzero(as_tuple=True)[1] 353 neg_idcs = neg_mask.nonzero(as_tuple=True)[1] 354 if self.max_num_neg_samples is not None: 355 neg_idcs = neg_idcs.view(num_queries, -1) 356 if neg_idcs.shape[-1] > 1: 357 neg_idcs = neg_idcs[:, torch.randperm(neg_idcs.shape[-1])] 358 neg_idcs = neg_idcs[:, : self.max_num_neg_samples] 359 neg_idcs = neg_idcs.reshape(-1) 360 return pos_idcs, neg_idcs
361 362
[docs] 363class ScoreBasedInBatchLossFunction(InBatchLossFunction): 364
[docs] 365 def __init__(self, min_target_diff: float, max_num_neg_samples: int | None = None): 366 super().__init__( 367 pos_sampling_technique="first", 368 neg_sampling_technique="all_and_non_first", 369 max_num_neg_samples=max_num_neg_samples, 370 ) 371 self.min_target_diff = min_target_diff
372 373 def _sort_mask( 374 self, mask: torch.Tensor, num_queries: int, num_docs: int, output: LightningIROutput, batch: TrainBatch 375 ) -> torch.Tensor: 376 scores = self.process_scores(output) 377 targets = self.process_targets(scores, batch) 378 idcs = targets.argsort(descending=True).argsort().cpu() 379 idcs = idcs + torch.arange(num_queries)[:, None] * num_docs 380 block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs 381 return mask.scatter(1, block_idcs, mask.gather(1, idcs)) 382 383 def _get_pos_mask( 384 self, 385 num_queries: int, 386 num_docs: int, 387 max_idx: torch.Tensor, 388 min_idx: torch.Tensor, 389 output: LightningIROutput, 390 batch: TrainBatch, 391 ) -> torch.Tensor: 392 pos_mask = super()._get_pos_mask(num_queries, num_docs, max_idx, min_idx, output, batch) 393 pos_mask = self._sort_mask(pos_mask, num_queries, num_docs, output, batch) 394 return pos_mask 395 396 def _get_neg_mask( 397 self, 398 num_queries: int, 399 num_docs: int, 400 max_idx: torch.Tensor, 401 min_idx: torch.Tensor, 402 output: LightningIROutput, 403 batch: TrainBatch, 404 ) -> torch.Tensor: 405 neg_mask = super()._get_neg_mask(num_queries, num_docs, max_idx, min_idx, output, batch) 406 neg_mask = self._sort_mask(neg_mask, num_queries, num_docs, output, batch) 407 scores = self.process_scores(output) 408 targets = self.process_targets(scores, batch).cpu() 409 max_score, _ = targets.max(dim=-1, keepdim=True) 410 score_diff = (max_score - targets).cpu() 411 score_mask = score_diff.ge(self.min_target_diff) 412 block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs 413 neg_mask = neg_mask.scatter(1, block_idcs, score_mask) 414 # num_neg_samples might be different between queries 415 num_neg_samples = neg_mask.sum(dim=1) 416 min_num_neg_samples = num_neg_samples.min() 417 additional_neg_samples = num_neg_samples - min_num_neg_samples 418 for query_idx, neg_samples in enumerate(additional_neg_samples): 419 neg_idcs = neg_mask[query_idx].nonzero().squeeze(1) 420 additional_neg_idcs = neg_idcs[torch.randperm(neg_idcs.shape[0])][:neg_samples] 421 assert neg_mask[query_idx, additional_neg_idcs].all() 422 neg_mask[query_idx, additional_neg_idcs] = False 423 assert neg_mask[query_idx].sum().eq(min_num_neg_samples) 424 return neg_mask
425 426
[docs] 427class InBatchCrossEntropy(InBatchLossFunction): 428 def compute_loss(self, output: LightningIROutput) -> torch.Tensor: 429 scores = self.process_scores(output) 430 targets = torch.zeros(scores.shape[0], dtype=torch.long, device=scores.device) 431 loss = torch.nn.functional.cross_entropy(scores, targets) 432 return loss
433 434
[docs] 435class ScoreBasedInBatchCrossEntropy(ScoreBasedInBatchLossFunction): 436 437 def compute_loss(self, output: LightningIROutput) -> torch.Tensor: 438 scores = self.process_scores(output) 439 targets = torch.zeros(scores.shape[0], dtype=torch.long, device=scores.device) 440 loss = torch.nn.functional.cross_entropy(scores, targets) 441 return loss
442 443
[docs] 444class RegularizationLossFunction(EmbeddingLossFunction):
[docs] 445 def __init__(self, query_weight: float = 1e-4, doc_weight: float = 1e-4) -> None: 446 self.query_weight = query_weight 447 self.doc_weight = doc_weight
448 449 def process_embeddings(self, output: BiEncoderOutput) -> Tuple[torch.Tensor, torch.Tensor]: 450 query_embeddings = output.query_embeddings 451 doc_embeddings = output.doc_embeddings 452 if query_embeddings is None: 453 raise ValueError("Expected query_embeddings in BiEncoderOutput") 454 if doc_embeddings is None: 455 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 456 return query_embeddings.embeddings, doc_embeddings.embeddings
457 458
[docs] 459class L2Regularization(RegularizationLossFunction): 460 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: 461 query_embeddings, doc_embeddings = self.process_embeddings(output) 462 query_loss = self.query_weight * query_embeddings.norm(dim=-1).mean() 463 doc_loss = self.doc_weight * doc_embeddings.norm(dim=-1).mean() 464 loss = query_loss + doc_loss 465 return loss
466 467
[docs] 468class L1Regularization(RegularizationLossFunction): 469 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: 470 query_embeddings, doc_embeddings = self.process_embeddings(output) 471 query_loss = self.query_weight * query_embeddings.norm(p=1, dim=-1).mean() 472 doc_loss = self.doc_weight * doc_embeddings.norm(p=1, dim=-1).mean() 473 loss = query_loss + doc_loss 474 return loss
475 476
[docs] 477class FLOPSRegularization(RegularizationLossFunction): 478 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: 479 query_embeddings, doc_embeddings = self.process_embeddings(output) 480 query_loss = torch.sum(torch.mean(torch.abs(query_embeddings), dim=0) ** 2) 481 doc_loss = torch.sum(torch.mean(torch.abs(doc_embeddings), dim=0) ** 2) 482 loss = self.query_weight * query_loss + self.doc_weight * doc_loss 483 return loss