Source code for lightning_ir.retrieve.base.searcher

  1from __future__ import annotations
  2
  3from abc import ABC, abstractmethod
  4from pathlib import Path
  5from typing import TYPE_CHECKING, List, Literal, Set, Tuple, Type
  6
  7import torch
  8
  9from ...bi_encoder.bi_encoder_model import BiEncoderEmbedding, SingleVectorBiEncoderConfig
 10from .packed_tensor import PackedTensor
 11
 12if TYPE_CHECKING:
 13    from ...bi_encoder import BiEncoderModule, BiEncoderOutput
 14
 15
[docs] 16def cat_arange(arange_starts: torch.Tensor, arange_ends: torch.Tensor) -> torch.Tensor: 17 arange_lengths = arange_ends - arange_starts 18 offsets = torch.cumsum(arange_lengths, dim=0) - arange_lengths - arange_starts 19 return torch.arange(arange_lengths.sum()) - torch.repeat_interleave(offsets, arange_lengths)
20 21
[docs] 22class Searcher(ABC):
[docs] 23 def __init__( 24 self, index_dir: Path | str, search_config: SearchConfig, module: BiEncoderModule, use_gpu: bool = True 25 ) -> None: 26 super().__init__() 27 self.index_dir = Path(index_dir) 28 self.search_config = search_config 29 self.use_gpu = use_gpu 30 self.module = module 31 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu") 32 33 self.doc_ids = (self.index_dir / "doc_ids.txt").read_text().split() 34 self.doc_lengths = torch.load(self.index_dir / "doc_lengths.pt", weights_only=True) 35 36 self.to_gpu() 37 38 self.num_docs = len(self.doc_ids) 39 self.cumulative_doc_lengths = torch.cumsum(self.doc_lengths, dim=0) 40 self.num_embeddings = int(self.cumulative_doc_lengths[-1].item()) 41 42 self.doc_is_single_vector = self.num_docs == self.num_embeddings 43 self.query_is_single_vector = isinstance(module.config, SingleVectorBiEncoderConfig) or getattr( 44 module.config, "query_pooling_strategy", None 45 ) in {"first", "mean", "min", "max"} 46 47 if self.doc_lengths.shape[0] != self.num_docs or self.doc_lengths.sum() != self.num_embeddings: 48 raise ValueError("doc_lengths do not match index")
49 50 def to_gpu(self) -> None: 51 self.doc_lengths = self.doc_lengths.to(self.device) 52 53 def _filter_and_sort( 54 self, doc_scores: PackedTensor, doc_idcs: PackedTensor, k: int | None = None 55 ) -> Tuple[PackedTensor, PackedTensor]: 56 k = k or self.search_config.k 57 per_query_doc_scores = torch.split(doc_scores, doc_scores.lengths) 58 per_query_doc_idcs = torch.split(doc_idcs, doc_idcs.lengths) 59 num_docs = [] 60 new_doc_scores = [] 61 new_doc_idcs = [] 62 for _scores, _idcs in zip(per_query_doc_scores, per_query_doc_idcs): 63 _k = min(k, _scores.shape[0]) 64 top_values, top_idcs = torch.topk(_scores, _k) 65 new_doc_scores.append(top_values) 66 new_doc_idcs.append(_idcs[top_idcs]) 67 num_docs.append(_k) 68 return PackedTensor(torch.cat(new_doc_scores), lengths=num_docs), PackedTensor( 69 torch.cat(new_doc_idcs), lengths=num_docs 70 ) 71 72 @abstractmethod 73 def search(self, output: BiEncoderOutput) -> Tuple[PackedTensor, List[List[str]]]: ...
74 75
[docs] 76class ExactSearcher(Searcher): 77 78 def search(self, output: BiEncoderOutput) -> Tuple[PackedTensor, List[List[str]]]: 79 query_embeddings = output.query_embeddings 80 if query_embeddings is None: 81 raise ValueError("Expected query_embeddings in BiEncoderOutput") 82 query_embeddings = query_embeddings.to(self.device) 83 84 scores = self._score(query_embeddings) 85 86 # aggregate doc token scores 87 if not self.doc_is_single_vector: 88 scores = torch.scatter_reduce( 89 torch.zeros(scores.shape[0], self.num_docs, device=scores.device), 90 1, 91 self.doc_token_idcs[None].long().expand_as(scores), 92 scores, 93 "amax", 94 ) 95 96 # aggregate query token scores 97 if not self.query_is_single_vector: 98 if query_embeddings.scoring_mask is None: 99 raise ValueError("Expected scoring_mask in multi-vector query_embeddings") 100 query_lengths = query_embeddings.scoring_mask.sum(-1) 101 query_token_idcs = torch.arange(query_lengths.shape[0]).to(query_lengths).repeat_interleave(query_lengths) 102 scores = torch.scatter_reduce( 103 torch.zeros(query_lengths.shape[0], self.num_docs, device=scores.device), 104 0, 105 query_token_idcs[:, None].expand_as(scores), 106 scores, 107 self.module.config.query_aggregation_function, 108 ) 109 top_scores, top_idcs = torch.topk(scores, self.search_config.k) 110 doc_ids = [[self.doc_ids[idx] for idx in _doc_idcs] for _doc_idcs in top_idcs.tolist()] 111 return PackedTensor(top_scores.view(-1), lengths=[self.search_config.k] * len(doc_ids)), doc_ids 112 113 @property 114 def doc_token_idcs(self) -> torch.Tensor: 115 if not hasattr(self, "_doc_token_idcs"): 116 self._doc_token_idcs = ( 117 torch.arange(self.doc_lengths.shape[0]) 118 .to(device=self.doc_lengths.device) 119 .repeat_interleave(self.doc_lengths) 120 ) 121 return self._doc_token_idcs 122 123 @abstractmethod 124 def _score(self, query_embeddings: BiEncoderEmbedding) -> torch.Tensor: ...
125 126
[docs] 127class ApproximateSearcher(Searcher): 128 129 def search(self, output: BiEncoderOutput) -> Tuple[PackedTensor, List[List[str]]]: 130 query_embeddings = output.query_embeddings 131 if query_embeddings is None: 132 raise ValueError("Expected query_embeddings in BiEncoderOutput") 133 query_embeddings = query_embeddings.to(self.device) 134 135 candidate_scores, candidate_idcs = self._candidate_retrieval(query_embeddings) 136 scores, doc_idcs = self._aggregate_doc_scores(candidate_scores, candidate_idcs, query_embeddings) 137 scores = self._aggregate_query_scores(scores, query_embeddings) 138 scores, doc_idcs = self._filter_and_sort(scores, doc_idcs) 139 doc_ids = [ 140 [self.doc_ids[doc_idx] for doc_idx in _doc_ids.tolist()] for _doc_ids in doc_idcs.split(doc_idcs.lengths) 141 ] 142 143 return scores, doc_ids 144 145 def _aggregate_doc_scores( 146 self, candidate_scores: PackedTensor, candidate_idcs: PackedTensor, query_embeddings: BiEncoderEmbedding 147 ) -> Tuple[PackedTensor, PackedTensor]: 148 if self.doc_is_single_vector: 149 return candidate_scores, candidate_idcs 150 151 query_lengths = query_embeddings.scoring_mask.sum(-1) 152 num_query_vecs = query_lengths.sum() 153 154 # map vec_idcs to doc_idcs 155 candidate_doc_idcs = torch.searchsorted( 156 self.cumulative_doc_lengths, 157 candidate_idcs.to(self.cumulative_doc_lengths.device), 158 side="right", 159 ) 160 161 # convert candidate_scores `num_query_vecs x candidate_k` to `num_query_doc_pairs x num_query_vecs` 162 # and aggregate the maximum doc_vector score per query_vector 163 max_query_length = query_lengths.max() 164 num_docs_per_query_candidate = torch.tensor(candidate_scores.lengths) 165 166 query_idcs = ( 167 torch.arange(query_lengths.shape[0], device=query_lengths.device) 168 .repeat_interleave(query_lengths) 169 .repeat_interleave(num_docs_per_query_candidate) 170 ) 171 query_vector_idcs = cat_arange(torch.zeros_like(query_lengths), query_lengths).repeat_interleave( 172 num_docs_per_query_candidate 173 ) 174 175 stacked = torch.stack([query_idcs, candidate_doc_idcs]) 176 unique_idcs, ranking_doc_idcs = stacked.unique(return_inverse=True, dim=1) 177 num_docs = unique_idcs[0].bincount() 178 doc_idcs = PackedTensor(unique_idcs[1], lengths=num_docs.tolist()) 179 total_num_docs = num_docs.sum() 180 181 unpacked_scores = torch.full((total_num_docs * max_query_length,), float("nan"), device=query_lengths.device) 182 index = ranking_doc_idcs * max_query_length + query_vector_idcs 183 unpacked_scores = torch.scatter_reduce( 184 unpacked_scores, 0, index, candidate_scores, "max", include_self=False 185 ).view(total_num_docs, max_query_length) 186 187 # impute the missing values 188 if self.search_config.imputation_strategy == "gather": 189 # reconstruct the doc embeddings and re-compute the scores 190 imputation_values = torch.empty_like(unpacked_scores) 191 doc_embeddings = self._reconstruct_doc_embeddings(doc_idcs) 192 similarity = self.module.model.compute_similarity(query_embeddings, doc_embeddings, doc_idcs.lengths) 193 unpacked_scores = self.module.model._aggregate( 194 similarity, doc_embeddings.scoring_mask, "max", dim=-1 195 ).squeeze(-1) 196 elif self.search_config.imputation_strategy == "min": 197 per_query_vec_min = torch.scatter_reduce( 198 torch.empty(num_query_vecs), 199 0, 200 torch.arange(query_lengths.sum()).repeat_interleave(num_docs_per_query_candidate), 201 candidate_scores, 202 "min", 203 include_self=False, 204 ) 205 imputation_values = torch.nn.utils.rnn.pad_sequence( 206 per_query_vec_min.split(query_lengths.tolist()), batch_first=True 207 ).repeat_interleave(num_docs, dim=0) 208 elif self.search_config.imputation_strategy == "zero": 209 imputation_values = torch.zeros_like(unpacked_scores) 210 else: 211 raise ValueError("Invalid imputation strategy: " f"{self.search_config.imputation_strategy}") 212 213 is_nan = torch.isnan(unpacked_scores) 214 unpacked_scores[is_nan] = imputation_values[is_nan] 215 216 return PackedTensor(unpacked_scores, lengths=num_docs.tolist()), doc_idcs 217 218 def _aggregate_query_scores(self, scores: PackedTensor, query_embeddings: BiEncoderEmbedding) -> PackedTensor: 219 if self.query_is_single_vector: 220 return scores 221 query_scoring_mask = query_embeddings.scoring_mask.repeat_interleave(torch.tensor(scores.lengths), dim=0) 222 scores = PackedTensor( 223 self.module.model._aggregate( 224 scores, query_scoring_mask, self.module.config.query_aggregation_function, dim=1 225 ).squeeze(-1), 226 lengths=scores.lengths, 227 ) 228 return scores 229 230 @abstractmethod 231 def _candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[PackedTensor, PackedTensor]: 232 """Retrieves initial candidates using the query embeddings. Returns candidate scores and candidate vector 233 indices of shape `num_query_vecs x candidate_k` (packed). Candidate indices are None if all doc vectors are 234 scored. 235 236 :return: Candidate scores and candidate vector indices 237 :rtype: Tuple[PackedTensor, PackedTensor] 238 """ 239 ... 240 241 @abstractmethod 242 def _gather_doc_embeddings(self, idcs: torch.Tensor) -> torch.Tensor: 243 """Reconstructs embeddings from indices. 244 245 :param doc_idcs: Indices 246 :type doc_idcs: PackedTensor 247 :return: Reconstructed embeddings 248 :rtype: BiEncoderEmbedding 249 """ 250 ... 251 252 def _reconstruct_doc_embeddings(self, doc_idcs: PackedTensor) -> BiEncoderEmbedding: 253 # unique doc_idcs per query 254 unique_doc_idcs, inverse_idcs = torch.unique(doc_idcs, return_inverse=True) 255 256 # gather all vectors for unique doc_idcs 257 doc_lengths = self.doc_lengths[unique_doc_idcs] 258 start_doc_idcs = self.cumulative_doc_lengths[unique_doc_idcs - 1] 259 start_doc_idcs[unique_doc_idcs == 0] = 0 260 all_doc_idcs = cat_arange(start_doc_idcs, start_doc_idcs + doc_lengths) 261 all_doc_embeddings = self._gather_doc_embeddings(all_doc_idcs) 262 unique_embeddings = torch.nn.utils.rnn.pad_sequence( 263 [embeddings for embeddings in torch.split(all_doc_embeddings, doc_lengths.tolist())], 264 batch_first=True, 265 ).to(inverse_idcs.device) 266 embeddings = unique_embeddings[inverse_idcs] 267 268 # mask out padding 269 doc_lengths = doc_lengths[inverse_idcs] 270 scoring_mask = torch.arange(embeddings.shape[1], device=embeddings.device) < doc_lengths[:, None] 271 doc_embeddings = BiEncoderEmbedding(embeddings=embeddings, scoring_mask=scoring_mask, encoding=None) 272 return doc_embeddings
273 274
[docs] 275class SearchConfig: 276 search_class: Type[Searcher] 277 278 SUPPORTED_MODELS: Set[str] 279
[docs] 280 def __init__(self, k: int = 10) -> None: 281 self.k = k
282 283
[docs] 284class ExactSearchConfig(SearchConfig): 285 search_class = ExactSearcher
286 287
[docs] 288class ApproximateSearchConfig(SearchConfig): 289 search_class = ApproximateSearcher 290
[docs] 291 def __init__( 292 self, k: int = 10, candidate_k: int = 100, imputation_strategy: Literal["min", "gather", "zero"] = "gather" 293 ) -> None: 294 super().__init__(k) 295 self.k = k 296 self.candidate_k = candidate_k 297 self.imputation_strategy = imputation_strategy