Source code for lightning_ir.bi_encoder.bi_encoder_module

  1"""
  2Module module for bi-encoder models.
  3
  4This module defines the Lightning IR module class used to implement bi-encoder models.
  5"""
  6
  7from __future__ import annotations
  8
  9from pathlib import Path
 10from typing import TYPE_CHECKING, Any, List, Mapping, Sequence, Tuple
 11
 12import torch
 13from transformers import BatchEncoding
 14
 15from ..base import LightningIRModule, LightningIROutput
 16from ..data import IndexBatch, RankBatch, SearchBatch, TrainBatch
 17from ..loss.loss import EmbeddingLossFunction, InBatchLossFunction, LossFunction, ScoringLossFunction
 18from .bi_encoder_config import BiEncoderConfig
 19from .bi_encoder_model import BiEncoderEmbedding, BiEncoderModel, BiEncoderOutput
 20from .bi_encoder_tokenizer import BiEncoderTokenizer
 21
 22if TYPE_CHECKING:
 23    from ..retrieve import SearchConfig, Searcher
 24
 25
[docs] 26class BiEncoderModule(LightningIRModule):
[docs] 27 def __init__( 28 self, 29 model_name_or_path: str | None = None, 30 config: BiEncoderConfig | None = None, 31 model: BiEncoderModel | None = None, 32 loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None = None, 33 evaluation_metrics: Sequence[str] | None = None, 34 index_dir: Path | None = None, 35 search_config: SearchConfig | None = None, 36 model_kwargs: Mapping[str, Any] | None = None, 37 ): 38 """:class:`.LightningIRModule` for bi-encoder models. It contains a :class:`.BiEncoderModel` and a 39 :class:`.BiEncoderTokenizer` and implements the training, validation, and testing steps for the model. 40 41 .. _ir-measures: https://ir-measur.es/en/latest/index.html 42 43 :param model_name_or_path: Name or path of backbone model or fine-tuned Lightning IR model, defaults to None 44 :type model_name_or_path: str | None, optional 45 :param config: BiEncoderConfig to apply when loading from backbone model, defaults to None 46 :type config: BiEncoderConfig | None, optional 47 :param model: Already instantiated BiEncoderModel, defaults to None 48 :type model: BiEncoderModel | None, optional 49 :param loss_functions: Loss functions to apply during fine-tuning, optional loss weights can be provided per 50 loss function, defaults to None 51 :type loss_functions: Sequence[LossFunction | Tuple[LossFunction, float]] | None, optional 52 :param evaluation_metrics: Metrics corresponding to ir-measures_ measure strings to apply during validation or 53 testing, defaults to None 54 :type evaluation_metrics: Sequence[str] | None, optional 55 :param index_dir: Path to an index used for retrieval, defaults to None 56 :type index_dir: Path | None, optional 57 :param search_config: Configuration to use during retrieval, defaults to None 58 :type search_config: SearchConfig | None, optional 59 :param model_kwargs: Additional keyword arguments to pass to `from_pretrained` when loading a model, 60 defaults to None 61 :type model_kwargs: Mapping[str, Any] | None, optional 62 """ 63 super().__init__(model_name_or_path, config, model, loss_functions, evaluation_metrics, model_kwargs) 64 self.model: BiEncoderModel 65 self.config: BiEncoderConfig 66 self.tokenizer: BiEncoderTokenizer 67 if self.config.add_marker_tokens and len(self.tokenizer) > self.config.vocab_size: 68 self.model.resize_token_embeddings(len(self.tokenizer), 8) 69 self._searcher = None 70 self.search_config = search_config 71 self.index_dir = index_dir
72 73 @property 74 def searcher(self) -> Searcher | None: 75 """Searcher used for retrieval if `index_dir` and `search_config` are set. 76 77 :return: Searcher class 78 :rtype: Searcher | None 79 """ 80 return self._searcher 81 82 @searcher.setter 83 def searcher(self, searcher: Searcher): 84 self._searcher = searcher 85 86 def _init_searcher(self) -> None: 87 if self.search_config is not None and self.index_dir is not None: 88 self.searcher = self.search_config.search_class(self.index_dir, self.search_config, self) 89
[docs] 90 def on_test_start(self) -> None: 91 """Called at the beginning of testing. Initializes the searcher if `index_dir` and `search_config` are set.""" 92 self._init_searcher() 93 return super().on_test_start()
94
[docs] 95 def forward(self, batch: RankBatch | IndexBatch | SearchBatch) -> BiEncoderOutput: 96 """Runs a forward pass of the model on a batch of data. The output will vary depending on the type of batch. If 97 the batch is a :class`.RankBatch`, query and document embeddings are computed and the relevance score is the 98 similarity between the two embeddings. If the batch is an :class:`.IndexBatch`, only document embeddings 99 are comuputed. If the batch is a :class:`.SearchBatch`, only query embeddings are computed and 100 the model will additionally retrieve documents if :attr:`.searcher` is set. 101 102 :param batch: Input batch containg 103 :type batch: RankBatch | IndexBatch | SearchBatch 104 :raises ValueError: If the input batch contains neither queries nor documents 105 :return: Output of the model 106 :rtype: BiEncoderOutput 107 """ 108 queries = getattr(batch, "queries", None) 109 docs = getattr(batch, "docs", None) 110 num_docs = None 111 if isinstance(batch, RankBatch): 112 num_docs = None if docs is None else [len(d) for d in docs] 113 docs = [d for nested in docs for d in nested] if docs is not None else None 114 encodings = self.prepare_input(queries, docs, num_docs) 115 116 if not encodings: 117 raise ValueError("No encodings were generated.") 118 output = self.model.forward( 119 encodings.get("query_encoding", None), encodings.get("doc_encoding", None), num_docs 120 ) 121 doc_ids = getattr(batch, "doc_ids", None) 122 if doc_ids is not None and output.doc_embeddings is not None: 123 output.doc_embeddings.ids = doc_ids 124 query_ids = getattr(batch, "query_ids", None) 125 if query_ids is not None and output.query_embeddings is not None: 126 output.query_embeddings.ids = query_ids 127 if isinstance(batch, SearchBatch) and self.searcher is not None: 128 scores, doc_ids = self.searcher.search(output) 129 output.scores = scores 130 if output.doc_embeddings is not None: 131 output.doc_embeddings.ids = [doc_id for _doc_ids in doc_ids for doc_id in _doc_ids] 132 batch.doc_ids = doc_ids 133 return output
134
[docs] 135 def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> BiEncoderOutput: 136 """Computes relevance scores for queries and documents. 137 138 :param queries: Queries to score 139 :type queries: Sequence[str] 140 :param docs: Documents to score 141 :type docs: Sequence[Sequence[str]] 142 :return: Model output 143 :rtype: BiEncoderOutput 144 """ 145 return super().score(queries, docs)
146 147 def _compute_losses(self, batch: TrainBatch, output: BiEncoderOutput) -> List[torch.Tensor]: 148 """Computes the losses for a training batch.""" 149 if self.loss_functions is None: 150 raise ValueError("Loss function is not set") 151 152 if ( 153 batch.targets is None 154 or output.query_embeddings is None 155 or output.doc_embeddings is None 156 or output.scores is None 157 ): 158 raise ValueError( 159 "targets, scores, query_embeddings, and doc_embeddings must be set in " "the output and batch" 160 ) 161 162 num_queries = len(batch.queries) 163 output.scores = output.scores.view(num_queries, -1) 164 batch.targets = batch.targets.view(*output.scores.shape, -1) 165 losses = [] 166 for loss_function, _ in self.loss_functions: 167 if isinstance(loss_function, InBatchLossFunction): 168 pos_idcs, neg_idcs = loss_function.get_ib_idcs(output, batch) 169 ib_doc_embeddings = self._get_ib_doc_embeddings(output.doc_embeddings, pos_idcs, neg_idcs, num_queries) 170 ib_scores = self.model.score(output.query_embeddings, ib_doc_embeddings) 171 ib_scores = ib_scores.view(num_queries, -1) 172 losses.append(loss_function.compute_loss(LightningIROutput(ib_scores))) 173 elif isinstance(loss_function, EmbeddingLossFunction): 174 losses.append(loss_function.compute_loss(output)) 175 elif isinstance(loss_function, ScoringLossFunction): 176 losses.append(loss_function.compute_loss(output, batch)) 177 else: 178 raise ValueError(f"Unknown loss function type {loss_function.__class__.__name__}") 179 if self.config.sparsification is not None: 180 query_num_nonzero = ( 181 torch.nonzero(output.query_embeddings.embeddings).shape[0] / output.query_embeddings.embeddings.shape[0] 182 ) 183 doc_num_nonzero = ( 184 torch.nonzero(output.doc_embeddings.embeddings).shape[0] / output.doc_embeddings.embeddings.shape[0] 185 ) 186 self.log("query_num_nonzero", query_num_nonzero) 187 self.log("doc_num_nonzero", doc_num_nonzero) 188 return losses 189 190 def _get_ib_doc_embeddings( 191 self, 192 embeddings: BiEncoderEmbedding, 193 pos_idcs: torch.Tensor, 194 neg_idcs: torch.Tensor, 195 num_queries: int, 196 ) -> BiEncoderEmbedding: 197 """Gets the in-batch document embeddings for a training batch.""" 198 _, num_embs, emb_dim = embeddings.embeddings.shape 199 ib_embeddings = torch.cat( 200 [ 201 embeddings.embeddings[pos_idcs].view(num_queries, -1, num_embs, emb_dim), 202 embeddings.embeddings[neg_idcs].view(num_queries, -1, num_embs, emb_dim), 203 ], 204 dim=1, 205 ).view(-1, num_embs, emb_dim) 206 if embeddings.scoring_mask is None: 207 ib_scoring_mask = None 208 else: 209 ib_scoring_mask = torch.cat( 210 [ 211 embeddings.scoring_mask[pos_idcs].view(num_queries, -1, num_embs), 212 embeddings.scoring_mask[neg_idcs].view(num_queries, -1, num_embs), 213 ], 214 dim=1, 215 ).view(-1, num_embs) 216 if embeddings.encoding is None: 217 ib_encoding = None 218 else: 219 ib_encoding = {} 220 for key, value in embeddings.encoding.items(): 221 seq_len = value.shape[-1] 222 ib_encoding[key] = torch.cat( 223 [value[pos_idcs].view(num_queries, -1, seq_len), value[neg_idcs].view(num_queries, -1, seq_len)], 224 dim=1, 225 ).view(-1, seq_len) 226 ib_encoding = BatchEncoding(ib_encoding) 227 return BiEncoderEmbedding(ib_embeddings, ib_scoring_mask, ib_encoding) 228
[docs] 229 def validation_step( 230 self, 231 batch: TrainBatch | IndexBatch | SearchBatch | RankBatch, 232 batch_idx: int, 233 dataloader_idx: int = 0, 234 ) -> BiEncoderOutput: 235 """Handles the validation step for the model. 236 237 :param batch: Batch of validation or testing data 238 :type batch: TrainBatch | IndexBatch | SearchBatch | RankBatch 239 :param batch_idx: Index of the batch 240 :type batch_idx: int 241 :param dataloader_idx: Index of the dataloader, defaults to 0 242 :type dataloader_idx: int, optional 243 :return: Model output 244 :rtype: BiEncoderOutput 245 """ 246 if isinstance(batch, IndexBatch): 247 return self.forward(batch) 248 if isinstance(batch, (RankBatch, TrainBatch, SearchBatch)): 249 return super().validation_step(batch, batch_idx, dataloader_idx) 250 raise ValueError(f"Unknown batch type {type(batch)}")