Source code for lightning_ir.retrieve.pytorch.sparse_searcher

  1"""Torch-based Sparse Searcher for Lightning IR Framework"""
  2
  3from __future__ import annotations
  4
  5from pathlib import Path
  6from typing import TYPE_CHECKING, Literal
  7
  8import torch
  9
 10from ...modeling_utils.batching import _batch_pairwise_scoring
 11from ...models import SpladeConfig
 12from ..base.searcher import ExactSearchConfig, ExactSearcher
 13from .sparse_indexer import TorchSparseIndexConfig
 14
 15if TYPE_CHECKING:
 16    from ...bi_encoder import BiEncoderEmbedding, BiEncoderModule
 17
 18
[docs] 19class TorchSparseIndex: 20 """Torch-based sparse index for efficient retrieval.""" 21
[docs] 22 def __init__(self, index_dir: Path, similarity_function: Literal["dot", "cosine"], use_gpu: bool = False) -> None: 23 """Initialize the TorchSparseIndex. 24 25 Args: 26 index_dir (Path): Directory containing the index files. 27 similarity_function (Literal["dot", "cosine"]): The similarity function to use. 28 use_gpu (bool): Whether to use GPU for computations. Defaults to False. 29 Raises: 30 ValueError: If the similarity function is not recognized. 31 """ 32 self.index = torch.load(index_dir / "index.pt", weights_only=True) 33 self.config = TorchSparseIndexConfig.from_pretrained(index_dir) 34 if similarity_function == "dot": 35 self.similarity_function = self.dot_similarity 36 elif similarity_function == "cosine": 37 self.similarity_function = self.cosine_similarity 38 else: 39 raise ValueError("Unknown similarity function") 40 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
41
[docs] 42 def score(self, embeddings: torch.Tensor) -> torch.Tensor: 43 """Compute scores for the given embeddings. 44 45 Args: 46 embeddings (torch.Tensor): The embeddings to score. 47 Returns: 48 torch.Tensor: The computed scores. 49 """ 50 embeddings = embeddings.to(self.index) 51 similarity = self.similarity_function(embeddings, self.index).to_dense() 52 return similarity
53 54 @property 55 def num_embeddings(self) -> int: 56 """Get the number of embeddings in the index. 57 58 Returns: 59 int: The number of embeddings. 60 """ 61 return self.index.shape[0] 62
[docs] 63 @staticmethod 64 @_batch_pairwise_scoring 65 @torch.autocast(device_type="cuda", enabled=False) 66 def cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 67 """Compute cosine similarity between two tensors. 68 69 Args: 70 x (torch.Tensor): The first tensor. 71 y (torch.Tensor): The second tensor. 72 Returns: 73 torch.Tensor: The cosine similarity scores. 74 """ 75 return y.matmul(x.T).T / (torch.norm(x, dim=-1)[:, None] * torch.norm(y, dim=-1)[None])
76
[docs] 77 @staticmethod 78 @_batch_pairwise_scoring 79 @torch.autocast(device_type="cuda", enabled=False) 80 def dot_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 81 """Compute dot product similarity between two tensors. 82 83 Args: 84 x (torch.Tensor): The first tensor. 85 y (torch.Tensor): The second tensor. 86 Returns: 87 torch.Tensor: The dot product similarity scores. 88 """ 89 return y.matmul(x.T).T
90
[docs] 91 def to_gpu(self) -> None: 92 """Move the index to GPU if available.""" 93 self.index = self.index.to(self.device)
94 95
[docs] 96class TorchSparseSearcher(ExactSearcher): 97 """Torch-based sparse searcher for Lightning IR framework.""" 98
[docs] 99 def __init__( 100 self, 101 index_dir: Path, 102 search_config: TorchSparseSearchConfig, 103 module: BiEncoderModule, 104 use_gpu: bool = True, 105 ) -> None: 106 """Initialize the TorchSparseSearcher. 107 108 Args: 109 index_dir (Path): Directory containing the index files. 110 search_config (TorchSparseSearchConfig): Configuration for the searcher. 111 module (BiEncoderModule): The BiEncoder module to use for scoring. 112 use_gpu (bool): Whether to use GPU for computations. Defaults to True. 113 """ 114 self.search_config: TorchSparseSearchConfig 115 self.index = TorchSparseIndex(index_dir, module.config.similarity_function, use_gpu) 116 super().__init__(index_dir, search_config, module, use_gpu) 117 self.device = torch.device("cuda") if use_gpu and torch.cuda.is_available() else torch.device("cpu")
118
[docs] 119 def to_gpu(self) -> None: 120 """Move the searcher and index to GPU if available.""" 121 super().to_gpu() 122 self.index.to_gpu()
123 124 def _score(self, query_embeddings: BiEncoderEmbedding) -> torch.Tensor: 125 if query_embeddings.scoring_mask is None: 126 embeddings = query_embeddings.embeddings[:, 0] 127 else: 128 embeddings = query_embeddings.embeddings[query_embeddings.scoring_mask] 129 scores = self.index.score(embeddings) 130 return scores
131 132
[docs] 133class TorchSparseSearchConfig(ExactSearchConfig): 134 """Configuration for the Torch-based sparse searcher.""" 135 136 search_class = TorchSparseSearcher 137 SUPPORTED_MODELS = {SpladeConfig.model_type}