Source code for lightning_ir.retrieve.pytorch.sparse_searcher
1"""Torch-based Sparse Searcher for Lightning IR Framework"""
2
3from __future__ import annotations
4
5from pathlib import Path
6from typing import TYPE_CHECKING, Literal
7
8import torch
9
10from ...modeling_utils.batching import _batch_pairwise_scoring
11from ...models import SpladeConfig
12from ..base.searcher import ExactSearchConfig, ExactSearcher
13from .sparse_indexer import TorchSparseIndexConfig
14
15if TYPE_CHECKING:
16 from ...bi_encoder import BiEncoderEmbedding, BiEncoderModule
17
18
[docs]
19class TorchSparseIndex:
20 """Torch-based sparse index for efficient retrieval."""
21
[docs]
22 def __init__(self, index_dir: Path, similarity_function: Literal["dot", "cosine"], use_gpu: bool = False) -> None:
23 """Initialize the TorchSparseIndex.
24
25 Args:
26 index_dir (Path): Directory containing the index files.
27 similarity_function (Literal["dot", "cosine"]): The similarity function to use.
28 use_gpu (bool): Whether to use GPU for computations. Defaults to False.
29 Raises:
30 ValueError: If the similarity function is not recognized.
31 """
32 self.index = torch.load(index_dir / "index.pt", weights_only=True)
33 self.config = TorchSparseIndexConfig.from_pretrained(index_dir)
34 if similarity_function == "dot":
35 self.similarity_function = self.dot_similarity
36 elif similarity_function == "cosine":
37 self.similarity_function = self.cosine_similarity
38 else:
39 raise ValueError("Unknown similarity function")
40 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
41
[docs]
42 def score(self, embeddings: torch.Tensor) -> torch.Tensor:
43 """Compute scores for the given embeddings.
44
45 Args:
46 embeddings (torch.Tensor): The embeddings to score.
47 Returns:
48 torch.Tensor: The computed scores.
49 """
50 embeddings = embeddings.to(self.index)
51 similarity = self.similarity_function(embeddings, self.index).to_dense()
52 return similarity
53
54 @property
55 def num_embeddings(self) -> int:
56 """Get the number of embeddings in the index.
57
58 Returns:
59 int: The number of embeddings.
60 """
61 return self.index.shape[0]
62
[docs]
63 @staticmethod
64 @_batch_pairwise_scoring
65 @torch.autocast(device_type="cuda", enabled=False)
66 def cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
67 """Compute cosine similarity between two tensors.
68
69 Args:
70 x (torch.Tensor): The first tensor.
71 y (torch.Tensor): The second tensor.
72 Returns:
73 torch.Tensor: The cosine similarity scores.
74 """
75 return y.matmul(x.T).T / (torch.norm(x, dim=-1)[:, None] * torch.norm(y, dim=-1)[None])
76
[docs]
77 @staticmethod
78 @_batch_pairwise_scoring
79 @torch.autocast(device_type="cuda", enabled=False)
80 def dot_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
81 """Compute dot product similarity between two tensors.
82
83 Args:
84 x (torch.Tensor): The first tensor.
85 y (torch.Tensor): The second tensor.
86 Returns:
87 torch.Tensor: The dot product similarity scores.
88 """
89 return y.matmul(x.T).T
90
[docs]
91 def to_gpu(self) -> None:
92 """Move the index to GPU if available."""
93 self.index = self.index.to(self.device)
94
95
[docs]
96class TorchSparseSearcher(ExactSearcher):
97 """Torch-based sparse searcher for Lightning IR framework."""
98
[docs]
99 def __init__(
100 self,
101 index_dir: Path,
102 search_config: TorchSparseSearchConfig,
103 module: BiEncoderModule,
104 use_gpu: bool = True,
105 ) -> None:
106 """Initialize the TorchSparseSearcher.
107
108 Args:
109 index_dir (Path): Directory containing the index files.
110 search_config (TorchSparseSearchConfig): Configuration for the searcher.
111 module (BiEncoderModule): The BiEncoder module to use for scoring.
112 use_gpu (bool): Whether to use GPU for computations. Defaults to True.
113 """
114 self.search_config: TorchSparseSearchConfig
115 self.index = TorchSparseIndex(index_dir, module.config.similarity_function, use_gpu)
116 super().__init__(index_dir, search_config, module, use_gpu)
117 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
118
[docs]
119 def to_gpu(self) -> None:
120 """Move the searcher and index to GPU if available."""
121 super().to_gpu()
122 self.index.to_gpu()
123
124 def _score(self, query_embeddings: BiEncoderEmbedding) -> torch.Tensor:
125 if query_embeddings.scoring_mask is None:
126 embeddings = query_embeddings.embeddings[:, 0]
127 else:
128 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask]
129 scores = self.index.score(embeddings)
130 return scores
131
132
[docs]
133class TorchSparseSearchConfig(ExactSearchConfig):
134 """Configuration for the Torch-based sparse searcher."""
135
136 search_class = TorchSparseSearcher
137 SUPPORTED_MODELS = {SpladeConfig.model_type}