1"""
2Configuration module for bi-encoder models.
3
4This module defines the configuration class used to instantiate bi-encoder models.
5"""
6
7from collections.abc import Sequence
8from typing import Any, Literal
9
10from ..base import LightningIRConfig
11
12
[docs]
13class BiEncoderConfig(LightningIRConfig):
14 """Configuration class for a bi-encoder model."""
15
16 model_type: str = "bi-encoder"
17 """Model type for bi-encoder models."""
18
[docs]
19 def __init__(
20 self,
21 query_length: int | None = 32,
22 doc_length: int | None = 512,
23 similarity_function: Literal["cosine", "dot"] = "dot",
24 normalization_strategy: Literal["l2"] | None = None,
25 sparsification_strategy: Literal["relu", "relu_log", "relu_2xlog"] | None = None,
26 add_marker_tokens: bool = False,
27 **kwargs,
28 ):
29 """A bi-encoder model encodes queries and documents separately and computes a relevance score based on the
30 similarity of the query and document embeddings. Normalization and sparsification can be applied to the
31 embeddings before computing the similarity score.
32
33 Args:
34 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
35 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
36 similarity_function (Literal['cosine', 'dot']): Similarity function to compute scores between query and
37 document embeddings. Defaults to "dot".
38 normalization_strategy (Literal['l2'] | None): Whether to normalize query and document embeddings.
39 Defaults to None.
40 sparsification_strategy (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification
41 function to apply. Defaults to None.
42 add_marker_tokens (bool): Whether to prepend extra marker tokens [Q] / [D] to queries / documents.
43 Defaults to False.
44 """
45 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs)
46 self.similarity_function = similarity_function
47 self.normalization_strategy = normalization_strategy
48 self.sparsification_strategy = sparsification_strategy
49 self.add_marker_tokens = add_marker_tokens
50 self.embedding_dim: int | None = getattr(self, "hidden_size", None)
51
[docs]
52 def to_diff_dict(self) -> dict[str, Any]:
53 """
54 Removes all attributes from the configuration that correspond to the default config attributes for
55 better readability, while always retaining the `config` attribute from the class. Serializes to a
56 Python dictionary.
57
58 Returns:
59 dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
60 """
61 diff_dict = super().to_diff_dict()
62 diff_dict.pop("embedding_dim", None) # Exclude embedding_dim from diff_dict
63 return diff_dict
64
65
[docs]
66class SingleVectorBiEncoderConfig(BiEncoderConfig):
67 """Configuration class for a single-vector bi-encoder model."""
68
69 model_type: str = "single-vector-bi-encoder"
70 """Model type for single-vector bi-encoder models."""
71
[docs]
72 def __init__(
73 self,
74 query_length: int | None = 32,
75 doc_length: int | None = 512,
76 similarity_function: Literal["cosine", "dot"] = "dot",
77 normalization_strategy: Literal["l2"] | None = None,
78 sparsification_strategy: Literal["relu", "relu_log", "relu_2xlog"] | None = None,
79 add_marker_tokens: bool = False,
80 pooling_strategy: Literal["first", "mean", "max", "sum"] = "mean",
81 **kwargs,
82 ):
83 """Configuration class for a single-vector bi-encoder model. A single-vector bi-encoder model pools the
84 representations of queries and documents into a single vector before computing a similarity score.
85
86 Args:
87 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
88 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
89 similarity_function (Literal['cosine', 'dot']): Similarity function to compute scores between query and
90 document embeddings. Defaults to "dot".
91 normalization_strategy (Literal['l2'] | None): Whether to normalize query and document embeddings.
92 Defaults to None.
93 sparsification_strategy (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification
94 function to apply. Defaults to None.
95 add_marker_tokens (bool): Whether to prepend extra marker tokens [Q] / [D] to queries / documents.
96 Defaults to False.
97 pooling_strategy (Literal['first', 'mean', 'max', 'sum'] | str): How to pool the token embeddings.
98 Defaults to "mean".
99 """
100 super().__init__(
101 query_length=query_length,
102 doc_length=doc_length,
103 similarity_function=similarity_function,
104 normalization_strategy=normalization_strategy,
105 sparsification_strategy=sparsification_strategy,
106 add_marker_tokens=add_marker_tokens,
107 **kwargs,
108 )
109 self.pooling_strategy = pooling_strategy
110
111
[docs]
112class MultiVectorBiEncoderConfig(BiEncoderConfig):
113 """Configuration class for a multi-vector bi-encoder model."""
114
115 model_type: str = "multi-vector-bi-encoder"
116 """Model type for multi-vector bi-encoder models."""
117
[docs]
118 def __init__(
119 self,
120 query_length: int | None = 32,
121 doc_length: int | None = 512,
122 similarity_function: Literal["cosine", "dot"] = "dot",
123 normalization_strategy: Literal["l2"] | None = None,
124 sparsification_strategy: None | Literal["relu", "relu_log", "relu_2xlog"] = None,
125 add_marker_tokens: bool = False,
126 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
127 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
128 query_aggregation_function: Literal["sum", "mean", "max"] = "sum",
129 doc_aggregation_function: Literal["sum", "mean", "max"] = "max",
130 **kwargs,
131 ):
132 """A multi-vector bi-encoder model keeps the representation of all tokens in query or document and computes a
133 relevance score by aggregating the similarities of query-document token pairs. Optionally, some tokens can be
134 masked out during scoring.
135
136 Args:
137 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
138 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
139 similarity_function (Literal['cosine', 'dot']): Similarity function to compute scores between query and
140 document embeddings. Defaults to "dot".
141 normalization_strategy (Literal['l2'] | None): Whether to normalize query and document embeddings.
142 Defaults to None.
143 sparsification_strategy (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification
144 function to apply. Defaults to None.
145 add_marker_tokens (bool): Whether to prepend extra marker tokens [Q] / [D] to queries / documents.
146 Defaults to False.
147 query_mask_scoring_tokens (Sequence[str] | Literal['punctuation'] | None): Whether and which query tokens
148 to ignore during scoring. Defaults to None.
149 doc_mask_scoring_tokens (Sequence[str] | Literal['punctuation'] | None): Whether and which document tokens
150 to ignore during scoring. Defaults to None.
151 query_aggregation_function (Literal['sum', 'mean', 'max']): How to aggregate similarity
152 scores over query tokens. Defaults to "sum".
153 doc_aggregation_function (Literal['sum', 'mean', 'max']): How to aggregate similarity
154 scores over doc tokens. Defaults to "max".
155 """
156 super().__init__(
157 query_length=query_length,
158 doc_length=doc_length,
159 similarity_function=similarity_function,
160 normalization_strategy=normalization_strategy,
161 sparsification_strategy=sparsification_strategy,
162 add_marker_tokens=add_marker_tokens,
163 **kwargs,
164 )
165 self.query_mask_scoring_tokens = query_mask_scoring_tokens
166 self.doc_mask_scoring_tokens = doc_mask_scoring_tokens
167 self.query_aggregation_function = query_aggregation_function
168 self.doc_aggregation_function = doc_aggregation_function