1"""Configuration and model for DPR (Dense Passage Retriever) type models. Originally proposed in \
2`Dense Passage Retrieval for Open-Domain Question Answering \
3<https://arxiv.org/abs/2004.04906>`_. This model type is also known as a SentenceTransformer model:
4`Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks \
5<https://arxiv.org/abs/1908.10084>`_.
6"""
7
8from typing import Literal
9
10import torch
11from transformers import BatchEncoding
12
13from ..bi_encoder import BiEncoderEmbedding, SingleVectorBiEncoderConfig, SingleVectorBiEncoderModel
14
15
[docs]
16class DprConfig(SingleVectorBiEncoderConfig):
17 """Configuration class for a DPR model."""
18
19 model_type = "lir-dpr"
20 """Model type for a DPR model."""
21
[docs]
22 def __init__(
23 self,
24 query_length: int = 32,
25 doc_length: int = 512,
26 similarity_function: Literal["cosine", "dot"] = "dot",
27 normalize: bool = False,
28 sparsification: Literal["relu", "relu_log"] | None = None,
29 add_marker_tokens: bool = False,
30 query_pooling_strategy: Literal["first", "mean", "max", "sum"] = "first",
31 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] = "first",
32 embedding_dim: int | None = None,
33 projection: Literal["linear", "linear_no_bias"] | None = "linear",
34 **kwargs,
35 ) -> None:
36 """A DPR model encodes queries and documents separately. Before computing the similarity score, the
37 contextualized token embeddings are aggregated to obtain a single embedding using a pooling strategy.
38 Optionally, the pooled embeddings can be projected using a linear layer.
39
40 :param query_length: Maximum query length, defaults to 32
41 :type query_length: int, optional
42 :param doc_length: Maximum document length, defaults to 512
43 :type doc_length: int, optional
44 :param similarity_function: Similarity function to compute scores between query and document embeddings,
45 defaults to "dot"
46 :type similarity_function: Literal['cosine', 'dot'], optional
47 :param sparsification: Whether and which sparsification function to apply, defaults to None
48 :type sparsification: Literal['relu', 'relu_log'] | None, optional
49 :param query_pooling_strategy: Whether and how to pool the query token embeddings, defaults to "first"
50 :type query_pooling_strategy: Literal['first', 'mean', 'max', 'sum'], optional
51 :param doc_pooling_strategy: Whether and how to pool document token embeddings, defaults to "first"
52 :type doc_pooling_strategy: Literal['first', 'mean', 'max', 'sum'], optional
53
54 """
55 super().__init__(
56 query_length=query_length,
57 doc_length=doc_length,
58 similarity_function=similarity_function,
59 normalize=normalize,
60 sparsification=sparsification,
61 add_marker_tokens=add_marker_tokens,
62 query_pooling_strategy=query_pooling_strategy,
63 doc_pooling_strategy=doc_pooling_strategy,
64 **kwargs,
65 )
66 hidden_size = getattr(self, "hidden_size", None)
67 if projection is None:
68 self.embedding_dim = hidden_size
69 else:
70 self.embedding_dim = embedding_dim or hidden_size
71 self.projection = projection
72
73
[docs]
74class DprModel(SingleVectorBiEncoderModel):
75 """A single-vector DPR model. See :class:`DprConfig` for configuration options."""
76
77 config_class = DprConfig
78 """Configuration class for a DPR model."""
79
[docs]
80 def __init__(self, config: SingleVectorBiEncoderConfig, *args, **kwargs) -> None:
81 """Initializes a DPR model given a :class:`DprConfig`.
82
83 :param config: Configuration for the DPR model
84 :type config: SingleVectorBiEncoderConfig
85 """
86 super().__init__(config, *args, **kwargs)
87 if self.config.projection is None:
88 self.projection: torch.nn.Module = torch.nn.Identity()
89 else:
90 if self.config.embedding_dim is None:
91 raise ValueError("Unable to determine embedding dimension.")
92 self.projection = torch.nn.Linear(
93 self.config.hidden_size,
94 self.config.embedding_dim,
95 bias="no_bias" not in self.config.projection,
96 )
97
[docs]
98 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
99 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
100
101 :param encoding: Tokenizer encodings for the text sequence
102 :type encoding: BatchEncoding
103 :param input_type: Type of input, either "query" or "doc"
104 :type input_type: Literal["query", "doc"]
105 :return: Embeddings and scoring mask
106 :rtype: BiEncoderEmbedding
107 """
108 pooling_strategy = getattr(self.config, f"{input_type}_pooling_strategy")
109 embeddings = self._backbone_forward(**encoding).last_hidden_state
110 embeddings = self.pooling(embeddings, encoding["attention_mask"], pooling_strategy)
111 embeddings = self.projection(embeddings)
112 embeddings = self.sparsification(embeddings, self.config.sparsification)
113 if self.config.normalize:
114 embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
115 return BiEncoderEmbedding(embeddings, None, encoding)