1"""Configuration and model for DPR (Dense Passage Retriever) type models. Originally proposed in \
2`Dense Passage Retrieval for Open-Domain Question Answering \
3<https://arxiv.org/abs/2004.04906>`_. This model type is also known as a SentenceTransformer model:
4`Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks \
5<https://arxiv.org/abs/1908.10084>`_.
6"""
7
8from typing import Literal
9
10import torch
11from transformers import BatchEncoding
12
13from ...bi_encoder import BiEncoderEmbedding, SingleVectorBiEncoderConfig, SingleVectorBiEncoderModel
14from ...modeling_utils.embedding_post_processing import Pooler, Sparsifier
15
16
[docs]
17class DprConfig(SingleVectorBiEncoderConfig):
18 """Configuration class for a DPR model."""
19
20 model_type = "lir-dpr"
21 """Model type for a DPR model."""
22
[docs]
23 def __init__(
24 self,
25 query_length: int | None = 32,
26 doc_length: int | None = 512,
27 similarity_function: Literal["cosine", "dot"] = "dot",
28 normalization_strategy: Literal["l2"] | None = None,
29 sparsification_strategy: Literal["relu", "relu_log", "relu_2xlog"] | None = None,
30 add_marker_tokens: bool = False,
31 pooling_strategy: Literal["first", "mean", "max", "sum"] = "first",
32 embedding_dim: int | None = None,
33 projection: Literal["linear", "linear_no_bias"] | None = "linear",
34 **kwargs,
35 ) -> None:
36 """A DPR model encodes queries and documents separately. Before computing the similarity score, the
37 contextualized token embeddings are aggregated to obtain a single embedding using a pooling strategy.
38 Optionally, the pooled embeddings can be projected using a linear layer.
39
40 Args:
41 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
42 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
43 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
44 document embeddings. Defaults to "dot".
45 normalization_strategy (Literal['l2'] | None): Whether to normalization_strategy query and document
46 embeddings.
47 Defaults to None.
48 sparsification_strategy (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which
49 sparsification_strategy function to apply. Defaults to None.
50 add_marker_tokens (bool): Whether to add marker tokens to the input sequences. Defaults to False.
51 pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for query and document
52 embeddings. Defaults to "first".
53 embedding_dim (int | None): Dimension of the final embeddings. If None, it will be set to the hidden size
54 of the backbone model. Defaults to None.
55 projection (Literal["linear", "linear_no_bias"] | None): type of projection layer to apply on the pooled
56 embeddings. If None, no projection is applied. Defaults to "linear".
57 """
58 super().__init__(
59 query_length=query_length,
60 doc_length=doc_length,
61 similarity_function=similarity_function,
62 normalization_strategy=normalization_strategy,
63 sparsification_strategy=sparsification_strategy,
64 add_marker_tokens=add_marker_tokens,
65 pooling_strategy=pooling_strategy,
66 **kwargs,
67 )
68 hidden_size = getattr(self, "hidden_size", None)
69 if projection is None:
70 self.embedding_dim = hidden_size
71 else:
72 self.embedding_dim = embedding_dim or hidden_size
73 self.projection = projection
74
75
[docs]
76class DprModel(SingleVectorBiEncoderModel):
77 """A single-vector DPR model. See :class:`DprConfig` for configuration options."""
78
79 config_class = DprConfig
80 """Configuration class for a DPR model."""
81
[docs]
82 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None:
83 """Initializes a DPR model given a :class:`DprConfig`.
84
85 Args:
86 config (SingleVectorBiEncoderConfig): Configuration for the DPR model.
87 Raises:
88 ValueError: If the embedding dimension is not specified in the configuration.
89 """
90 super().__init__(config, *args, **kwargs)
91 if self.config.projection is None:
92 self.projection: torch.nn.Module = torch.nn.Identity()
93 else:
94 if self.config.embedding_dim is None:
95 raise ValueError("Unable to determine embedding dimension.")
96 self.projection = torch.nn.Linear(
97 self.config.hidden_size,
98 self.config.embedding_dim,
99 bias="no_bias" not in self.config.projection,
100 )
101 self.pooler = Pooler(config)
102 self.sparsifier = Sparsifier(config)
103
[docs]
104 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
105 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
106
107 Args:
108 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
109 input_type (Literal["query", "doc"]): type of input, either "query" or "doc".
110 Returns:
111 BiEncoderEmbedding: Embeddings and scoring mask.
112 """
113 embeddings = self._backbone_forward(**encoding).last_hidden_state
114 embeddings = self.pooler(embeddings, encoding["attention_mask"])
115 embeddings = self.projection(embeddings)
116 embeddings = self.sparsifier(embeddings)
117 if self.config.normalization_strategy == "l2":
118 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
119 return BiEncoderEmbedding(embeddings, None, encoding)