Source code for lightning_ir.retrieve.base.searcher

  1"""Base searcher class and configuration for retrieval tasks."""
  2
  3from __future__ import annotations
  4
  5from abc import ABC, abstractmethod
  6from pathlib import Path
  7from typing import TYPE_CHECKING, List, Literal, Set, Tuple, Type
  8
  9import torch
 10
 11from ...bi_encoder.bi_encoder_model import BiEncoderEmbedding, SingleVectorBiEncoderConfig
 12from .packed_tensor import PackedTensor
 13
 14if TYPE_CHECKING:
 15    from ...bi_encoder import BiEncoderModule, BiEncoderOutput
 16
 17
[docs] 18def cat_arange(arange_starts: torch.Tensor, arange_ends: torch.Tensor) -> torch.Tensor: 19 """Concatenates arange tensors into a single tensor. 20 21 Args: 22 arange_starts (torch.Tensor): The start indices of the ranges. 23 arange_ends (torch.Tensor): The end indices of the ranges. 24 Returns: 25 torch.Tensor: A tensor containing the concatenated ranges. 26 """ 27 arange_lengths = arange_ends - arange_starts 28 offsets = torch.cumsum(arange_lengths, dim=0) - arange_lengths - arange_starts 29 return torch.arange(arange_lengths.sum()) - torch.repeat_interleave(offsets, arange_lengths)
30 31
[docs] 32class Searcher(ABC): 33 """Base class for searchers in the Lightning IR framework.""" 34
[docs] 35 def __init__( 36 self, index_dir: Path | str, search_config: SearchConfig, module: BiEncoderModule, use_gpu: bool = True 37 ) -> None: 38 """Initialize the Searcher. 39 40 Args: 41 index_dir (Path | str): The directory containing the index files. 42 search_config (SearchConfig): The configuration for the search. 43 module (BiEncoderModule): The bi-encoder module to use for scoring. 44 use_gpu (bool): Whether to use GPU for computations. Defaults to True. 45 Raises: 46 ValueError: If the document lengths do not match the index. 47 """ 48 super().__init__() 49 self.index_dir = Path(index_dir) 50 self.search_config = search_config 51 self.use_gpu = use_gpu 52 self.module = module 53 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu") 54 55 self.doc_ids = (self.index_dir / "doc_ids.txt").read_text().split() 56 self.doc_lengths = torch.load(self.index_dir / "doc_lengths.pt", weights_only=True) 57 58 self.to_gpu() 59 60 self.num_docs = len(self.doc_ids) 61 self.cumulative_doc_lengths = torch.cumsum(self.doc_lengths, dim=0) 62 self.num_embeddings = int(self.cumulative_doc_lengths[-1].item()) 63 64 self.doc_is_single_vector = self.num_docs == self.num_embeddings 65 self.query_is_single_vector = isinstance(module.config, SingleVectorBiEncoderConfig) or getattr( 66 module.config, "query_pooling_strategy", None 67 ) in {"first", "mean", "min", "max"} 68 69 if self.doc_lengths.shape[0] != self.num_docs or self.doc_lengths.sum() != self.num_embeddings: 70 raise ValueError("doc_lengths do not match index")
71
[docs] 72 def to_gpu(self) -> None: 73 """Move the searcher to the GPU if available.""" 74 self.doc_lengths = self.doc_lengths.to(self.device)
75 76 def _filter_and_sort( 77 self, doc_scores: PackedTensor, doc_idcs: PackedTensor, k: int | None = None 78 ) -> Tuple[PackedTensor, PackedTensor]: 79 """Filter and sort the document scores and indices. 80 81 Args: 82 doc_scores (PackedTensor): The document scores. 83 doc_idcs (PackedTensor): The document indices. 84 k (int | None): The number of top documents to return. If None, use the configured k from search_config. 85 Defaults to None. 86 Returns: 87 Tuple[PackedTensor, PackedTensor]: The filtered and sorted document scores and indices. 88 """ 89 k = k or self.search_config.k 90 per_query_doc_scores = torch.split(doc_scores, doc_scores.lengths) 91 per_query_doc_idcs = torch.split(doc_idcs, doc_idcs.lengths) 92 num_docs = [] 93 new_doc_scores = [] 94 new_doc_idcs = [] 95 for _scores, _idcs in zip(per_query_doc_scores, per_query_doc_idcs): 96 _k = min(k, _scores.shape[0]) 97 top_values, top_idcs = torch.topk(_scores, _k) 98 new_doc_scores.append(top_values) 99 new_doc_idcs.append(_idcs[top_idcs]) 100 num_docs.append(_k) 101 return PackedTensor(torch.cat(new_doc_scores), lengths=num_docs), PackedTensor( 102 torch.cat(new_doc_idcs), lengths=num_docs 103 ) 104
[docs] 105 @abstractmethod 106 def search(self, output: BiEncoderOutput) -> Tuple[PackedTensor, List[List[str]]]: 107 """Search for documents based on the output of the bi-encoder model. 108 109 Args: 110 output (BiEncoderOutput): The output from the bi-encoder model containing query and document embeddings. 111 Returns: 112 Tuple[PackedTensor, List[List[str]]]: The top-k scores and corresponding document IDs. 113 """ 114 ...
115 116
[docs] 117class ExactSearcher(Searcher): 118 """Searcher that retrieves documents using exact matching of query embeddings.""" 119
[docs] 120 def search(self, output: BiEncoderOutput) -> Tuple[PackedTensor, List[List[str]]]: 121 """Search for documents based on the output of the bi-encoder model. 122 123 Args: 124 output (BiEncoderOutput): The output from the bi-encoder model containing query and document embeddings. 125 Returns: 126 Tuple[PackedTensor, List[List[str]]]: The top-k scores and corresponding document IDs. 127 """ 128 query_embeddings = output.query_embeddings 129 if query_embeddings is None: 130 raise ValueError("Expected query_embeddings in BiEncoderOutput") 131 query_embeddings = query_embeddings.to(self.device) 132 133 scores = self._score(query_embeddings) 134 135 # aggregate doc token scores 136 if not self.doc_is_single_vector: 137 scores = torch.scatter_reduce( 138 torch.zeros(scores.shape[0], self.num_docs, device=scores.device), 139 1, 140 self.doc_token_idcs[None].long().expand_as(scores), 141 scores, 142 "amax", 143 ) 144 145 # aggregate query token scores 146 if not self.query_is_single_vector: 147 if query_embeddings.scoring_mask is None: 148 raise ValueError("Expected scoring_mask in multi-vector query_embeddings") 149 query_lengths = query_embeddings.scoring_mask.sum(-1) 150 query_token_idcs = torch.arange(query_lengths.shape[0]).to(query_lengths).repeat_interleave(query_lengths) 151 scores = torch.scatter_reduce( 152 torch.zeros(query_lengths.shape[0], self.num_docs, device=scores.device), 153 0, 154 query_token_idcs[:, None].expand_as(scores), 155 scores, 156 self.module.config.query_aggregation_function, 157 ) 158 top_scores, top_idcs = torch.topk(scores, self.search_config.k) 159 doc_ids = [[self.doc_ids[idx] for idx in _doc_idcs] for _doc_idcs in top_idcs.tolist()] 160 return PackedTensor(top_scores.view(-1), lengths=[self.search_config.k] * len(doc_ids)), doc_ids
161 162 @property 163 def doc_token_idcs(self) -> torch.Tensor: 164 """Get the document token indices for scoring. 165 166 Returns: 167 torch.Tensor: The document token indices. 168 """ 169 if not hasattr(self, "_doc_token_idcs"): 170 self._doc_token_idcs = ( 171 torch.arange(self.doc_lengths.shape[0]) 172 .to(device=self.doc_lengths.device) 173 .repeat_interleave(self.doc_lengths) 174 ) 175 return self._doc_token_idcs 176 177 @abstractmethod 178 def _score(self, query_embeddings: BiEncoderEmbedding) -> torch.Tensor: ...
179 180
[docs] 181class ApproximateSearcher(Searcher): 182
[docs] 183 def search(self, output: BiEncoderOutput) -> Tuple[PackedTensor, List[List[str]]]: 184 """Search for documents based on the output of the bi-encoder model. 185 186 Args: 187 output (BiEncoderOutput): The output from the bi-encoder model containing query and document embeddings. 188 Returns: 189 Tuple[PackedTensor, List[List[str]]]: The top-k scores and corresponding document IDs. 190 """ 191 query_embeddings = output.query_embeddings 192 if query_embeddings is None: 193 raise ValueError("Expected query_embeddings in BiEncoderOutput") 194 query_embeddings = query_embeddings.to(self.device) 195 196 candidate_scores, candidate_idcs = self._candidate_retrieval(query_embeddings) 197 scores, doc_idcs = self._aggregate_doc_scores(candidate_scores, candidate_idcs, query_embeddings) 198 scores = self._aggregate_query_scores(scores, query_embeddings) 199 scores, doc_idcs = self._filter_and_sort(scores, doc_idcs) 200 doc_ids = [ 201 [self.doc_ids[doc_idx] for doc_idx in _doc_ids.tolist()] for _doc_ids in doc_idcs.split(doc_idcs.lengths) 202 ] 203 204 return scores, doc_ids
205 206 def _aggregate_doc_scores( 207 self, candidate_scores: PackedTensor, candidate_idcs: PackedTensor, query_embeddings: BiEncoderEmbedding 208 ) -> Tuple[PackedTensor, PackedTensor]: 209 if self.doc_is_single_vector: 210 return candidate_scores, candidate_idcs 211 212 query_lengths = query_embeddings.scoring_mask.sum(-1) 213 num_query_vecs = query_lengths.sum() 214 215 # map vec_idcs to doc_idcs 216 candidate_doc_idcs = torch.searchsorted( 217 self.cumulative_doc_lengths, 218 candidate_idcs.to(self.cumulative_doc_lengths.device), 219 side="right", 220 ) 221 222 # convert candidate_scores `num_query_vecs x candidate_k` to `num_query_doc_pairs x num_query_vecs` 223 # and aggregate the maximum doc_vector score per query_vector 224 max_query_length = query_lengths.max() 225 num_docs_per_query_candidate = torch.tensor(candidate_scores.lengths) 226 227 query_idcs = ( 228 torch.arange(query_lengths.shape[0], device=query_lengths.device) 229 .repeat_interleave(query_lengths) 230 .repeat_interleave(num_docs_per_query_candidate) 231 ) 232 query_vector_idcs = cat_arange(torch.zeros_like(query_lengths), query_lengths).repeat_interleave( 233 num_docs_per_query_candidate 234 ) 235 236 stacked = torch.stack([query_idcs, candidate_doc_idcs]) 237 unique_idcs, ranking_doc_idcs = stacked.unique(return_inverse=True, dim=1) 238 num_docs = unique_idcs[0].bincount() 239 doc_idcs = PackedTensor(unique_idcs[1], lengths=num_docs.tolist()) 240 total_num_docs = num_docs.sum() 241 242 unpacked_scores = torch.full((total_num_docs * max_query_length,), float("nan"), device=query_lengths.device) 243 index = ranking_doc_idcs * max_query_length + query_vector_idcs 244 unpacked_scores = torch.scatter_reduce( 245 unpacked_scores, 0, index, candidate_scores, "max", include_self=False 246 ).view(total_num_docs, max_query_length) 247 248 # impute the missing values 249 if self.search_config.imputation_strategy == "gather": 250 # reconstruct the doc embeddings and re-compute the scores 251 imputation_values = torch.empty_like(unpacked_scores) 252 doc_embeddings = self._reconstruct_doc_embeddings(doc_idcs) 253 similarity = self.module.model.compute_similarity(query_embeddings, doc_embeddings, doc_idcs.lengths) 254 unpacked_scores = self.module.model._aggregate( 255 similarity, doc_embeddings.scoring_mask, "max", dim=-1 256 ).squeeze(-1) 257 elif self.search_config.imputation_strategy == "min": 258 per_query_vec_min = torch.scatter_reduce( 259 torch.empty(num_query_vecs), 260 0, 261 torch.arange(query_lengths.sum()).repeat_interleave(num_docs_per_query_candidate), 262 candidate_scores, 263 "min", 264 include_self=False, 265 ) 266 imputation_values = torch.nn.utils.rnn.pad_sequence( 267 per_query_vec_min.split(query_lengths.tolist()), batch_first=True 268 ).repeat_interleave(num_docs, dim=0) 269 elif self.search_config.imputation_strategy == "zero": 270 imputation_values = torch.zeros_like(unpacked_scores) 271 else: 272 raise ValueError("Invalid imputation strategy: " f"{self.search_config.imputation_strategy}") 273 274 is_nan = torch.isnan(unpacked_scores) 275 unpacked_scores[is_nan] = imputation_values[is_nan] 276 277 return PackedTensor(unpacked_scores, lengths=num_docs.tolist()), doc_idcs 278 279 def _aggregate_query_scores(self, scores: PackedTensor, query_embeddings: BiEncoderEmbedding) -> PackedTensor: 280 if self.query_is_single_vector: 281 return scores 282 query_scoring_mask = query_embeddings.scoring_mask.repeat_interleave(torch.tensor(scores.lengths), dim=0) 283 scores = PackedTensor( 284 self.module.model._aggregate( 285 scores, query_scoring_mask, self.module.config.query_aggregation_function, dim=1 286 ).squeeze(-1), 287 lengths=scores.lengths, 288 ) 289 return scores 290 291 @abstractmethod 292 def _candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[PackedTensor, PackedTensor]: 293 """Retrieves initial candidates using the query embeddings. Returns candidate scores and candidate vector 294 indices of shape `num_query_vecs x candidate_k` (packed). Candidate indices are None if all doc vectors are 295 scored. 296 297 Args: 298 query_embeddings (BiEncoderEmbedding): The query embeddings to use for candidate retrieval. 299 Returns: 300 Tuple[PackedTensor, PackedTensor]: The candidate scores and candidate vector indices. 301 """ 302 ... 303 304 @abstractmethod 305 def _gather_doc_embeddings(self, idcs: torch.Tensor) -> torch.Tensor: 306 """Gather document embeddings based on the provided indices. 307 308 Args: 309 idcs (torch.Tensor): The indices of the document embeddings to gather. 310 Returns: 311 torch.Tensor: The gathered document embeddings. 312 """ 313 ... 314 315 def _reconstruct_doc_embeddings(self, doc_idcs: PackedTensor) -> BiEncoderEmbedding: 316 """Reconstruct document embeddings based on the provided document indices. 317 318 Args: 319 doc_idcs (PackedTensor): The packed tensor containing document indices. 320 Returns: 321 BiEncoderEmbedding: The reconstructed document embeddings. 322 """ 323 # unique doc_idcs per query 324 unique_doc_idcs, inverse_idcs = torch.unique(doc_idcs, return_inverse=True) 325 326 # gather all vectors for unique doc_idcs 327 doc_lengths = self.doc_lengths[unique_doc_idcs] 328 start_doc_idcs = self.cumulative_doc_lengths[unique_doc_idcs - 1] 329 start_doc_idcs[unique_doc_idcs == 0] = 0 330 all_doc_idcs = cat_arange(start_doc_idcs, start_doc_idcs + doc_lengths) 331 all_doc_embeddings = self._gather_doc_embeddings(all_doc_idcs) 332 unique_embeddings = torch.nn.utils.rnn.pad_sequence( 333 [embeddings for embeddings in torch.split(all_doc_embeddings, doc_lengths.tolist())], 334 batch_first=True, 335 ).to(inverse_idcs.device) 336 embeddings = unique_embeddings[inverse_idcs] 337 338 # mask out padding 339 doc_lengths = doc_lengths[inverse_idcs] 340 scoring_mask = torch.arange(embeddings.shape[1], device=embeddings.device) < doc_lengths[:, None] 341 doc_embeddings = BiEncoderEmbedding(embeddings=embeddings, scoring_mask=scoring_mask, encoding=None) 342 return doc_embeddings
343 344
[docs] 345class SearchConfig: 346 """Configuration class for searchers in the Lightning IR framework.""" 347 348 search_class: Type[Searcher] 349 350 SUPPORTED_MODELS: Set[str] 351
[docs] 352 def __init__(self, k: int = 10) -> None: 353 """Initialize the SearchConfig. 354 355 Args: 356 k (int): The number of top documents to retrieve. Defaults to 10. 357 """ 358 self.k = k
359 360
[docs] 361class ExactSearchConfig(SearchConfig): 362 """Configuration class for exact searchers in the Lightning IR framework.""" 363 364 search_class = ExactSearcher
365 366
[docs] 367class ApproximateSearchConfig(SearchConfig): 368 """Configuration class for approximate searchers in the Lightning IR framework.""" 369 370 search_class = ApproximateSearcher 371
[docs] 372 def __init__( 373 self, k: int = 10, candidate_k: int = 100, imputation_strategy: Literal["min", "gather", "zero"] = "gather" 374 ) -> None: 375 """Initialize the ApproximateSearchConfig. 376 377 Args: 378 k (int): The number of top documents to retrieve. Defaults to 10. 379 candidate_k (int): The number of candidate documents to consider for scoring. Defaults to 100. 380 imputation_strategy (Literal["min", "gather", "zero"]): Strategy for imputing missing scores. Defaults to 381 "gather". 382 """ 383 super().__init__(k) 384 self.k = k 385 self.candidate_k = candidate_k 386 self.imputation_strategy = imputation_strategy