Source code for lightning_ir.retrieve.pytorch.dense_searcher

  1"""Torch-based Dense Searcher for Lightning IR Framework"""
  2
  3from __future__ import annotations
  4
  5from pathlib import Path
  6from typing import TYPE_CHECKING, Literal
  7
  8import torch
  9
 10from ...modeling_utils.batching import _batch_pairwise_scoring
 11from ...models import ColConfig, DprConfig
 12from ..base.searcher import ExactSearchConfig, ExactSearcher
 13from .dense_indexer import TorchDenseIndexConfig
 14
 15if TYPE_CHECKING:
 16    from ...bi_encoder import BiEncoderEmbedding, BiEncoderModule
 17
 18
[docs] 19class TorchDenseIndex: 20 """Torch-based dense index for embeddings.""" 21
[docs] 22 def __init__(self, index_dir: Path, similarity_function: Literal["dot", "cosine"], use_gpu: bool = False) -> None: 23 """Initialize the TorchDenseIndex. 24 25 Args: 26 index_dir (Path): Directory where the index is stored. 27 similarity_function (Literal["dot", "cosine"]): Similarity function to use for scoring. 28 use_gpu (bool): Whether to use GPU for indexing. Defaults to False. 29 Raises: 30 ValueError: If the similarity function is not recognized. 31 """ 32 self.index = torch.load(index_dir / "index.pt", weights_only=True) 33 self.config = TorchDenseIndexConfig.from_pretrained(index_dir) 34 if similarity_function == "dot": 35 self.similarity_function = self.dot_similarity 36 elif similarity_function == "cosine": 37 self.similarity_function = self.cosine_similarity 38 else: 39 raise ValueError("Unknown similarity function") 40 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
41
[docs] 42 def score(self, embeddings: torch.Tensor) -> torch.Tensor: 43 """Score the embeddings against the index. 44 45 Args: 46 embeddings (torch.Tensor): The embeddings to score. 47 Returns: 48 torch.Tensor: The scores for the embeddings. 49 """ 50 embeddings = embeddings.to(self.device) 51 similarity = self.similarity_function(embeddings, self.index) 52 return similarity
53 54 @property 55 def num_embeddings(self) -> int: 56 """Get the number of embeddings in the index.""" 57 return self.index.shape[0] 58
[docs] 59 @staticmethod 60 @_batch_pairwise_scoring 61 @torch.autocast(device_type="cuda", enabled=False) 62 def cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 63 """Compute the cosine similarity between two tensors. 64 65 Args: 66 x (torch.Tensor): First tensor. 67 y (torch.Tensor): Second tensor. 68 Returns: 69 torch.Tensor: Cosine similarity scores. 70 """ 71 return torch.nn.functional.cosine_similarity(x[:, None], y[None], dim=-1)
72
[docs] 73 @staticmethod 74 @_batch_pairwise_scoring 75 @torch.autocast(device_type="cuda", enabled=False) 76 def dot_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 77 """Compute the dot product similarity between two tensors. 78 79 Args: 80 x (torch.Tensor): First tensor. 81 y (torch.Tensor): Second tensor. 82 Returns: 83 torch.Tensor: Dot product similarity scores. 84 """ 85 return torch.matmul(x, y.T)
86
[docs] 87 def to_gpu(self) -> None: 88 """Convert the index to GPU format.""" 89 self.index = self.index.to(self.device)
90 91
[docs] 92class TorchDenseSearcher(ExactSearcher): 93 """Torch-based dense searcher for embeddings.""" 94
[docs] 95 def __init__( 96 self, 97 index_dir: Path, 98 search_config: TorchDenseSearchConfig, 99 module: BiEncoderModule, 100 use_gpu: bool = True, 101 ) -> None: 102 """Initialize the TorchDenseSearcher. 103 104 Args: 105 index_dir (Path): Directory where the index is stored. 106 search_config (TorchDenseSearchConfig): Configuration for the dense search. 107 module (BiEncoderModule): Bi-encoder module to use for searching. 108 use_gpu (bool): Whether to use GPU for searching. Defaults to True. 109 """ 110 self.search_config: TorchDenseSearchConfig 111 self.index = TorchDenseIndex(index_dir, module.config.similarity_function, use_gpu) 112 super().__init__(index_dir, search_config, module, use_gpu) 113 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
114
[docs] 115 def to_gpu(self) -> None: 116 """Move the searcher to the GPU if available.""" 117 super().to_gpu() 118 self.index.to_gpu()
119 120 def _score(self, query_embeddings: BiEncoderEmbedding) -> torch.Tensor: 121 if query_embeddings.scoring_mask is None: 122 embeddings = query_embeddings.embeddings[:, 0] 123 else: 124 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask] 125 scores = self.index.score(embeddings) 126 return scores
127 128
[docs] 129class TorchDenseSearchConfig(ExactSearchConfig): 130 """Configuration for the TorchDenseSearcher.""" 131 132 search_class = TorchDenseSearcher 133 SUPPORTED_MODELS = {ColConfig.model_type, DprConfig.model_type}