Source code for lightning_ir.loss.in_batch

  1"""
  2In-batch loss functions for the Lightning IR framework.
  3
  4This module contains loss functions that operate on batches of data,
  5comparing examples within the same batch for training.
  6"""
  7
  8from __future__ import annotations
  9
 10from typing import TYPE_CHECKING
 11
 12import torch
 13
 14from .base import InBatchLossFunction
 15
 16if TYPE_CHECKING:
 17    from ..base import LightningIROutput
 18    from ..data import TrainBatch
 19
 20
[docs] 21class ScoreBasedInBatchLossFunction(InBatchLossFunction): 22 """Base class for in-batch loss functions that compute in-batch indices based on scores.""" 23
[docs] 24 def __init__(self, min_target_diff: float, max_num_neg_samples: int | None = None): 25 """Initialize the ScoreBasedInBatchLossFunction. 26 27 Args: 28 min_target_diff (float): Minimum target difference for negative sampling. 29 max_num_neg_samples (int | None): Maximum number of negative samples. 30 """ 31 super().__init__( 32 pos_sampling_technique="first", 33 neg_sampling_technique="all_and_non_first", 34 max_num_neg_samples=max_num_neg_samples, 35 ) 36 self.min_target_diff = min_target_diff
37 38 def _sort_mask( 39 self, mask: torch.Tensor, num_queries: int, num_docs: int, output: LightningIROutput, batch: TrainBatch 40 ) -> torch.Tensor: 41 """Sort the mask based on the scores and targets. 42 43 Args: 44 mask (torch.Tensor): The initial mask tensor. 45 num_queries (int): Number of queries in the batch. 46 num_docs (int): Number of documents per query. 47 output (LightningIROutput): The output from the model containing scores. 48 batch (TrainBatch): The training batch containing targets. 49 Returns: 50 torch.Tensor: The sorted mask tensor. 51 """ 52 scores = self.process_scores(output) 53 targets = self.process_targets(scores, batch) 54 idcs = targets.argsort(descending=True).argsort().cpu() 55 idcs = idcs + torch.arange(num_queries)[:, None] * num_docs 56 block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs 57 return mask.scatter(1, block_idcs, mask.gather(1, idcs)) 58 59 def _get_pos_mask( 60 self, 61 num_queries: int, 62 num_docs: int, 63 max_idx: torch.Tensor, 64 min_idx: torch.Tensor, 65 output: LightningIROutput, 66 batch: TrainBatch, 67 ) -> torch.Tensor: 68 """Get the mask for positive samples and sort it based on scores. 69 70 Args: 71 num_queries (int): Number of queries in the batch. 72 num_docs (int): Number of documents per query. 73 max_idx (torch.Tensor): Maximum index for each query. 74 min_idx (torch.Tensor): Minimum index for each query. 75 output (LightningIROutput): The output from the model containing scores. 76 batch (TrainBatch): The training batch containing targets. 77 Returns: 78 torch.Tensor: A mask tensor indicating the positions of positive samples, sorted by scores. 79 """ 80 pos_mask = super()._get_pos_mask(num_queries, num_docs, max_idx, min_idx, output, batch) 81 pos_mask = self._sort_mask(pos_mask, num_queries, num_docs, output, batch) 82 return pos_mask 83 84 def _get_neg_mask( 85 self, 86 num_queries: int, 87 num_docs: int, 88 max_idx: torch.Tensor, 89 min_idx: torch.Tensor, 90 output: LightningIROutput, 91 batch: TrainBatch, 92 ) -> torch.Tensor: 93 """Get the mask for negative samples and sort it based on scores. 94 95 Args: 96 num_queries (int): Number of queries in the batch. 97 num_docs (int): Number of documents per query. 98 max_idx (torch.Tensor): Maximum index for each query. 99 min_idx (torch.Tensor): Minimum index for each query. 100 output (LightningIROutput): The output from the model containing scores. 101 batch (TrainBatch): The training batch containing targets. 102 Returns: 103 torch.Tensor: A mask tensor indicating the positions of negative samples, sorted by scores. 104 """ 105 neg_mask = super()._get_neg_mask(num_queries, num_docs, max_idx, min_idx, output, batch) 106 neg_mask = self._sort_mask(neg_mask, num_queries, num_docs, output, batch) 107 scores = self.process_scores(output) 108 targets = self.process_targets(scores, batch).cpu() 109 max_score, _ = targets.max(dim=-1, keepdim=True) 110 score_diff = (max_score - targets).cpu() 111 score_mask = score_diff.ge(self.min_target_diff) 112 block_idcs = torch.arange(num_docs)[None] + torch.arange(num_queries)[:, None] * num_docs 113 neg_mask = neg_mask.scatter(1, block_idcs, score_mask) 114 # num_neg_samples might be different between queries 115 num_neg_samples = neg_mask.sum(dim=1) 116 min_num_neg_samples = num_neg_samples.min() 117 additional_neg_samples = num_neg_samples - min_num_neg_samples 118 for query_idx, neg_samples in enumerate(additional_neg_samples): 119 neg_idcs = neg_mask[query_idx].nonzero().squeeze(1) 120 additional_neg_idcs = neg_idcs[torch.randperm(neg_idcs.shape[0])][:neg_samples] 121 assert neg_mask[query_idx, additional_neg_idcs].all() 122 neg_mask[query_idx, additional_neg_idcs] = False 123 assert neg_mask[query_idx].sum().eq(min_num_neg_samples) 124 return neg_mask
125 126
[docs] 127class InBatchCrossEntropy(InBatchLossFunction): 128 """In-batch cross-entropy loss function for ranking tasks. 129 Originally proposed in: `Fast Single-Class Classification and the Principle of Logit Separation 130 <https://arxiv.org/pdf/1705.10246v1>`_""" 131
[docs] 132 def compute_loss(self, output: LightningIROutput) -> torch.Tensor: 133 """Compute the in-batch cross-entropy loss. 134 135 Args: 136 output (LightningIROutput): The output from the model containing scores. 137 Returns: 138 torch.Tensor: The computed loss. 139 """ 140 scores = self.process_scores(output) 141 targets = torch.zeros(scores.shape[0], dtype=torch.long, device=scores.device) 142 loss = torch.nn.functional.cross_entropy(scores, targets) 143 return loss
144 145
[docs] 146class ScoreBasedInBatchCrossEntropy(ScoreBasedInBatchLossFunction): 147 """In-batch cross-entropy loss function based on scores for ranking tasks.""" 148
[docs] 149 def compute_loss(self, output: LightningIROutput) -> torch.Tensor: 150 """Compute the in-batch cross-entropy loss based on scores. 151 152 Args: 153 output (LightningIROutput): The output from the model containing scores. 154 Returns: 155 torch.Tensor: The computed loss. 156 """ 157 scores = self.process_scores(output) 158 targets = torch.zeros(scores.shape[0], dtype=torch.long, device=scores.device) 159 loss = torch.nn.functional.cross_entropy(scores, targets) 160 return loss