1"""
2Configuration module for bi-encoder models.
3
4This module defines the configuration class used to instantiate bi-encoder models.
5"""
6
7from typing import Literal, Sequence
8
9from ..base import LightningIRConfig
10
11
[docs]
12class BiEncoderConfig(LightningIRConfig):
13 """Configuration class for a bi-encoder model."""
14
15 model_type: str = "bi-encoder"
16 """Model type for bi-encoder models."""
17
[docs]
18 def __init__(
19 self,
20 query_length: int = 32,
21 doc_length: int = 512,
22 similarity_function: Literal["cosine", "dot"] = "dot",
23 normalize: bool = False,
24 sparsification: Literal["relu", "relu_log"] | None = None,
25 add_marker_tokens: bool = False,
26 **kwargs,
27 ):
28 """A bi-encoder model encodes queries and documents separately and computes a relevance score based on the
29 similarity of the query and document embeddings. Normalization and sparsification can be applied to the
30 embeddings before computing the similarity score.
31
32 :param query_length: Maximum query length, defaults to 32
33 :type query_length: int, optional
34 :param doc_length: Maximum document length, defaults to 512
35 :type doc_length: int, optional
36 :param similarity_function: Similarity function to compute scores between query and document embeddings,
37 defaults to "dot"
38 :type similarity_function: Literal['cosine', 'dot'], optional
39 :param normalize: Whether to normalize query and document embeddings, defaults to False
40 :type normalize: bool, optional
41 :param sparsification: Whether and which sparsification function to apply, defaults to None
42 :type sparsification: Literal['relu', 'relu_log'] | None, optional
43 :param add_marker_tokens: Whether to preprend extra marker tokens [Q] / [D] to queries / documents,
44 defaults to False
45 :type add_marker_tokens: bool, optional
46 """
47 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs)
48 self.similarity_function = similarity_function
49 self.normalize = normalize
50 self.sparsification = sparsification
51 self.add_marker_tokens = add_marker_tokens
52 self.embedding_dim: int | None = getattr(self, "hidden_size", None)
53
54
[docs]
55class SingleVectorBiEncoderConfig(BiEncoderConfig):
56 """Configuration class for a single-vector bi-encoder model."""
57
58 model_type: str = "single-vector-bi-encoder"
59 """Model type for single-vector bi-encoder models."""
60
[docs]
61 def __init__(
62 self,
63 query_length: int = 32,
64 doc_length: int = 512,
65 similarity_function: Literal["cosine", "dot"] = "dot",
66 normalize: bool = False,
67 sparsification: Literal["relu", "relu_log"] | None = None,
68 add_marker_tokens: bool = False,
69 query_pooling_strategy: Literal["first", "mean", "max", "sum"] = "mean",
70 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] = "mean",
71 **kwargs,
72 ):
73 """Configuration class for a single-vector bi-encoder model. A single-vector bi-encoder model pools the
74 representations of queries and documents into a single vector before computing a similarity score.
75
76 :param query_length: Maximum query length, defaults to 32
77 :type query_length: int, optional
78 :param doc_length: Maximum document length, defaults to 512
79 :type doc_length: int, optional
80 :param similarity_function: Similarity function to compute scores between query and document embeddings,
81 defaults to "dot"
82 :type similarity_function: Literal['cosine', 'dot'], optional
83 :param normalize: Whether to normalize query and document embeddings, defaults to False
84 :type normalize: bool, optional
85 :param sparsification: Whether and which sparsification function to apply, defaults to None
86 :type sparsification: Literal['relu', 'relu_log'] | None, optional
87 :param add_marker_tokens: Whether to preprend extra marker tokens [Q] / [D] to queries / documents,
88 defaults to False
89 :type add_marker_tokens: bool, optional
90 :param query_pooling_strategy: Whether and how to pool the query token embeddings, defaults to "mean"
91 :type query_pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None, optional
92 :param doc_pooling_strategy: Whether and how to pool document token embeddings, defaults to "mean"
93 :type doc_pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None, optional
94 """
95 super().__init__(
96 query_length=query_length,
97 doc_length=doc_length,
98 similarity_function=similarity_function,
99 normalize=normalize,
100 sparsification=sparsification,
101 add_marker_tokens=add_marker_tokens,
102 **kwargs,
103 )
104 self.query_pooling_strategy = query_pooling_strategy
105 self.doc_pooling_strategy = doc_pooling_strategy
106
107
[docs]
108class MultiVectorBiEncoderConfig(BiEncoderConfig):
109 """Configuration class for a multi-vector bi-encoder model."""
110
111 model_type: str = "multi-vector-bi-encoder"
112 """Model type for multi-vector bi-encoder models."""
113
[docs]
114 def __init__(
115 self,
116 query_length: int = 32,
117 doc_length: int = 512,
118 similarity_function: Literal["cosine", "dot"] = "dot",
119 normalize: bool = False,
120 sparsification: None | Literal["relu", "relu_log"] = None,
121 add_marker_tokens: bool = False,
122 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
123 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None,
124 query_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "sum",
125 doc_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "max",
126 **kwargs,
127 ):
128 """A multi-vector bi-encoder model keeps the representation of all tokens in query or document and computes a
129 relevance score by aggregating the similarities of query-document token pairs. Optionally, some tokens can be
130 masked out during scoring.
131
132 :param query_length: Maximum query length, defaults to 32
133 :type query_length: int, optional
134 :param doc_length: Maximum document length, defaults to 512
135 :type doc_length: int, optional
136 :param similarity_function: Similarity function to compute scores between query and document embeddings,
137 defaults to "dot"
138 :type similarity_function: Literal['cosine', 'dot'], optional
139 :param normalize: Whether to normalize query and document embeddings, defaults to False
140 :type normalize: bool, optional
141 :param sparsification: Whether and which sparsification function to apply, defaults to None
142 :type sparsification: Literal['relu', 'relu_log'] | None, optional
143 :param add_marker_tokens: Whether to preprend extra marker tokens [Q] / [D] to queries / documents,
144 defaults to False
145 :type add_marker_tokens: bool, optional
146 :param query_mask_scoring_tokens: Whether and which query tokens to ignore during scoring, defaults to None
147 :type query_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None, optional
148 :param doc_mask_scoring_tokens: Whether and which document tokens to ignore during scoring, defaults to None
149 :type doc_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None, optional
150 :param doc_aggregation_function: How to aggregate similarity scores over doc tokens, defaults to "max"
151 :type doc_aggregation_function: Literal[ 'sum', 'mean', 'max', 'harmonic_mean' ], optional
152 :param query_aggregation_function: How to aggregate similarity scores over query tokens, defaults to "sum"
153 :type query_aggregation_function: Literal[ 'sum', 'mean', 'max', 'harmonic_mean' ], optional
154 """
155 super().__init__(
156 query_length, doc_length, similarity_function, normalize, sparsification, add_marker_tokens, **kwargs
157 )
158 self.query_mask_scoring_tokens = query_mask_scoring_tokens
159 self.doc_mask_scoring_tokens = doc_mask_scoring_tokens
160 self.query_aggregation_function = query_aggregation_function
161 self.doc_aggregation_function = doc_aggregation_function