Source code for lightning_ir.retrieve.pytorch.dense_searcher

 1from __future__ import annotations
 2
 3from pathlib import Path
 4from typing import TYPE_CHECKING, Literal
 5
 6import torch
 7
 8from ...modeling_utils.batching import _batch_pairwise_scoring
 9from ...models import ColConfig, DprConfig
10from ..base.searcher import ExactSearchConfig, ExactSearcher
11from .dense_indexer import TorchDenseIndexConfig
12
13if TYPE_CHECKING:
14    from ...bi_encoder import BiEncoderEmbedding, BiEncoderModule
15
16
[docs] 17class TorchDenseIndex:
[docs] 18 def __init__(self, index_dir: Path, similarity_function: Literal["dot", "cosine"], use_gpu: bool = False) -> None: 19 self.index = torch.load(index_dir / "index.pt", weights_only=True) 20 self.config = TorchDenseIndexConfig.from_pretrained(index_dir) 21 if similarity_function == "dot": 22 self.similarity_function = self.dot_similarity 23 elif similarity_function == "cosine": 24 self.similarity_function = self.cosine_similarity 25 else: 26 raise ValueError("Unknown similarity function") 27 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
28 29 def score(self, embeddings: torch.Tensor) -> torch.Tensor: 30 embeddings = embeddings.to(self.device) 31 similarity = self.similarity_function(embeddings, self.index) 32 return similarity 33 34 @property 35 def num_embeddings(self) -> int: 36 return self.index.shape[0] 37 38 @staticmethod 39 @_batch_pairwise_scoring 40 @torch.autocast(device_type="cuda", enabled=False) 41 def cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 42 return torch.nn.functional.cosine_similarity(x[:, None], y[None], dim=-1) 43 44 @staticmethod 45 @_batch_pairwise_scoring 46 @torch.autocast(device_type="cuda", enabled=False) 47 def dot_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 48 return torch.matmul(x, y.T) 49 50 def to_gpu(self) -> None: 51 self.index = self.index.to(self.device)
52 53
[docs] 54class TorchDenseSearcher(ExactSearcher):
[docs] 55 def __init__( 56 self, 57 index_dir: Path, 58 search_config: TorchDenseSearchConfig, 59 module: BiEncoderModule, 60 use_gpu: bool = True, 61 ) -> None: 62 self.search_config: TorchDenseSearchConfig 63 self.index = TorchDenseIndex(index_dir, module.config.similarity_function, use_gpu) 64 super().__init__(index_dir, search_config, module, use_gpu) 65 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
66 67 def to_gpu(self) -> None: 68 super().to_gpu() 69 self.index.to_gpu() 70 71 def _score(self, query_embeddings: BiEncoderEmbedding) -> torch.Tensor: 72 if query_embeddings.scoring_mask is None: 73 embeddings = query_embeddings.embeddings[:, 0] 74 else: 75 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask] 76 scores = self.index.score(embeddings) 77 return scores
78 79
[docs] 80class TorchDenseSearchConfig(ExactSearchConfig): 81 search_class = TorchDenseSearcher 82 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type}