Source code for lightning_ir.retrieve.pytorch.dense_searcher
1"""Torch-based Dense Searcher for Lightning IR Framework"""
2
3from __future__ import annotations
4
5from pathlib import Path
6from typing import TYPE_CHECKING, Literal
7
8import torch
9
10from ...modeling_utils.batching import _batch_pairwise_scoring
11from ...models import ColConfig, DprConfig
12from ..base.searcher import ExactSearchConfig, ExactSearcher
13from .dense_indexer import TorchDenseIndexConfig
14
15if TYPE_CHECKING:
16 from ...bi_encoder import BiEncoderEmbedding, BiEncoderModule
17
18
[docs]
19class TorchDenseIndex:
20 """Torch-based dense index for embeddings."""
21
[docs]
22 def __init__(self, index_dir: Path, similarity_function: Literal["dot", "cosine"], use_gpu: bool = False) -> None:
23 """Initialize the TorchDenseIndex.
24
25 Args:
26 index_dir (Path): Directory where the index is stored.
27 similarity_function (Literal["dot", "cosine"]): Similarity function to use for scoring.
28 use_gpu (bool): Whether to use GPU for indexing. Defaults to False.
29 Raises:
30 ValueError: If the similarity function is not recognized.
31 """
32 self.index = torch.load(index_dir / "index.pt", weights_only=True)
33 self.config = TorchDenseIndexConfig.from_pretrained(index_dir)
34 if similarity_function == "dot":
35 self.similarity_function = self.dot_similarity
36 elif similarity_function == "cosine":
37 self.similarity_function = self.cosine_similarity
38 else:
39 raise ValueError("Unknown similarity function")
40 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
41
[docs]
42 def score(self, embeddings: torch.Tensor) -> torch.Tensor:
43 """Score the embeddings against the index.
44
45 Args:
46 embeddings (torch.Tensor): The embeddings to score.
47 Returns:
48 torch.Tensor: The scores for the embeddings.
49 """
50 embeddings = embeddings.to(self.device)
51 similarity = self.similarity_function(embeddings, self.index)
52 return similarity
53
54 @property
55 def num_embeddings(self) -> int:
56 """Get the number of embeddings in the index."""
57 return self.index.shape[0]
58
[docs]
59 @staticmethod
60 @_batch_pairwise_scoring
61 @torch.autocast(device_type="cuda", enabled=False)
62 def cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
63 """Compute the cosine similarity between two tensors.
64
65 Args:
66 x (torch.Tensor): First tensor.
67 y (torch.Tensor): Second tensor.
68 Returns:
69 torch.Tensor: Cosine similarity scores.
70 """
71 return torch.nn.functional.cosine_similarity(x[:, None], y[None], dim=-1)
72
[docs]
73 @staticmethod
74 @_batch_pairwise_scoring
75 @torch.autocast(device_type="cuda", enabled=False)
76 def dot_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
77 """Compute the dot product similarity between two tensors.
78
79 Args:
80 x (torch.Tensor): First tensor.
81 y (torch.Tensor): Second tensor.
82 Returns:
83 torch.Tensor: Dot product similarity scores.
84 """
85 return torch.matmul(x, y.T)
86
[docs]
87 def to_gpu(self) -> None:
88 """Convert the index to GPU format."""
89 self.index = self.index.to(self.device)
90
91
[docs]
92class TorchDenseSearcher(ExactSearcher):
93 """Torch-based dense searcher for embeddings."""
94
[docs]
95 def __init__(
96 self,
97 index_dir: Path,
98 search_config: TorchDenseSearchConfig,
99 module: BiEncoderModule,
100 use_gpu: bool = True,
101 ) -> None:
102 """Initialize the TorchDenseSearcher.
103
104 Args:
105 index_dir (Path): Directory where the index is stored.
106 search_config (TorchDenseSearchConfig): Configuration for the dense search.
107 module (BiEncoderModule): Bi-encoder module to use for searching.
108 use_gpu (bool): Whether to use GPU for searching. Defaults to True.
109 """
110 self.search_config: TorchDenseSearchConfig
111 self.index = TorchDenseIndex(index_dir, module.config.similarity_function, use_gpu)
112 super().__init__(index_dir, search_config, module, use_gpu)
113 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
114
[docs]
115 def to_gpu(self) -> None:
116 """Move the searcher to the GPU if available."""
117 super().to_gpu()
118 self.index.to_gpu()
119
120 def _score(self, query_embeddings: BiEncoderEmbedding) -> torch.Tensor:
121 if query_embeddings.scoring_mask is None:
122 embeddings = query_embeddings.embeddings[:, 0]
123 else:
124 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask]
125 scores = self.index.score(embeddings)
126 return scores
127
128
[docs]
129class TorchDenseSearchConfig(ExactSearchConfig):
130 """Configuration for the TorchDenseSearcher."""
131
132 search_class = TorchDenseSearcher
133 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type}