1"""Base searcher class and configuration for retrieval tasks."""
2
3from __future__ import annotations
4
5from abc import ABC, abstractmethod
6from pathlib import Path
7from typing import TYPE_CHECKING, List, Literal, Set, Tuple, Type
8
9import torch
10
11from ...bi_encoder.bi_encoder_model import BiEncoderEmbedding, SingleVectorBiEncoderConfig
12from .packed_tensor import PackedTensor
13
14if TYPE_CHECKING:
15 from ...bi_encoder import BiEncoderModule, BiEncoderOutput
16
17
[docs]
18def cat_arange(arange_starts: torch.Tensor, arange_ends: torch.Tensor) -> torch.Tensor:
19 """Concatenates arange tensors into a single tensor.
20
21 Args:
22 arange_starts (torch.Tensor): The start indices of the ranges.
23 arange_ends (torch.Tensor): The end indices of the ranges.
24 Returns:
25 torch.Tensor: A tensor containing the concatenated ranges.
26 """
27 arange_lengths = arange_ends - arange_starts
28 offsets = torch.cumsum(arange_lengths, dim=0) - arange_lengths - arange_starts
29 return torch.arange(arange_lengths.sum()) - torch.repeat_interleave(offsets, arange_lengths)
30
31
[docs]
32class Searcher(ABC):
33 """Base class for searchers in the Lightning IR framework."""
34
[docs]
35 def __init__(
36 self, index_dir: Path | str, search_config: SearchConfig, module: BiEncoderModule, use_gpu: bool = True
37 ) -> None:
38 """Initialize the Searcher.
39
40 Args:
41 index_dir (Path | str): The directory containing the index files.
42 search_config (SearchConfig): The configuration for the search.
43 module (BiEncoderModule): The bi-encoder module to use for scoring.
44 use_gpu (bool): Whether to use GPU for computations. Defaults to True.
45 Raises:
46 ValueError: If the document lengths do not match the index.
47 """
48 super().__init__()
49 self.index_dir = Path(index_dir)
50 self.search_config = search_config
51 self.use_gpu = use_gpu
52 self.module = module
53 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
54
55 self.doc_ids = (self.index_dir / "doc_ids.txt").read_text().split()
56 self.doc_lengths = torch.load(self.index_dir / "doc_lengths.pt", weights_only=True)
57
58 self.to_gpu()
59
60 self.num_docs = len(self.doc_ids)
61 self.cumulative_doc_lengths = torch.cumsum(self.doc_lengths, dim=0)
62 self.num_embeddings = int(self.cumulative_doc_lengths[-1].item())
63
64 self.doc_is_single_vector = self.num_docs == self.num_embeddings
65 self.query_is_single_vector = isinstance(module.config, SingleVectorBiEncoderConfig) or getattr(
66 module.config, "query_pooling_strategy", None
67 ) in {"first", "mean", "min", "max"}
68
69 if self.doc_lengths.shape[0] != self.num_docs or self.doc_lengths.sum() != self.num_embeddings:
70 raise ValueError("doc_lengths do not match index")
71
[docs]
72 def to_gpu(self) -> None:
73 """Move the searcher to the GPU if available."""
74 self.doc_lengths = self.doc_lengths.to(self.device)
75
76 def _filter_and_sort(
77 self, doc_scores: PackedTensor, doc_idcs: PackedTensor, k: int | None = None
78 ) -> Tuple[PackedTensor, PackedTensor]:
79 """Filter and sort the document scores and indices.
80
81 Args:
82 doc_scores (PackedTensor): The document scores.
83 doc_idcs (PackedTensor): The document indices.
84 k (int | None): The number of top documents to return. If None, use the configured k from search_config.
85 Defaults to None.
86 Returns:
87 Tuple[PackedTensor, PackedTensor]: The filtered and sorted document scores and indices.
88 """
89 k = k or self.search_config.k
90 per_query_doc_scores = torch.split(doc_scores, doc_scores.lengths)
91 per_query_doc_idcs = torch.split(doc_idcs, doc_idcs.lengths)
92 num_docs = []
93 new_doc_scores = []
94 new_doc_idcs = []
95 for _scores, _idcs in zip(per_query_doc_scores, per_query_doc_idcs):
96 _k = min(k, _scores.shape[0])
97 top_values, top_idcs = torch.topk(_scores, _k)
98 new_doc_scores.append(top_values)
99 new_doc_idcs.append(_idcs[top_idcs])
100 num_docs.append(_k)
101 return PackedTensor(torch.cat(new_doc_scores), lengths=num_docs), PackedTensor(
102 torch.cat(new_doc_idcs), lengths=num_docs
103 )
104
[docs]
105 @abstractmethod
106 def search(self, output: BiEncoderOutput) -> Tuple[PackedTensor, List[List[str]]]:
107 """Search for documents based on the output of the bi-encoder model.
108
109 Args:
110 output (BiEncoderOutput): The output from the bi-encoder model containing query and document embeddings.
111 Returns:
112 Tuple[PackedTensor, List[List[str]]]: The top-k scores and corresponding document IDs.
113 """
114 ...
115
116
[docs]
117class ExactSearcher(Searcher):
118 """Searcher that retrieves documents using exact matching of query embeddings."""
119
[docs]
120 def search(self, output: BiEncoderOutput) -> Tuple[PackedTensor, List[List[str]]]:
121 """Search for documents based on the output of the bi-encoder model.
122
123 Args:
124 output (BiEncoderOutput): The output from the bi-encoder model containing query and document embeddings.
125 Returns:
126 Tuple[PackedTensor, List[List[str]]]: The top-k scores and corresponding document IDs.
127 """
128 query_embeddings = output.query_embeddings
129 if query_embeddings is None:
130 raise ValueError("Expected query_embeddings in BiEncoderOutput")
131 query_embeddings = query_embeddings.to(self.device)
132
133 scores = self._score(query_embeddings)
134
135 # aggregate doc token scores
136 if not self.doc_is_single_vector:
137 scores = torch.scatter_reduce(
138 torch.zeros(scores.shape[0], self.num_docs, device=scores.device),
139 1,
140 self.doc_token_idcs[None].long().expand_as(scores),
141 scores,
142 "amax",
143 )
144
145 # aggregate query token scores
146 if not self.query_is_single_vector:
147 if query_embeddings.scoring_mask is None:
148 raise ValueError("Expected scoring_mask in multi-vector query_embeddings")
149 query_lengths = query_embeddings.scoring_mask.sum(-1)
150 query_token_idcs = torch.arange(query_lengths.shape[0]).to(query_lengths).repeat_interleave(query_lengths)
151 scores = torch.scatter_reduce(
152 torch.zeros(query_lengths.shape[0], self.num_docs, device=scores.device),
153 0,
154 query_token_idcs[:, None].expand_as(scores),
155 scores,
156 self.module.config.query_aggregation_function,
157 )
158 top_scores, top_idcs = torch.topk(scores, self.search_config.k)
159 doc_ids = [[self.doc_ids[idx] for idx in _doc_idcs] for _doc_idcs in top_idcs.tolist()]
160 return PackedTensor(top_scores.view(-1), lengths=[self.search_config.k] * len(doc_ids)), doc_ids
161
162 @property
163 def doc_token_idcs(self) -> torch.Tensor:
164 """Get the document token indices for scoring.
165
166 Returns:
167 torch.Tensor: The document token indices.
168 """
169 if not hasattr(self, "_doc_token_idcs"):
170 self._doc_token_idcs = (
171 torch.arange(self.doc_lengths.shape[0])
172 .to(device=self.doc_lengths.device)
173 .repeat_interleave(self.doc_lengths)
174 )
175 return self._doc_token_idcs
176
177 @abstractmethod
178 def _score(self, query_embeddings: BiEncoderEmbedding) -> torch.Tensor: ...
179
180
[docs]
181class ApproximateSearcher(Searcher):
182
[docs]
183 def search(self, output: BiEncoderOutput) -> Tuple[PackedTensor, List[List[str]]]:
184 """Search for documents based on the output of the bi-encoder model.
185
186 Args:
187 output (BiEncoderOutput): The output from the bi-encoder model containing query and document embeddings.
188 Returns:
189 Tuple[PackedTensor, List[List[str]]]: The top-k scores and corresponding document IDs.
190 """
191 query_embeddings = output.query_embeddings
192 if query_embeddings is None:
193 raise ValueError("Expected query_embeddings in BiEncoderOutput")
194 query_embeddings = query_embeddings.to(self.device)
195
196 candidate_scores, candidate_idcs = self._candidate_retrieval(query_embeddings)
197 scores, doc_idcs = self._aggregate_doc_scores(candidate_scores, candidate_idcs, query_embeddings)
198 scores = self._aggregate_query_scores(scores, query_embeddings)
199 scores, doc_idcs = self._filter_and_sort(scores, doc_idcs)
200 doc_ids = [
201 [self.doc_ids[doc_idx] for doc_idx in _doc_ids.tolist()] for _doc_ids in doc_idcs.split(doc_idcs.lengths)
202 ]
203
204 return scores, doc_ids
205
206 def _aggregate_doc_scores(
207 self, candidate_scores: PackedTensor, candidate_idcs: PackedTensor, query_embeddings: BiEncoderEmbedding
208 ) -> Tuple[PackedTensor, PackedTensor]:
209 if self.doc_is_single_vector:
210 return candidate_scores, candidate_idcs
211
212 query_lengths = query_embeddings.scoring_mask.sum(-1)
213 num_query_vecs = query_lengths.sum()
214
215 # map vec_idcs to doc_idcs
216 candidate_doc_idcs = torch.searchsorted(
217 self.cumulative_doc_lengths,
218 candidate_idcs.to(self.cumulative_doc_lengths.device),
219 side="right",
220 )
221
222 # convert candidate_scores `num_query_vecs x candidate_k` to `num_query_doc_pairs x num_query_vecs`
223 # and aggregate the maximum doc_vector score per query_vector
224 max_query_length = query_lengths.max()
225 num_docs_per_query_candidate = torch.tensor(candidate_scores.lengths)
226
227 query_idcs = (
228 torch.arange(query_lengths.shape[0], device=query_lengths.device)
229 .repeat_interleave(query_lengths)
230 .repeat_interleave(num_docs_per_query_candidate)
231 )
232 query_vector_idcs = cat_arange(torch.zeros_like(query_lengths), query_lengths).repeat_interleave(
233 num_docs_per_query_candidate
234 )
235
236 stacked = torch.stack([query_idcs, candidate_doc_idcs])
237 unique_idcs, ranking_doc_idcs = stacked.unique(return_inverse=True, dim=1)
238 num_docs = unique_idcs[0].bincount()
239 doc_idcs = PackedTensor(unique_idcs[1], lengths=num_docs.tolist())
240 total_num_docs = num_docs.sum()
241
242 unpacked_scores = torch.full((total_num_docs * max_query_length,), float("nan"), device=query_lengths.device)
243 index = ranking_doc_idcs * max_query_length + query_vector_idcs
244 unpacked_scores = torch.scatter_reduce(
245 unpacked_scores, 0, index, candidate_scores, "max", include_self=False
246 ).view(total_num_docs, max_query_length)
247
248 # impute the missing values
249 if self.search_config.imputation_strategy == "gather":
250 # reconstruct the doc embeddings and re-compute the scores
251 imputation_values = torch.empty_like(unpacked_scores)
252 doc_embeddings = self._reconstruct_doc_embeddings(doc_idcs)
253 similarity = self.module.model.compute_similarity(query_embeddings, doc_embeddings, doc_idcs.lengths)
254 unpacked_scores = self.module.model._aggregate(
255 similarity, doc_embeddings.scoring_mask, "max", dim=-1
256 ).squeeze(-1)
257 elif self.search_config.imputation_strategy == "min":
258 per_query_vec_min = torch.scatter_reduce(
259 torch.empty(num_query_vecs),
260 0,
261 torch.arange(query_lengths.sum()).repeat_interleave(num_docs_per_query_candidate),
262 candidate_scores,
263 "min",
264 include_self=False,
265 )
266 imputation_values = torch.nn.utils.rnn.pad_sequence(
267 per_query_vec_min.split(query_lengths.tolist()), batch_first=True
268 ).repeat_interleave(num_docs, dim=0)
269 elif self.search_config.imputation_strategy == "zero":
270 imputation_values = torch.zeros_like(unpacked_scores)
271 else:
272 raise ValueError("Invalid imputation strategy: " f"{self.search_config.imputation_strategy}")
273
274 is_nan = torch.isnan(unpacked_scores)
275 unpacked_scores[is_nan] = imputation_values[is_nan]
276
277 return PackedTensor(unpacked_scores, lengths=num_docs.tolist()), doc_idcs
278
279 def _aggregate_query_scores(self, scores: PackedTensor, query_embeddings: BiEncoderEmbedding) -> PackedTensor:
280 if self.query_is_single_vector:
281 return scores
282 query_scoring_mask = query_embeddings.scoring_mask.repeat_interleave(torch.tensor(scores.lengths), dim=0)
283 scores = PackedTensor(
284 self.module.model._aggregate(
285 scores, query_scoring_mask, self.module.config.query_aggregation_function, dim=1
286 ).squeeze(-1),
287 lengths=scores.lengths,
288 )
289 return scores
290
291 @abstractmethod
292 def _candidate_retrieval(self, query_embeddings: BiEncoderEmbedding) -> Tuple[PackedTensor, PackedTensor]:
293 """Retrieves initial candidates using the query embeddings. Returns candidate scores and candidate vector
294 indices of shape `num_query_vecs x candidate_k` (packed). Candidate indices are None if all doc vectors are
295 scored.
296
297 Args:
298 query_embeddings (BiEncoderEmbedding): The query embeddings to use for candidate retrieval.
299 Returns:
300 Tuple[PackedTensor, PackedTensor]: The candidate scores and candidate vector indices.
301 """
302 ...
303
304 @abstractmethod
305 def _gather_doc_embeddings(self, idcs: torch.Tensor) -> torch.Tensor:
306 """Gather document embeddings based on the provided indices.
307
308 Args:
309 idcs (torch.Tensor): The indices of the document embeddings to gather.
310 Returns:
311 torch.Tensor: The gathered document embeddings.
312 """
313 ...
314
315 def _reconstruct_doc_embeddings(self, doc_idcs: PackedTensor) -> BiEncoderEmbedding:
316 """Reconstruct document embeddings based on the provided document indices.
317
318 Args:
319 doc_idcs (PackedTensor): The packed tensor containing document indices.
320 Returns:
321 BiEncoderEmbedding: The reconstructed document embeddings.
322 """
323 # unique doc_idcs per query
324 unique_doc_idcs, inverse_idcs = torch.unique(doc_idcs, return_inverse=True)
325
326 # gather all vectors for unique doc_idcs
327 doc_lengths = self.doc_lengths[unique_doc_idcs]
328 start_doc_idcs = self.cumulative_doc_lengths[unique_doc_idcs - 1]
329 start_doc_idcs[unique_doc_idcs == 0] = 0
330 all_doc_idcs = cat_arange(start_doc_idcs, start_doc_idcs + doc_lengths)
331 all_doc_embeddings = self._gather_doc_embeddings(all_doc_idcs)
332 unique_embeddings = torch.nn.utils.rnn.pad_sequence(
333 [embeddings for embeddings in torch.split(all_doc_embeddings, doc_lengths.tolist())],
334 batch_first=True,
335 ).to(inverse_idcs.device)
336 embeddings = unique_embeddings[inverse_idcs]
337
338 # mask out padding
339 doc_lengths = doc_lengths[inverse_idcs]
340 scoring_mask = torch.arange(embeddings.shape[1], device=embeddings.device) < doc_lengths[:, None]
341 doc_embeddings = BiEncoderEmbedding(embeddings=embeddings, scoring_mask=scoring_mask, encoding=None)
342 return doc_embeddings
343
344
[docs]
345class SearchConfig:
346 """Configuration class for searchers in the Lightning IR framework."""
347
348 search_class: Type[Searcher]
349
350 SUPPORTED_MODELS: Set[str]
351
[docs]
352 def __init__(self, k: int = 10) -> None:
353 """Initialize the SearchConfig.
354
355 Args:
356 k (int): The number of top documents to retrieve. Defaults to 10.
357 """
358 self.k = k
359
360
[docs]
361class ExactSearchConfig(SearchConfig):
362 """Configuration class for exact searchers in the Lightning IR framework."""
363
364 search_class = ExactSearcher
365
366
[docs]
367class ApproximateSearchConfig(SearchConfig):
368 """Configuration class for approximate searchers in the Lightning IR framework."""
369
370 search_class = ApproximateSearcher
371
[docs]
372 def __init__(
373 self, k: int = 10, candidate_k: int = 100, imputation_strategy: Literal["min", "gather", "zero"] = "gather"
374 ) -> None:
375 """Initialize the ApproximateSearchConfig.
376
377 Args:
378 k (int): The number of top documents to retrieve. Defaults to 10.
379 candidate_k (int): The number of candidate documents to consider for scoring. Defaults to 100.
380 imputation_strategy (Literal["min", "gather", "zero"]): Strategy for imputing missing scores. Defaults to
381 "gather".
382 """
383 super().__init__(k)
384 self.k = k
385 self.candidate_k = candidate_k
386 self.imputation_strategy = imputation_strategy