Source code for lightning_ir.models.bi_encoders.dpr

  1"""Configuration and model for DPR (Dense Passage Retriever) type models. Originally proposed in \
  2`Dense Passage Retrieval for Open-Domain Question Answering \
  3<https://arxiv.org/abs/2004.04906>`_. This model type is also known as a SentenceTransformer model:
  4`Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks \
  5<https://arxiv.org/abs/1908.10084>`_.
  6"""
  7
  8from typing import Literal
  9
 10import torch
 11from transformers import BatchEncoding
 12
 13from ...bi_encoder import BiEncoderEmbedding, SingleVectorBiEncoderConfig, SingleVectorBiEncoderModel
 14
 15
[docs] 16class DprConfig(SingleVectorBiEncoderConfig): 17 """Configuration class for a DPR model.""" 18 19 model_type = "lir-dpr" 20 """Model type for a DPR model.""" 21
[docs] 22 def __init__( 23 self, 24 query_length: int | None = 32, 25 doc_length: int | None = 512, 26 similarity_function: Literal["cosine", "dot"] = "dot", 27 normalize: bool = False, 28 sparsification: Literal["relu", "relu_log", "relu_2xlog"] | None = None, 29 add_marker_tokens: bool = False, 30 query_pooling_strategy: Literal["first", "mean", "max", "sum"] = "first", 31 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] = "first", 32 embedding_dim: int | None = None, 33 projection: Literal["linear", "linear_no_bias"] | None = "linear", 34 **kwargs, 35 ) -> None: 36 """A DPR model encodes queries and documents separately. Before computing the similarity score, the 37 contextualized token embeddings are aggregated to obtain a single embedding using a pooling strategy. 38 Optionally, the pooled embeddings can be projected using a linear layer. 39 40 Args: 41 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 42 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 43 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and 44 document embeddings. Defaults to "dot". 45 normalize (bool): Whether to normalize the embeddings. Defaults to False. 46 sparsification (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification 47 function to apply. Defaults to None. 48 add_marker_tokens (bool): Whether to add marker tokens to the input sequences. Defaults to False. 49 query_pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for query embeddings. 50 Defaults to "first". 51 doc_pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for document embeddings. 52 Defaults to "first". 53 embedding_dim (int | None): Dimension of the final embeddings. If None, it will be set to the hidden size 54 of the backbone model. Defaults to None. 55 projection (Literal["linear", "linear_no_bias"] | None): Type of projection layer to apply on the pooled 56 embeddings. If None, no projection is applied. Defaults to "linear". 57 """ 58 super().__init__( 59 query_length=query_length, 60 doc_length=doc_length, 61 similarity_function=similarity_function, 62 normalize=normalize, 63 sparsification=sparsification, 64 add_marker_tokens=add_marker_tokens, 65 query_pooling_strategy=query_pooling_strategy, 66 doc_pooling_strategy=doc_pooling_strategy, 67 **kwargs, 68 ) 69 hidden_size = getattr(self, "hidden_size", None) 70 if projection is None: 71 self.embedding_dim = hidden_size 72 else: 73 self.embedding_dim = embedding_dim or hidden_size 74 self.projection = projection
75 76
[docs] 77class DprModel(SingleVectorBiEncoderModel): 78 """A single-vector DPR model. See :class:`DprConfig` for configuration options.""" 79 80 config_class = DprConfig 81 """Configuration class for a DPR model.""" 82
[docs] 83 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None: 84 """Initializes a DPR model given a :class:`DprConfig`. 85 86 Args: 87 config (SingleVectorBiEncoderConfig): Configuration for the DPR model. 88 Raises: 89 ValueError: If the embedding dimension is not specified in the configuration. 90 """ 91 super().__init__(config, *args, **kwargs) 92 if self.config.projection is None: 93 self.projection: torch.nn.Module = torch.nn.Identity() 94 else: 95 if self.config.embedding_dim is None: 96 raise ValueError("Unable to determine embedding dimension.") 97 self.projection = torch.nn.Linear( 98 self.config.hidden_size, 99 self.config.embedding_dim, 100 bias="no_bias" not in self.config.projection, 101 )
102
[docs] 103 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 104 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 105 106 Args: 107 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 108 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc". 109 Returns: 110 BiEncoderEmbedding: Embeddings and scoring mask. 111 """ 112 pooling_strategy = getattr(self.config, f"{input_type}_pooling_strategy") 113 embeddings = self._backbone_forward(**encoding).last_hidden_state 114 embeddings = self.pooling(embeddings, encoding["attention_mask"], pooling_strategy) 115 embeddings = self.projection(embeddings) 116 embeddings = self.sparsification(embeddings, self.config.sparsification) 117 if self.config.normalize: 118 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1) 119 return BiEncoderEmbedding(embeddings, None, encoding)