Source code for lightning_ir.retrieve.pytorch.sparse_searcher
1from __future__ import annotations
2
3from pathlib import Path
4from typing import TYPE_CHECKING, Literal
5
6import torch
7
8from ...modeling_utils.batching import _batch_pairwise_scoring
9from ...models import SpladeConfig
10from ..base.searcher import ExactSearchConfig, ExactSearcher
11from .sparse_indexer import TorchSparseIndexConfig
12
13if TYPE_CHECKING:
14 from ...bi_encoder import BiEncoderEmbedding, BiEncoderModule
15
16
[docs]
17class TorchSparseIndex:
[docs]
18 def __init__(self, index_dir: Path, similarity_function: Literal["dot", "cosine"], use_gpu: bool = False) -> None:
19 self.index = torch.load(index_dir / "index.pt", weights_only=True)
20 self.config = TorchSparseIndexConfig.from_pretrained(index_dir)
21 if similarity_function == "dot":
22 self.similarity_function = self.dot_similarity
23 elif similarity_function == "cosine":
24 self.similarity_function = self.cosine_similarity
25 else:
26 raise ValueError("Unknown similarity function")
27 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
28
29 def score(self, embeddings: torch.Tensor) -> torch.Tensor:
30 embeddings = embeddings.to(self.device)
31 similarity = self.similarity_function(embeddings, self.index).to_dense()
32 return similarity
33
34 @property
35 def num_embeddings(self) -> int:
36 return self.index.shape[0]
37
38 @staticmethod
39 @_batch_pairwise_scoring
40 @torch.autocast(device_type="cuda", enabled=False)
41 def cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
42 return y.matmul(x.T).T / (torch.norm(x, dim=-1)[:, None] * torch.norm(y, dim=-1)[None])
43
44 @staticmethod
45 @_batch_pairwise_scoring
46 @torch.autocast(device_type="cuda", enabled=False)
47 def dot_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
48 return y.matmul(x.T).T
49
50 def to_gpu(self) -> None:
51 self.index = self.index.to(self.device)
52
53
[docs]
54class TorchSparseSearcher(ExactSearcher):
[docs]
55 def __init__(
56 self,
57 index_dir: Path,
58 search_config: TorchSparseSearchConfig,
59 module: BiEncoderModule,
60 use_gpu: bool = True,
61 ) -> None:
62 self.search_config: TorchSparseSearchConfig
63 self.index = TorchSparseIndex(index_dir, module.config.similarity_function, use_gpu)
64 super().__init__(index_dir, search_config, module, use_gpu)
65 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
66
67 def to_gpu(self) -> None:
68 super().to_gpu()
69 self.index.to_gpu()
70
71 def _score(self, query_embeddings: BiEncoderEmbedding) -> torch.Tensor:
72 if query_embeddings.scoring_mask is None:
73 embeddings = query_embeddings.embeddings[:, 0]
74 else:
75 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask]
76 scores = self.index.score(embeddings)
77 return scores
78
79
[docs]
80class TorchSparseSearchConfig(ExactSearchConfig):
81 search_class = TorchSparseSearcher
82 SUPPORTED_MODELS = {SpladeConfig.model_type}