Source code for lightning_ir.bi_encoder.bi_encoder_module

  1"""
  2Module module for bi-encoder models.
  3
  4This module defines the Lightning IR module class used to implement bi-encoder models.
  5"""
  6
  7from __future__ import annotations
  8
  9from pathlib import Path
 10from typing import TYPE_CHECKING, Any, List, Mapping, Sequence, Tuple, Type
 11
 12import torch
 13from transformers import BatchEncoding, PreTrainedModel
 14
 15from ..base import LightningIRModule, LightningIROutput
 16from ..data import IndexBatch, RankBatch, SearchBatch, TrainBatch
 17from ..loss.base import EmbeddingLossFunction, LossFunction, ScoringLossFunction
 18from ..loss.in_batch import InBatchLossFunction
 19from .bi_encoder_config import BiEncoderConfig
 20from .bi_encoder_model import BiEncoderEmbedding, BiEncoderModel, BiEncoderOutput
 21from .bi_encoder_tokenizer import BiEncoderTokenizer
 22
 23if TYPE_CHECKING:
 24    from ..retrieve import SearchConfig, Searcher
 25
 26
[docs] 27class BiEncoderModule(LightningIRModule):
[docs] 28 def __init__( 29 self, 30 model_name_or_path: str | None = None, 31 config: BiEncoderConfig | None = None, 32 model: BiEncoderModel | None = None, 33 BackboneModel: Type[PreTrainedModel] | None = None, 34 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, 35 evaluation_metrics: Sequence[str] | None = None, 36 index_dir: Path | None = None, 37 search_config: SearchConfig | None = None, 38 model_kwargs: Mapping[str, Any] | None = None, 39 ): 40 """:class:`.LightningIRModule` for bi-encoder models. It contains a :class:`.BiEncoderModel` and a 41 :class:`.BiEncoderTokenizer` and implements the training, validation, and testing steps for the model. 42 43 .. _ir-measures: https://ir-measur.es/en/latest/index.html 44 45 Args: 46 model_name_or_path (str | None): Name or path of backbone model or fine-tuned Lightning IR model. 47 Defaults to None. 48 config (BiEncoderConfig | None): BiEncoderConfig to apply when loading from backbone model. 49 Defaults to None. 50 model (BiEncoderModel | None): Already instantiated BiEncoderModel. Defaults to None. 51 BackboneModel (Type[PreTrainedModel] | None): Huggingface PreTrainedModel class to use as backbone 52 instead of the default AutoModel. Defaults to None. 53 loss_functions (Sequence[LossFunction | Tuple[LossFunction, float]] | None): 54 Loss functions to apply during fine-tuning, optional loss weights can be provided per loss function 55 Defaults to None. 56 evaluation_metrics (Sequence[str] | None): Metrics corresponding to ir-measures_ measure strings 57 to apply during validation or testing. Defaults to None. 58 index_dir (Path | None): Path to an index used for retrieval. Defaults to None. 59 search_config (SearchConfig | None): Configuration to use during retrieval. Defaults to None. 60 model_kwargs (Mapping[str, Any] | None): Additional keyword arguments to pass to `from_pretrained` 61 when loading a model. Defaults to None. 62 """ 63 super().__init__( 64 model_name_or_path=model_name_or_path, 65 config=config, 66 model=model, 67 BackboneModel=BackboneModel, 68 loss_functions=loss_functions, 69 evaluation_metrics=evaluation_metrics, 70 model_kwargs=model_kwargs, 71 ) 72 self.model: BiEncoderModel 73 self.config: BiEncoderConfig 74 self.tokenizer: BiEncoderTokenizer 75 if len(self.tokenizer) > self.config.vocab_size: 76 self.model.resize_token_embeddings(len(self.tokenizer), 8) 77 self._searcher = None 78 self.search_config = search_config 79 self.index_dir = index_dir
80 81 @property 82 def searcher(self) -> Searcher | None: 83 """Searcher used for retrieval if `index_dir` and `search_config` are set. 84 85 Returns: 86 Searcher: Searcher class. 87 """ 88 return self._searcher 89 90 @searcher.setter 91 def searcher(self, searcher: Searcher): 92 self._searcher = searcher 93 94 def _init_searcher(self) -> None: 95 if self.search_config is not None and self.index_dir is not None: 96 self.searcher = self.search_config.search_class(self.index_dir, self.search_config, self) 97
[docs] 98 def on_test_start(self) -> None: 99 """Called at the beginning of testing. Initializes the searcher if `index_dir` and `search_config` are set.""" 100 self._init_searcher() 101 return super().on_test_start()
102
[docs] 103 def forward(self, batch: RankBatch | IndexBatch | SearchBatch) -> BiEncoderOutput: 104 """Runs a forward pass of the model on a batch of data. The output will vary depending on the type of batch. If 105 the batch is a :class`.RankBatch`, query and document embeddings are computed and the relevance score is the 106 similarity between the two embeddings. If the batch is an :class:`.IndexBatch`, only document embeddings 107 are comuputed. If the batch is a :class:`.SearchBatch`, only query embeddings are computed and 108 the model will additionally retrieve documents if :attr:`.searcher` is set. 109 110 Args: 111 batch (RankBatch | IndexBatch | SearchBatch): Input batch containing queries and/or documents. 112 Returns: 113 BiEncoderOutput: Output of the model. 114 Raises: 115 ValueError: If the input batch contains neither queries nor documents. 116 """ 117 queries = getattr(batch, "queries", None) 118 docs = getattr(batch, "docs", None) 119 num_docs = None 120 if isinstance(batch, RankBatch): 121 num_docs = None if docs is None else [len(d) for d in docs] 122 docs = [d for nested in docs for d in nested] if docs is not None else None 123 encodings = self.prepare_input(queries, docs, num_docs) 124 125 if not encodings: 126 raise ValueError("No encodings were generated.") 127 output = self.model.forward( 128 encodings.get("query_encoding", None), encodings.get("doc_encoding", None), num_docs 129 ) 130 doc_ids = getattr(batch, "doc_ids", None) 131 if doc_ids is not None and output.doc_embeddings is not None: 132 output.doc_embeddings.ids = doc_ids 133 query_ids = getattr(batch, "query_ids", None) 134 if query_ids is not None and output.query_embeddings is not None: 135 output.query_embeddings.ids = query_ids 136 if isinstance(batch, SearchBatch) and self.searcher is not None: 137 scores, doc_ids = self.searcher.search(output) 138 output.scores = scores 139 if output.doc_embeddings is not None: 140 output.doc_embeddings.ids = [doc_id for _doc_ids in doc_ids for doc_id in _doc_ids] 141 batch.doc_ids = doc_ids 142 return output
143
[docs] 144 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> BiEncoderOutput: 145 """Computes relevance scores for queries and documents. 146 147 Args: 148 queries (Sequence[str] | str): Queries to score. 149 docs (Sequence[Sequence[str]] | Sequence[str]): Documents to score. 150 Returns: 151 BiEncoderOutput: Output of the model. 152 """ 153 return super().score(queries, docs)
154 155 def _compute_losses(self, batch: TrainBatch, output: BiEncoderOutput) -> List[torch.Tensor]: 156 """Computes the losses for a training batch.""" 157 if self.loss_functions is None: 158 raise ValueError("Loss function is not set") 159 160 if ( 161 batch.targets is None 162 or output.query_embeddings is None 163 or output.doc_embeddings is None 164 or output.scores is None 165 ): 166 raise ValueError( 167 "targets, scores, query_embeddings, and doc_embeddings must be set in " "the output and batch" 168 ) 169 170 num_queries = len(batch.queries) 171 output.scores = output.scores.view(num_queries, -1) 172 batch.targets = batch.targets.view(*output.scores.shape, -1) 173 losses = [] 174 for loss_function, _ in self.loss_functions: 175 if isinstance(loss_function, InBatchLossFunction): 176 pos_idcs, neg_idcs = loss_function.get_ib_idcs(output, batch) 177 ib_doc_embeddings = self._get_ib_doc_embeddings(output.doc_embeddings, pos_idcs, neg_idcs, num_queries) 178 ib_scores = self.model.score( 179 BiEncoderOutput(query_embeddings=output.query_embeddings, doc_embeddings=ib_doc_embeddings) 180 ).scores 181 if ib_scores is None: 182 raise ValueError("In-batch scores cannot be None") 183 ib_scores = ib_scores.view(num_queries, -1) 184 losses.append(loss_function.compute_loss(LightningIROutput(ib_scores))) 185 elif isinstance(loss_function, EmbeddingLossFunction): 186 losses.append(loss_function.compute_loss(output)) 187 elif isinstance(loss_function, ScoringLossFunction): 188 losses.append(loss_function.compute_loss(output, batch)) 189 else: 190 raise ValueError(f"Unknown loss function type {loss_function.__class__.__name__}") 191 if self.config.sparsification is not None: 192 query_num_nonzero = ( 193 torch.nonzero(output.query_embeddings.embeddings).shape[0] / output.query_embeddings.embeddings.shape[0] 194 ) 195 doc_num_nonzero = ( 196 torch.nonzero(output.doc_embeddings.embeddings).shape[0] / output.doc_embeddings.embeddings.shape[0] 197 ) 198 self.log("query_num_nonzero", query_num_nonzero) 199 self.log("doc_num_nonzero", doc_num_nonzero) 200 return losses 201 202 def _get_ib_doc_embeddings( 203 self, 204 embeddings: BiEncoderEmbedding, 205 pos_idcs: torch.Tensor, 206 neg_idcs: torch.Tensor, 207 num_queries: int, 208 ) -> BiEncoderEmbedding: 209 """Gets the in-batch document embeddings for a training batch.""" 210 _, num_embs, emb_dim = embeddings.embeddings.shape 211 ib_embeddings = torch.cat( 212 [ 213 embeddings.embeddings[pos_idcs].view(num_queries, -1, num_embs, emb_dim), 214 embeddings.embeddings[neg_idcs].view(num_queries, -1, num_embs, emb_dim), 215 ], 216 dim=1, 217 ).view(-1, num_embs, emb_dim) 218 if embeddings.scoring_mask is None: 219 ib_scoring_mask = None 220 else: 221 ib_scoring_mask = torch.cat( 222 [ 223 embeddings.scoring_mask[pos_idcs].view(num_queries, -1, num_embs), 224 embeddings.scoring_mask[neg_idcs].view(num_queries, -1, num_embs), 225 ], 226 dim=1, 227 ).view(-1, num_embs) 228 if embeddings.encoding is None: 229 ib_encoding = None 230 else: 231 ib_encoding = {} 232 for key, value in embeddings.encoding.items(): 233 seq_len = value.shape[-1] 234 ib_encoding[key] = torch.cat( 235 [value[pos_idcs].view(num_queries, -1, seq_len), value[neg_idcs].view(num_queries, -1, seq_len)], 236 dim=1, 237 ).view(-1, seq_len) 238 ib_encoding = BatchEncoding(ib_encoding) 239 return BiEncoderEmbedding(ib_embeddings, ib_scoring_mask, ib_encoding) 240
[docs] 241 def validation_step( 242 self, 243 batch: TrainBatch | IndexBatch | SearchBatch | RankBatch, 244 batch_idx: int, 245 dataloader_idx: int = 0, 246 ) -> BiEncoderOutput: 247 """Handles the validation step for the model. 248 249 Args: 250 batch (TrainBatch | IndexBatch | SearchBatch | RankBatch): Batch of validation or testing data. 251 batch_idx (int): Index of the batch. 252 dataloader_idx (int | None): Index of the dataloader. Defaults to 0. 253 Returns: 254 BiEncoderOutput: Output of the model. 255 """ 256 if isinstance(batch, IndexBatch): 257 return self.forward(batch) 258 if isinstance(batch, (RankBatch, TrainBatch, SearchBatch)): 259 return super().validation_step(batch, batch_idx, dataloader_idx) 260 raise ValueError(f"Unknown batch type {type(batch)}")