Source code for lightning_ir.models.bi_encoders.dpr

  1"""Configuration and model for DPR (Dense Passage Retriever) type models. Originally proposed in \
  2`Dense Passage Retrieval for Open-Domain Question Answering \
  3<https://arxiv.org/abs/2004.04906>`_. This model type is also known as a SentenceTransformer model:
  4`Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks \
  5<https://arxiv.org/abs/1908.10084>`_.
  6"""
  7
  8from typing import Literal
  9
 10import torch
 11from transformers import BatchEncoding
 12
 13from ...bi_encoder import BiEncoderEmbedding, SingleVectorBiEncoderConfig, SingleVectorBiEncoderModel
 14from ...modeling_utils.embedding_post_processing import Pooler, Sparsifier
 15
 16
[docs] 17class DprConfig(SingleVectorBiEncoderConfig): 18 """Configuration class for a DPR model.""" 19 20 model_type = "lir-dpr" 21 """Model type for a DPR model.""" 22
[docs] 23 def __init__( 24 self, 25 query_length: int | None = 32, 26 doc_length: int | None = 512, 27 similarity_function: Literal["cosine", "dot"] = "dot", 28 normalization_strategy: Literal["l2"] | None = None, 29 sparsification_strategy: Literal["relu", "relu_log", "relu_2xlog"] | None = None, 30 add_marker_tokens: bool = False, 31 pooling_strategy: Literal["first", "mean", "max", "sum"] = "first", 32 embedding_dim: int | None = None, 33 projection: Literal["linear", "linear_no_bias"] | None = "linear", 34 **kwargs, 35 ) -> None: 36 """A DPR model encodes queries and documents separately. Before computing the similarity score, the 37 contextualized token embeddings are aggregated to obtain a single embedding using a pooling strategy. 38 Optionally, the pooled embeddings can be projected using a linear layer. 39 40 Args: 41 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 42 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 43 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and 44 document embeddings. Defaults to "dot". 45 normalization_strategy (Literal['l2'] | None): Whether to normalization_strategy query and document 46 embeddings. 47 Defaults to None. 48 sparsification_strategy (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which 49 sparsification_strategy function to apply. Defaults to None. 50 add_marker_tokens (bool): Whether to add marker tokens to the input sequences. Defaults to False. 51 pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for query and document 52 embeddings. Defaults to "first". 53 embedding_dim (int | None): Dimension of the final embeddings. If None, it will be set to the hidden size 54 of the backbone model. Defaults to None. 55 projection (Literal["linear", "linear_no_bias"] | None): type of projection layer to apply on the pooled 56 embeddings. If None, no projection is applied. Defaults to "linear". 57 """ 58 super().__init__( 59 query_length=query_length, 60 doc_length=doc_length, 61 similarity_function=similarity_function, 62 normalization_strategy=normalization_strategy, 63 sparsification_strategy=sparsification_strategy, 64 add_marker_tokens=add_marker_tokens, 65 pooling_strategy=pooling_strategy, 66 **kwargs, 67 ) 68 hidden_size = getattr(self, "hidden_size", None) 69 if projection is None: 70 self.embedding_dim = hidden_size 71 else: 72 self.embedding_dim = embedding_dim or hidden_size 73 self.projection = projection
74 75
[docs] 76class DprModel(SingleVectorBiEncoderModel): 77 """A single-vector DPR model. See :class:`DprConfig` for configuration options.""" 78 79 config_class = DprConfig 80 """Configuration class for a DPR model.""" 81
[docs] 82 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None: 83 """Initializes a DPR model given a :class:`DprConfig`. 84 85 Args: 86 config (SingleVectorBiEncoderConfig): Configuration for the DPR model. 87 Raises: 88 ValueError: If the embedding dimension is not specified in the configuration. 89 """ 90 super().__init__(config, *args, **kwargs) 91 if self.config.projection is None: 92 self.projection: torch.nn.Module = torch.nn.Identity() 93 else: 94 if self.config.embedding_dim is None: 95 raise ValueError("Unable to determine embedding dimension.") 96 self.projection = torch.nn.Linear( 97 self.config.hidden_size, 98 self.config.embedding_dim, 99 bias="no_bias" not in self.config.projection, 100 ) 101 self.pooler = Pooler(config) 102 self.sparsifier = Sparsifier(config)
103
[docs] 104 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 105 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 106 107 Args: 108 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 109 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 110 Returns: 111 BiEncoderEmbedding: Embeddings and scoring mask. 112 """ 113 embeddings = self._backbone_forward(**encoding).last_hidden_state 114 embeddings = self.pooler(embeddings, encoding["attention_mask"]) 115 embeddings = self.projection(embeddings) 116 embeddings = self.sparsifier(embeddings) 117 if self.config.normalization_strategy == "l2": 118 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1) 119 return BiEncoderEmbedding(embeddings, None, encoding)