1"""Configuration and model for DPR (Dense Passage Retriever) type models. Originally proposed in \
2`Dense Passage Retrieval for Open-Domain Question Answering \
3<https://arxiv.org/abs/2004.04906>`_. This model type is also known as a SentenceTransformer model:
4`Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks \
5<https://arxiv.org/abs/1908.10084>`_.
6"""
7
8from typing import Literal
9
10import torch
11from transformers import BatchEncoding
12
13from ..bi_encoder import BiEncoderEmbedding, SingleVectorBiEncoderConfig, SingleVectorBiEncoderModel
14
15
[docs]
16class DprConfig(SingleVectorBiEncoderConfig):
17 """Configuration class for a DPR model."""
18
19 model_type = "lir-dpr"
20 """Model type for a DPR model."""
21
[docs]
22 def __init__(
23 self,
24 query_length: int = 32,
25 doc_length: int = 512,
26 similarity_function: Literal["cosine", "dot"] = "dot",
27 normalize: bool = False,
28 sparsification: Literal["relu", "relu_log"] | None = None,
29 add_marker_tokens: bool = False,
30 query_pooling_strategy: Literal["first", "mean", "max", "sum"] = "first",
31 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] = "first",
32 embedding_dim: int | None = None,
33 projection: Literal["linear", "linear_no_bias"] | None = "linear",
34 **kwargs,
35 ) -> None:
36 """A DPR model encodes queries and documents separately. Before computing the similarity score, the
37 contextualized token embeddings are aggregated to obtain a single embedding using a pooling strategy.
38 Optionally, the pooled embeddings can be projected using a linear layer.
39
40 Args:
41 query_length (int): Maximum query length. Defaults to 32.
42 doc_length (int): Maximum document length. Defaults to 512.
43 similarity_function (Literal["cosine", "dot"]): Similarity function to compute scores between query and
44 document embeddings. Defaults to "dot".
45 normalize (bool): Whether to normalize the embeddings. Defaults to False.
46 sparsification (Literal["relu", "relu_log"] | None): Sparsification function to apply. Defaults to None.
47 add_marker_tokens (bool): Whether to add marker tokens to the input sequences. Defaults to False.
48 query_pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for query embeddings.
49 Defaults to "first".
50 doc_pooling_strategy (Literal["first", "mean", "max", "sum"]): Pooling strategy for document embeddings.
51 Defaults to "first".
52 embedding_dim (int | None): Dimension of the final embeddings. If None, it will be set to the hidden size
53 of the backbone model. Defaults to None.
54 projection (Literal["linear", "linear_no_bias"] | None): Type of projection layer to apply on the pooled
55 embeddings. If None, no projection is applied. Defaults to "linear".
56 """
57 super().__init__(
58 query_length=query_length,
59 doc_length=doc_length,
60 similarity_function=similarity_function,
61 normalize=normalize,
62 sparsification=sparsification,
63 add_marker_tokens=add_marker_tokens,
64 query_pooling_strategy=query_pooling_strategy,
65 doc_pooling_strategy=doc_pooling_strategy,
66 **kwargs,
67 )
68 hidden_size = getattr(self, "hidden_size", None)
69 if projection is None:
70 self.embedding_dim = hidden_size
71 else:
72 self.embedding_dim = embedding_dim or hidden_size
73 self.projection = projection
74
75
[docs]
76class DprModel(SingleVectorBiEncoderModel):
77 """A single-vector DPR model. See :class:`DprConfig` for configuration options."""
78
79 config_class = DprConfig
80 """Configuration class for a DPR model."""
81
[docs]
82 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None:
83 """Initializes a DPR model given a :class:`DprConfig`.
84
85 Args:
86 config (SingleVectorBiEncoderConfig): Configuration for the DPR model.
87 Raises:
88 ValueError: If the embedding dimension is not specified in the configuration.
89 """
90 super().__init__(config, *args, **kwargs)
91 if self.config.projection is None:
92 self.projection: torch.nn.Module = torch.nn.Identity()
93 else:
94 if self.config.embedding_dim is None:
95 raise ValueError("Unable to determine embedding dimension.")
96 self.projection = torch.nn.Linear(
97 self.config.hidden_size,
98 self.config.embedding_dim,
99 bias="no_bias" not in self.config.projection,
100 )
101
[docs]
102 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
103 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
104
105 Args:
106 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
107 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc".
108 Returns:
109 BiEncoderEmbedding: Embeddings and scoring mask.
110 """
111 pooling_strategy = getattr(self.config, f"{input_type}_pooling_strategy")
112 embeddings = self._backbone_forward(**encoding).last_hidden_state
113 embeddings = self.pooling(embeddings, encoding["attention_mask"], pooling_strategy)
114 embeddings = self.projection(embeddings)
115 embeddings = self.sparsification(embeddings, self.config.sparsification)
116 if self.config.normalize:
117 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
118 return BiEncoderEmbedding(embeddings, None, encoding)