1"""Configuration and model for DPR (Dense Passage Retriever) type models. Originally proposed in \
2`Dense Passage Retrieval for Open-Domain Question Answering \
3<https://arxiv.org/abs/2004.04906>`_. This model type is also known as a SentenceTransformer model:
4`Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks \
5<https://arxiv.org/abs/1908.10084>`_.
6"""
7
8from typing import Literal
9
10import torch
11from transformers import BatchEncoding
12
13from ...bi_encoder import BiEncoderEmbedding, SingleVectorBiEncoderConfig, SingleVectorBiEncoderModel
14
15
[docs]
16class DprConfig(SingleVectorBiEncoderConfig):
17 """Configuration class for a DPR model."""
18
19 model_type = "lir-dpr"
20 """Model type for a DPR model."""
21
[docs]
22 def __init__(
23 self,
24 query_length: int | None = 32,
25 doc_length: int | None = 512,
26 similarity_function: Literal["cosine", "dot"] = "dot",
27 normalize: bool = False,
28 sparsification: Literal["relu", "relu_log", "relu_2xlog"] | None = None,
29 add_marker_tokens: bool = False,
30 query_pooling_strategy: Literal["first", "mean", "max", "sum"] = "first",
31 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] = "first",
32 embedding_dim: int | None = None,
33 projection: Literal["linear", "linear_no_bias"] | None = "linear",
34 **kwargs,
35 ) -> None:
36 """A DPR model encodes queries and documents separately. Before computing the similarity score, the
37 contextualized token embeddings are aggregated to obtain a single embedding using a pooling strategy.
38 Optionally, the pooled embeddings can be projected using a linear layer.
39
40 Args:
41 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
42 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
43 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
44 document embeddings. Defaults to "dot".
45 normalize (bool): Whether to normalize the embeddings. Defaults to False.
46 sparsification (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification
47 function to apply. Defaults to None.
48 add_marker_tokens (bool): Whether to add marker tokens to the input sequences. Defaults to False.
49 query_pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for query embeddings.
50 Defaults to "first".
51 doc_pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for document embeddings.
52 Defaults to "first".
53 embedding_dim (int | None): Dimension of the final embeddings. If None, it will be set to the hidden size
54 of the backbone model. Defaults to None.
55 projection (Literal["linear", "linear_no_bias"] | None): Type of projection layer to apply on the pooled
56 embeddings. If None, no projection is applied. Defaults to "linear".
57 """
58 super().__init__(
59 query_length=query_length,
60 doc_length=doc_length,
61 similarity_function=similarity_function,
62 normalize=normalize,
63 sparsification=sparsification,
64 add_marker_tokens=add_marker_tokens,
65 query_pooling_strategy=query_pooling_strategy,
66 doc_pooling_strategy=doc_pooling_strategy,
67 **kwargs,
68 )
69 hidden_size = getattr(self, "hidden_size", None)
70 if projection is None:
71 self.embedding_dim = hidden_size
72 else:
73 self.embedding_dim = embedding_dim or hidden_size
74 self.projection = projection
75
76
[docs]
77class DprModel(SingleVectorBiEncoderModel):
78 """A single-vector DPR model. See :class:`DprConfig` for configuration options."""
79
80 config_class = DprConfig
81 """Configuration class for a DPR model."""
82
[docs]
83 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None:
84 """Initializes a DPR model given a :class:`DprConfig`.
85
86 Args:
87 config (SingleVectorBiEncoderConfig): Configuration for the DPR model.
88 Raises:
89 ValueError: If the embedding dimension is not specified in the configuration.
90 """
91 super().__init__(config, *args, **kwargs)
92 if self.config.projection is None:
93 self.projection: torch.nn.Module = torch.nn.Identity()
94 else:
95 if self.config.embedding_dim is None:
96 raise ValueError("Unable to determine embedding dimension.")
97 self.projection = torch.nn.Linear(
98 self.config.hidden_size,
99 self.config.embedding_dim,
100 bias="no_bias" not in self.config.projection,
101 )
102
[docs]
103 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
104 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
105
106 Args:
107 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
108 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc".
109 Returns:
110 BiEncoderEmbedding: Embeddings and scoring mask.
111 """
112 pooling_strategy = getattr(self.config, f"{input_type}_pooling_strategy")
113 embeddings = self._backbone_forward(**encoding).last_hidden_state
114 embeddings = self.pooling(embeddings, encoding["attention_mask"], pooling_strategy)
115 embeddings = self.projection(embeddings)
116 embeddings = self.sparsification(embeddings, self.config.sparsification)
117 if self.config.normalize:
118 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
119 return BiEncoderEmbedding(embeddings, None, encoding)