Source code for lightning_ir.bi_encoder.bi_encoder_config

  1"""
  2Configuration module for bi-encoder models.
  3
  4This module defines the configuration class used to instantiate bi-encoder models.
  5"""
  6
  7from typing import Any, Literal, Sequence
  8
  9from ..base import LightningIRConfig
 10
 11
[docs] 12class BiEncoderConfig(LightningIRConfig): 13 """Configuration class for a bi-encoder model.""" 14 15 model_type: str = "bi-encoder" 16 """Model type for bi-encoder models.""" 17
[docs] 18 def __init__( 19 self, 20 query_length: int | None = 32, 21 doc_length: int | None = 512, 22 similarity_function: Literal["cosine", "dot"] = "dot", 23 normalization: Literal["l2"] | None = None, 24 sparsification: Literal["relu", "relu_log", "relu_2xlog"] | None = None, 25 add_marker_tokens: bool = False, 26 **kwargs, 27 ): 28 """A bi-encoder model encodes queries and documents separately and computes a relevance score based on the 29 similarity of the query and document embeddings. Normalization and sparsification can be applied to the 30 embeddings before computing the similarity score. 31 32 Args: 33 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 34 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 35 similarity_function (Literal['cosine', 'dot']): Similarity function to compute scores between query and 36 document embeddings. Defaults to "dot". 37 normalization (Literal['l2'] | None): Whether to normalize query and document embeddings. Defaults to None. 38 sparsification (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification 39 function to apply. Defaults to None. 40 add_marker_tokens (bool): Whether to prepend extra marker tokens [Q] / [D] to queries / documents. 41 Defaults to False. 42 """ 43 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs) 44 self.similarity_function = similarity_function 45 self.normalization = normalization 46 self.sparsification = sparsification 47 self.add_marker_tokens = add_marker_tokens 48 self.embedding_dim: int | None = getattr(self, "hidden_size", None)
49
[docs] 50 def to_diff_dict(self) -> dict[str, Any]: 51 """ 52 Removes all attributes from the configuration that correspond to the default config attributes for 53 better readability, while always retaining the `config` attribute from the class. Serializes to a 54 Python dictionary. 55 56 Returns: 57 dict[str, Any]: Dictionary of all the attributes that make up this configuration instance. 58 """ 59 diff_dict = super().to_diff_dict() 60 diff_dict.pop("embedding_dim", None) # Exclude embedding_dim from diff_dict 61 return diff_dict
62 63
[docs] 64class SingleVectorBiEncoderConfig(BiEncoderConfig): 65 """Configuration class for a single-vector bi-encoder model.""" 66 67 model_type: str = "single-vector-bi-encoder" 68 """Model type for single-vector bi-encoder models.""" 69
[docs] 70 def __init__( 71 self, 72 query_length: int | None = 32, 73 doc_length: int | None = 512, 74 similarity_function: Literal["cosine", "dot"] = "dot", 75 normalization: Literal["l2"] | None = None, 76 sparsification: Literal["relu", "relu_log", "relu_2xlog"] | None = None, 77 add_marker_tokens: bool = False, 78 query_pooling_strategy: Literal["first", "mean", "max", "sum"] = "mean", 79 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] = "mean", 80 **kwargs, 81 ): 82 """Configuration class for a single-vector bi-encoder model. A single-vector bi-encoder model pools the 83 representations of queries and documents into a single vector before computing a similarity score. 84 85 Args: 86 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 87 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 88 similarity_function (Literal['cosine', 'dot']): Similarity function to compute scores between query and 89 document embeddings. Defaults to "dot". 90 normalization (Literal['l2'] | None): Whether to normalize query and document embeddings. Defaults to None. 91 sparsification (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification 92 function to apply. Defaults to None. 93 add_marker_tokens (bool): Whether to prepend extra marker tokens [Q] / [D] to queries / documents. 94 Defaults to False. 95 query_pooling_strategy (Literal['first', 'mean', 'max', 'sum'] | str): How to pool the query 96 token embeddings. Defaults to "mean". 97 doc_pooling_strategy (Literal['first', 'mean', 'max', 'sum'] | str): How to pool document 98 token embeddings. Defaults to "mean". 99 """ 100 super().__init__( 101 query_length=query_length, 102 doc_length=doc_length, 103 similarity_function=similarity_function, 104 normalization=normalization, 105 sparsification=sparsification, 106 add_marker_tokens=add_marker_tokens, 107 **kwargs, 108 ) 109 self.query_pooling_strategy = query_pooling_strategy 110 self.doc_pooling_strategy = doc_pooling_strategy
111 112
[docs] 113class MultiVectorBiEncoderConfig(BiEncoderConfig): 114 """Configuration class for a multi-vector bi-encoder model.""" 115 116 model_type: str = "multi-vector-bi-encoder" 117 """Model type for multi-vector bi-encoder models.""" 118
[docs] 119 def __init__( 120 self, 121 query_length: int | None = 32, 122 doc_length: int | None = 512, 123 similarity_function: Literal["cosine", "dot"] = "dot", 124 normalization: Literal["l2"] | None = None, 125 sparsification: None | Literal["relu", "relu_log", "relu_2xlog"] = None, 126 add_marker_tokens: bool = False, 127 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 128 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 129 query_aggregation_function: Literal["sum", "mean", "max"] = "sum", 130 doc_aggregation_function: Literal["sum", "mean", "max"] = "max", 131 **kwargs, 132 ): 133 """A multi-vector bi-encoder model keeps the representation of all tokens in query or document and computes a 134 relevance score by aggregating the similarities of query-document token pairs. Optionally, some tokens can be 135 masked out during scoring. 136 137 Args: 138 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 139 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 140 similarity_function (Literal['cosine', 'dot']): Similarity function to compute scores between query and 141 document embeddings. Defaults to "dot". 142 normalization (Literal['l2'] | None): Whether to normalize query and document embeddings. Defaults to None. 143 sparsification (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification 144 function to apply. Defaults to None. 145 add_marker_tokens (bool): Whether to prepend extra marker tokens [Q] / [D] to queries / documents. 146 Defaults to False. 147 query_mask_scoring_tokens (Sequence[str] | Literal['punctuation'] | None): Whether and which query tokens 148 to ignore during scoring. Defaults to None. 149 doc_mask_scoring_tokens (Sequence[str] | Literal['punctuation'] | None): Whether and which document tokens 150 to ignore during scoring. Defaults to None. 151 query_aggregation_function (Literal['sum', 'mean', 'max']): How to aggregate similarity 152 scores over query tokens. Defaults to "sum". 153 doc_aggregation_function (Literal['sum', 'mean', 'max']): How to aggregate similarity 154 scores over doc tokens. Defaults to "max". 155 """ 156 super().__init__( 157 query_length=query_length, 158 doc_length=doc_length, 159 similarity_function=similarity_function, 160 normalization=normalization, 161 sparsification=sparsification, 162 add_marker_tokens=add_marker_tokens, 163 **kwargs, 164 ) 165 self.query_mask_scoring_tokens = query_mask_scoring_tokens 166 self.doc_mask_scoring_tokens = doc_mask_scoring_tokens 167 self.query_aggregation_function = query_aggregation_function 168 self.doc_aggregation_function = doc_aggregation_function