Source code for lightning_ir.loss.base

  1"""
  2Base classes and abstract interfaces for loss functions in the Lightning IR framework.
  3
  4This module defines the abstract base classes and common functionality for all loss functions
  5used in the Lightning IR framework.
  6"""
  7
  8from __future__ import annotations
  9
 10from abc import ABC, abstractmethod
 11from typing import TYPE_CHECKING, Literal, Tuple
 12
 13import torch
 14
 15if TYPE_CHECKING:
 16    from ..base import LightningIROutput
 17    from ..bi_encoder import BiEncoderOutput
 18    from ..data import TrainBatch
 19
 20
[docs] 21class LossFunction(ABC): 22 """Base class for loss functions in the Lightning IR framework.""" 23
[docs] 24 @abstractmethod 25 def compute_loss(self, output: LightningIROutput, *args, **kwargs) -> torch.Tensor: 26 """Compute the loss for the given output. 27 28 Args: 29 output (LightningIROutput): The output from the model. 30 Returns: 31 torch.Tensor: The computed loss. 32 """ 33 ...
34
[docs] 35 def process_scores(self, output: LightningIROutput) -> torch.Tensor: 36 """Process the scores from the output. 37 38 Args: 39 output (LightningIROutput): The output from the model. 40 Returns: 41 torch.Tensor: The scores tensor. 42 """ 43 if output.scores is None: 44 raise ValueError("Expected scores in LightningIROutput") 45 return output.scores
46
[docs] 47 def process_targets(self, scores: torch.Tensor, batch: TrainBatch) -> torch.Tensor: 48 """Process the targets from the batch. 49 50 Args: 51 scores (torch.Tensor): The scores tensor. 52 batch (TrainBatch): The training batch. 53 Returns: 54 torch.Tensor: The processed targets tensor. 55 """ 56 targets = batch.targets 57 if targets is None: 58 raise ValueError("Expected targets in TrainBatch") 59 if targets.ndim > scores.ndim: 60 return targets.amax(-1) 61 return targets
62 63
[docs] 64class ScoringLossFunction(LossFunction): 65 """Base class for loss functions that operate on scores.""" 66
[docs] 67 @abstractmethod 68 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 69 """Compute the loss based on the scores and targets in the output and batch. 70 71 Args: 72 output (LightningIROutput): The output from the model containing scores. 73 batch (TrainBatch): The training batch containing targets. 74 Returns: 75 torch.Tensor: The computed loss. 76 """ 77 ...
78 79
[docs] 80class EmbeddingLossFunction(LossFunction): 81 """Base class for loss functions that operate on embeddings.""" 82
[docs] 83 @abstractmethod 84 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: 85 """Compute the loss based on the embeddings in the output. 86 87 Args: 88 output (BiEncoderOutput): The output from the model containing query and document embeddings. 89 Returns: 90 torch.Tensor: The computed loss. 91 """ 92 ...
93 94
[docs] 95class RegularizationLossFunction(EmbeddingLossFunction): 96 """Base class for regularization loss functions that operate on embeddings.""" 97
[docs] 98 def __init__(self, query_weight: float = 1e-4, doc_weight: float = 1e-4) -> None: 99 """Initialize the RegularizationLossFunction. 100 101 Args: 102 query_weight (float): Weight for the query embeddings regularization. Defaults to 1e-4. 103 doc_weight (float): Weight for the document embeddings regularization. Defaults to 1e-4. 104 """ 105 self.query_weight = query_weight 106 self.doc_weight = doc_weight
107
[docs] 108 def process_embeddings(self, output: BiEncoderOutput) -> Tuple[torch.Tensor, torch.Tensor]: 109 """Process the embeddings from the output. 110 111 Args: 112 output (BiEncoderOutput): The output from the model containing query and document embeddings. 113 Returns: 114 Tuple[torch.Tensor, torch.Tensor]: The processed query and document embeddings. 115 Raises: 116 ValueError: If query_embeddings are not present in the output. 117 ValueError: If doc_embeddings are not present in the output. 118 """ 119 query_embeddings = output.query_embeddings 120 doc_embeddings = output.doc_embeddings 121 if query_embeddings is None: 122 raise ValueError("Expected query_embeddings in BiEncoderOutput") 123 if doc_embeddings is None: 124 raise ValueError("Expected doc_embeddings in BiEncoderOutput") 125 return query_embeddings.embeddings, doc_embeddings.embeddings
126 127
[docs] 128class PairwiseLossFunction(ScoringLossFunction): 129 """Base class for pairwise loss functions.""" 130
[docs] 131 def get_pairwise_idcs(self, targets: torch.Tensor) -> Tuple[torch.Tensor, ...]: 132 """Get pairwise indices for positive and negative samples based on targets. 133 134 Args: 135 targets (torch.Tensor): The targets tensor containing relevance labels. 136 Returns: 137 Tuple[torch.Tensor, ...]: Indices of positive and negative samples. 138 """ 139 # positive items are items where label is greater than other label in sample 140 return torch.nonzero(targets[..., None] > targets[:, None], as_tuple=True)
141 142
[docs] 143class ListwiseLossFunction(ScoringLossFunction): 144 """Base class for listwise loss functions.""" 145 146 pass
147 148
[docs] 149class InBatchLossFunction(LossFunction): 150 """Base class for in-batch loss functions that compute in-batch indices for positive and negative samples.""" 151
[docs] 152 def __init__( 153 self, 154 pos_sampling_technique: Literal["all", "first"] = "all", 155 neg_sampling_technique: Literal["all", "first", "all_and_non_first"] = "all", 156 max_num_neg_samples: int | None = None, 157 ): 158 """Initialize the InBatchLossFunction. 159 160 Args: 161 pos_sampling_technique (Literal["all", "first"]): Technique for positive sample sampling. 162 neg_sampling_technique (Literal["all", "first", "all_and_non_first"]): Technique for negative sample 163 sampling. 164 max_num_neg_samples (int | None): Maximum number of negative samples to consider. If None, all negative 165 samples are considered. 166 Raises: 167 ValueError: If the negative sampling technique is invalid for the given positive sampling technique. 168 """ 169 super().__init__() 170 self.pos_sampling_technique = pos_sampling_technique 171 self.neg_sampling_technique = neg_sampling_technique 172 self.max_num_neg_samples = max_num_neg_samples 173 if self.neg_sampling_technique == "all_and_non_first" and self.pos_sampling_technique != "first": 174 raise ValueError("all_and_non_first is only valid with pos_sampling_technique first")
175 176 def _get_pos_mask( 177 self, 178 num_queries: int, 179 num_docs: int, 180 max_idx: torch.Tensor, 181 min_idx: torch.Tensor, 182 output: LightningIROutput, 183 batch: TrainBatch, 184 ) -> torch.Tensor: 185 """Get the mask for positive samples based on the sampling technique. 186 187 Args: 188 num_queries (int): Number of queries in the batch. 189 num_docs (int): Number of documents per query. 190 max_idx (torch.Tensor): Maximum index for each query. 191 min_idx (torch.Tensor): Minimum index for each query. 192 output (LightningIROutput): The output from the model containing scores. 193 batch (TrainBatch): The training batch containing targets. 194 Returns: 195 torch.Tensor: A mask tensor indicating the positions of positive samples. 196 Raises: 197 ValueError: If the positive sampling technique is invalid. 198 """ 199 if self.pos_sampling_technique == "all": 200 pos_mask = torch.arange(num_queries * num_docs)[None].greater_equal(min_idx) & torch.arange( 201 num_queries * num_docs 202 )[None].less(max_idx) 203 elif self.pos_sampling_technique == "first": 204 pos_mask = torch.arange(num_queries * num_docs)[None].eq(min_idx) 205 else: 206 raise ValueError("invalid pos sampling technique") 207 return pos_mask 208 209 def _get_neg_mask( 210 self, 211 num_queries: int, 212 num_docs: int, 213 max_idx: torch.Tensor, 214 min_idx: torch.Tensor, 215 output: LightningIROutput, 216 batch: TrainBatch, 217 ) -> torch.Tensor: 218 """Get the mask for negative samples based on the sampling technique. 219 220 Args: 221 num_queries (int): Number of queries in the batch. 222 num_docs (int): Number of documents per query. 223 max_idx (torch.Tensor): Maximum index for each query. 224 min_idx (torch.Tensor): Minimum index for each query. 225 output (LightningIROutput): The output from the model containing scores. 226 batch (TrainBatch): The training batch containing targets. 227 Returns: 228 torch.Tensor: A mask tensor indicating the positions of negative samples. 229 Raises: 230 ValueError: If the negative sampling technique is invalid. 231 """ 232 if self.neg_sampling_technique == "all_and_non_first": 233 neg_mask = torch.arange(num_queries * num_docs)[None].not_equal(min_idx) 234 elif self.neg_sampling_technique == "all": 235 neg_mask = torch.arange(num_queries * num_docs)[None].less(min_idx) | torch.arange(num_queries * num_docs)[ 236 None 237 ].greater_equal(max_idx) 238 elif self.neg_sampling_technique == "first": 239 neg_mask = torch.arange(num_queries * num_docs)[None, None].eq(min_idx).any(1) & torch.arange( 240 num_queries * num_docs 241 )[None].ne(min_idx) 242 else: 243 raise ValueError("invalid neg sampling technique") 244 return neg_mask 245
[docs] 246 def get_ib_idcs(self, output: LightningIROutput, batch: TrainBatch) -> Tuple[torch.Tensor, torch.Tensor]: 247 """Get in-batch indices for positive and negative samples. 248 249 Args: 250 output (LightningIROutput): The output from the model containing scores. 251 batch (TrainBatch): The training batch containing targets. 252 Returns: 253 Tuple[torch.Tensor, torch.Tensor]: Indices of positive and negative samples. 254 Raises: 255 ValueError: If scores are not present in the output. 256 """ 257 if output.scores is None: 258 raise ValueError("Expected scores in LightningIROutput") 259 num_queries, num_docs = output.scores.shape 260 min_idx = torch.arange(num_queries)[:, None] * num_docs 261 max_idx = min_idx + num_docs 262 pos_mask = self._get_pos_mask(num_queries, num_docs, max_idx, min_idx, output, batch) 263 neg_mask = self._get_neg_mask(num_queries, num_docs, max_idx, min_idx, output, batch) 264 pos_idcs = pos_mask.nonzero(as_tuple=True)[1] 265 neg_idcs = neg_mask.nonzero(as_tuple=True)[1] 266 if self.max_num_neg_samples is not None: 267 neg_idcs = neg_idcs.view(num_queries, -1) 268 if neg_idcs.shape[-1] > 1: 269 neg_idcs = neg_idcs[:, torch.randperm(neg_idcs.shape[-1])] 270 neg_idcs = neg_idcs[:, : self.max_num_neg_samples] 271 neg_idcs = neg_idcs.reshape(-1) 272 return pos_idcs, neg_idcs