1"""
2Configuration module for bi-encoder models.
3
4This module defines the configuration class used to instantiate bi-encoder models.
5"""
6
7from typing import Any, Literal, Sequence
8
9from ..base import LightningIRConfig
10
11
[docs]
12class BiEncoderConfig(LightningIRConfig):
13 """Configuration class for a bi-encoder model."""
14
15 model_type: str = "bi-encoder"
16 """Model type for bi-encoder models."""
17
[docs]
18 def __init__(
19 self,
20 query_length: int | None = 32,
21 doc_length: int | None = 512,
22 similarity_function: Literal["cosine", "dot"] = "dot",
23 normalization: Literal["l2"] | None = None,
24 sparsification: Literal["relu", "relu_log", "relu_2xlog"] | None = None,
25 add_marker_tokens: bool = False,
26 **kwargs,
27 ):
28 """A bi-encoder model encodes queries and documents separately and computes a relevance score based on the
29 similarity of the query and document embeddings. Normalization and sparsification can be applied to the
30 embeddings before computing the similarity score.
31
32 Args:
33 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
34 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
35 similarity_function (Literal['cosine', 'dot']): Similarity function to compute scores between query and
36 document embeddings. Defaults to "dot".
37 normalization (Literal['l2'] | None): Whether to normalize query and document embeddings. Defaults to None.
38 sparsification (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification
39 function to apply. Defaults to None.
40 add_marker_tokens (bool): Whether to prepend extra marker tokens [Q] / [D] to queries / documents.
41 Defaults to False.
42 """
43 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs)
44 self.similarity_function = similarity_function
45 self.normalization = normalization
46 self.sparsification = sparsification
47 self.add_marker_tokens = add_marker_tokens
48 self.embedding_dim: int | None = getattr(self, "hidden_size", None)
49
[docs]
50 def to_diff_dict(self) -> dict[str, Any]:
51 """
52 Removes all attributes from the configuration that correspond to the default config attributes for
53 better readability, while always retaining the `config` attribute from the class. Serializes to a
54 Python dictionary.
55
56 Returns:
57 dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
58 """
59 diff_dict = super().to_diff_dict()
60 diff_dict.pop("embedding_dim", None) # Exclude embedding_dim from diff_dict
61 return diff_dict
62
63
[docs]
64class SingleVectorBiEncoderConfig(BiEncoderConfig):
65 """Configuration class for a single-vector bi-encoder model."""
66
67 model_type: str = "single-vector-bi-encoder"
68 """Model type for single-vector bi-encoder models."""
69
[docs]
70 def __init__(
71 self,
72 query_length: int | None = 32,
73 doc_length: int | None = 512,
74 similarity_function: Literal["cosine", "dot"] = "dot",
75 normalization: Literal["l2"] | None = None,
76 sparsification: Literal["relu", "relu_log", "relu_2xlog"] | None = None,
77 add_marker_tokens: bool = False,
78 query_pooling_strategy: Literal["first", "mean", "max", "sum"] = "mean",
79 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] = "mean",
80 **kwargs,
81 ):
82 """Configuration class for a single-vector bi-encoder model. A single-vector bi-encoder model pools the
83 representations of queries and documents into a single vector before computing a similarity score.
84
85 Args:
86 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
87 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
88 similarity_function (Literal['cosine', 'dot']): Similarity function to compute scores between query and
89 document embeddings. Defaults to "dot".
90 normalization (Literal['l2'] | None): Whether to normalize query and document embeddings. Defaults to None.
91 sparsification (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification
92 function to apply. Defaults to None.
93 add_marker_tokens (bool): Whether to prepend extra marker tokens [Q] / [D] to queries / documents.
94 Defaults to False.
95 query_pooling_strategy (Literal['first', 'mean', 'max', 'sum'] | str): How to pool the query
96 token embeddings. Defaults to "mean".
97 doc_pooling_strategy (Literal['first', 'mean', 'max', 'sum'] | str): How to pool document
98 token embeddings. Defaults to "mean".
99 """
100 super().__init__(
101 query_length=query_length,
102 doc_length=doc_length,
103 similarity_function=similarity_function,
104 normalization=normalization,
105 sparsification=sparsification,
106 add_marker_tokens=add_marker_tokens,
107 **kwargs,
108 )
109 self.query_pooling_strategy = query_pooling_strategy
110 self.doc_pooling_strategy = doc_pooling_strategy
111
112
[docs]
113class MultiVectorBiEncoderConfig(BiEncoderConfig):
114 """Configuration class for a multi-vector bi-encoder model."""
115
116 model_type: str = "multi-vector-bi-encoder"
117 """Model type for multi-vector bi-encoder models."""
118
[docs]
119 def __init__(
120 self,
121 query_length: int | None = 32,
122 doc_length: int | None = 512,
123 similarity_function: Literal["cosine", "dot"] = "dot",
124 normalization: Literal["l2"] | None = None,
125 sparsification: None | Literal["relu", "relu_log", "relu_2xlog"] = None,
126 add_marker_tokens: bool = False,
127 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
128 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
129 query_aggregation_function: Literal["sum", "mean", "max"] = "sum",
130 doc_aggregation_function: Literal["sum", "mean", "max"] = "max",
131 **kwargs,
132 ):
133 """A multi-vector bi-encoder model keeps the representation of all tokens in query or document and computes a
134 relevance score by aggregating the similarities of query-document token pairs. Optionally, some tokens can be
135 masked out during scoring.
136
137 Args:
138 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
139 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
140 similarity_function (Literal['cosine', 'dot']): Similarity function to compute scores between query and
141 document embeddings. Defaults to "dot".
142 normalization (Literal['l2'] | None): Whether to normalize query and document embeddings. Defaults to None.
143 sparsification (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification
144 function to apply. Defaults to None.
145 add_marker_tokens (bool): Whether to prepend extra marker tokens [Q] / [D] to queries / documents.
146 Defaults to False.
147 query_mask_scoring_tokens (Sequence[str] | Literal['punctuation'] | None): Whether and which query tokens
148 to ignore during scoring. Defaults to None.
149 doc_mask_scoring_tokens (Sequence[str] | Literal['punctuation'] | None): Whether and which document tokens
150 to ignore during scoring. Defaults to None.
151 query_aggregation_function (Literal['sum', 'mean', 'max']): How to aggregate similarity
152 scores over query tokens. Defaults to "sum".
153 doc_aggregation_function (Literal['sum', 'mean', 'max']): How to aggregate similarity
154 scores over doc tokens. Defaults to "max".
155 """
156 super().__init__(
157 query_length=query_length,
158 doc_length=doc_length,
159 similarity_function=similarity_function,
160 normalization=normalization,
161 sparsification=sparsification,
162 add_marker_tokens=add_marker_tokens,
163 **kwargs,
164 )
165 self.query_mask_scoring_tokens = query_mask_scoring_tokens
166 self.doc_mask_scoring_tokens = doc_mask_scoring_tokens
167 self.query_aggregation_function = query_aggregation_function
168 self.doc_aggregation_function = doc_aggregation_function