Source code for lightning_ir.models.dpr

  1"""Configuration and model for DPR (Dense Passage Retriever) type models. Originally proposed in \
  2`Dense Passage Retrieval for Open-Domain Question Answering \
  3<https://arxiv.org/abs/2004.04906>`_. This model type is also known as a SentenceTransformer model:
  4`Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks \
  5<https://arxiv.org/abs/1908.10084>`_.
  6"""
  7
  8from typing import Literal
  9
 10import torch
 11from transformers import BatchEncoding
 12
 13from ..bi_encoder import BiEncoderEmbedding, SingleVectorBiEncoderConfig, SingleVectorBiEncoderModel
 14
 15
[docs] 16class DprConfig(SingleVectorBiEncoderConfig): 17 """Configuration class for a DPR model.""" 18 19 model_type = "lir-dpr" 20 """Model type for a DPR model.""" 21
[docs] 22 def __init__( 23 self, 24 query_length: int = 32, 25 doc_length: int = 512, 26 similarity_function: Literal["cosine", "dot"] = "dot", 27 normalize: bool = False, 28 sparsification: Literal["relu", "relu_log"] | None = None, 29 add_marker_tokens: bool = False, 30 query_pooling_strategy: Literal["first", "mean", "max", "sum"] = "first", 31 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] = "first", 32 embedding_dim: int | None = None, 33 projection: Literal["linear", "linear_no_bias"] | None = "linear", 34 **kwargs, 35 ) -> None: 36 """A DPR model encodes queries and documents separately. Before computing the similarity score, the 37 contextualized token embeddings are aggregated to obtain a single embedding using a pooling strategy. 38 Optionally, the pooled embeddings can be projected using a linear layer. 39 40 :param query_length: Maximum query length, defaults to 32 41 :type query_length: int, optional 42 :param doc_length: Maximum document length, defaults to 512 43 :type doc_length: int, optional 44 :param similarity_function: Similarity function to compute scores between query and document embeddings, 45 defaults to "dot" 46 :type similarity_function: Literal['cosine', 'dot'], optional 47 :param sparsification: Whether and which sparsification function to apply, defaults to None 48 :type sparsification: Literal['relu', 'relu_log'] | None, optional 49 :param query_pooling_strategy: Whether and how to pool the query token embeddings, defaults to "first" 50 :type query_pooling_strategy: Literal['first', 'mean', 'max', 'sum'], optional 51 :param doc_pooling_strategy: Whether and how to pool document token embeddings, defaults to "first" 52 :type doc_pooling_strategy: Literal['first', 'mean', 'max', 'sum'], optional 53 54 """ 55 super().__init__( 56 query_length=query_length, 57 doc_length=doc_length, 58 similarity_function=similarity_function, 59 normalize=normalize, 60 sparsification=sparsification, 61 add_marker_tokens=add_marker_tokens, 62 query_pooling_strategy=query_pooling_strategy, 63 doc_pooling_strategy=doc_pooling_strategy, 64 **kwargs, 65 ) 66 hidden_size = getattr(self, "hidden_size", None) 67 if projection is None: 68 self.embedding_dim = hidden_size 69 else: 70 self.embedding_dim = embedding_dim or hidden_size 71 self.projection = projection
72 73
[docs] 74class DprModel(SingleVectorBiEncoderModel): 75 """A single-vector DPR model. See :class:`DprConfig` for configuration options.""" 76 77 config_class = DprConfig 78 """Configuration class for a DPR model.""" 79
[docs] 80 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None: 81 """Initializes a DPR model given a :class:`DprConfig`. 82 83 :param config: Configuration for the DPR model 84 :type config: SingleVectorBiEncoderConfig 85 """ 86 super().__init__(config, *args, **kwargs) 87 if self.config.projection is None: 88 self.projection: torch.nn.Module = torch.nn.Identity() 89 else: 90 if self.config.embedding_dim is None: 91 raise ValueError("Unable to determine embedding dimension.") 92 self.projection = torch.nn.Linear( 93 self.config.hidden_size, 94 self.config.embedding_dim, 95 bias="no_bias" not in self.config.projection, 96 )
97
[docs] 98 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 99 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 100 101 :param encoding: Tokenizer encodings for the text sequence 102 :type encoding: BatchEncoding 103 :param input_type: Type of input, either "query" or "doc" 104 :type input_type: Literal["query", "doc"] 105 :return: Embeddings and scoring mask 106 :rtype: BiEncoderEmbedding 107 """ 108 pooling_strategy = getattr(self.config, f"{input_type}_pooling_strategy") 109 embeddings = self._backbone_forward(**encoding).last_hidden_state 110 embeddings = self.pooling(embeddings, encoding["attention_mask"], pooling_strategy) 111 embeddings = self.projection(embeddings) 112 embeddings = self.sparsification(embeddings, self.config.sparsification) 113 if self.config.normalize: 114 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1) 115 return BiEncoderEmbedding(embeddings, None, encoding)