Source code for lightning_ir.bi_encoder.bi_encoder_config

  1"""
  2Configuration module for bi-encoder models.
  3
  4This module defines the configuration class used to instantiate bi-encoder models.
  5"""
  6
  7from collections.abc import Sequence
  8from typing import Any, Literal
  9
 10from ..base import LightningIRConfig
 11
 12
[docs] 13class BiEncoderConfig(LightningIRConfig): 14 """Configuration class for a bi-encoder model.""" 15 16 model_type: str = "bi-encoder" 17 """Model type for bi-encoder models.""" 18
[docs] 19 def __init__( 20 self, 21 query_length: int | None = 32, 22 doc_length: int | None = 512, 23 similarity_function: Literal["cosine", "dot"] = "dot", 24 normalization_strategy: Literal["l2"] | None = None, 25 sparsification_strategy: Literal["relu", "relu_log", "relu_2xlog"] | None = None, 26 add_marker_tokens: bool = False, 27 **kwargs, 28 ): 29 """A bi-encoder model encodes queries and documents separately and computes a relevance score based on the 30 similarity of the query and document embeddings. Normalization and sparsification can be applied to the 31 embeddings before computing the similarity score. 32 33 Args: 34 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 35 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 36 similarity_function (Literal['cosine', 'dot']): Similarity function to compute scores between query and 37 document embeddings. Defaults to "dot". 38 normalization_strategy (Literal['l2'] | None): Whether to normalize query and document embeddings. 39 Defaults to None. 40 sparsification_strategy (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification 41 function to apply. Defaults to None. 42 add_marker_tokens (bool): Whether to prepend extra marker tokens [Q] / [D] to queries / documents. 43 Defaults to False. 44 """ 45 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs) 46 self.similarity_function = similarity_function 47 self.normalization_strategy = normalization_strategy 48 self.sparsification_strategy = sparsification_strategy 49 self.add_marker_tokens = add_marker_tokens 50 self.embedding_dim: int | None = getattr(self, "hidden_size", None)
51
[docs] 52 def to_diff_dict(self) -> dict[str, Any]: 53 """ 54 Removes all attributes from the configuration that correspond to the default config attributes for 55 better readability, while always retaining the `config` attribute from the class. Serializes to a 56 Python dictionary. 57 58 Returns: 59 dict[str, Any]: Dictionary of all the attributes that make up this configuration instance. 60 """ 61 diff_dict = super().to_diff_dict() 62 diff_dict.pop("embedding_dim", None) # Exclude embedding_dim from diff_dict 63 return diff_dict
64 65
[docs] 66class SingleVectorBiEncoderConfig(BiEncoderConfig): 67 """Configuration class for a single-vector bi-encoder model.""" 68 69 model_type: str = "single-vector-bi-encoder" 70 """Model type for single-vector bi-encoder models.""" 71
[docs] 72 def __init__( 73 self, 74 query_length: int | None = 32, 75 doc_length: int | None = 512, 76 similarity_function: Literal["cosine", "dot"] = "dot", 77 normalization_strategy: Literal["l2"] | None = None, 78 sparsification_strategy: Literal["relu", "relu_log", "relu_2xlog"] | None = None, 79 add_marker_tokens: bool = False, 80 pooling_strategy: Literal["first", "mean", "max", "sum"] = "mean", 81 **kwargs, 82 ): 83 """Configuration class for a single-vector bi-encoder model. A single-vector bi-encoder model pools the 84 representations of queries and documents into a single vector before computing a similarity score. 85 86 Args: 87 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 88 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 89 similarity_function (Literal['cosine', 'dot']): Similarity function to compute scores between query and 90 document embeddings. Defaults to "dot". 91 normalization_strategy (Literal['l2'] | None): Whether to normalize query and document embeddings. 92 Defaults to None. 93 sparsification_strategy (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification 94 function to apply. Defaults to None. 95 add_marker_tokens (bool): Whether to prepend extra marker tokens [Q] / [D] to queries / documents. 96 Defaults to False. 97 pooling_strategy (Literal['first', 'mean', 'max', 'sum'] | str): How to pool the token embeddings. 98 Defaults to "mean". 99 """ 100 super().__init__( 101 query_length=query_length, 102 doc_length=doc_length, 103 similarity_function=similarity_function, 104 normalization_strategy=normalization_strategy, 105 sparsification_strategy=sparsification_strategy, 106 add_marker_tokens=add_marker_tokens, 107 **kwargs, 108 ) 109 self.pooling_strategy = pooling_strategy
110 111
[docs] 112class MultiVectorBiEncoderConfig(BiEncoderConfig): 113 """Configuration class for a multi-vector bi-encoder model.""" 114 115 model_type: str = "multi-vector-bi-encoder" 116 """Model type for multi-vector bi-encoder models.""" 117
[docs] 118 def __init__( 119 self, 120 query_length: int | None = 32, 121 doc_length: int | None = 512, 122 similarity_function: Literal["cosine", "dot"] = "dot", 123 normalization_strategy: Literal["l2"] | None = None, 124 sparsification_strategy: None | Literal["relu", "relu_log", "relu_2xlog"] = None, 125 add_marker_tokens: bool = False, 126 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 127 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 128 query_aggregation_function: Literal["sum", "mean", "max"] = "sum", 129 doc_aggregation_function: Literal["sum", "mean", "max"] = "max", 130 **kwargs, 131 ): 132 """A multi-vector bi-encoder model keeps the representation of all tokens in query or document and computes a 133 relevance score by aggregating the similarities of query-document token pairs. Optionally, some tokens can be 134 masked out during scoring. 135 136 Args: 137 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 138 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 139 similarity_function (Literal['cosine', 'dot']): Similarity function to compute scores between query and 140 document embeddings. Defaults to "dot". 141 normalization_strategy (Literal['l2'] | None): Whether to normalize query and document embeddings. 142 Defaults to None. 143 sparsification_strategy (Literal['relu', 'relu_log', 'relu_2xlog'] | None): Whether and which sparsification 144 function to apply. Defaults to None. 145 add_marker_tokens (bool): Whether to prepend extra marker tokens [Q] / [D] to queries / documents. 146 Defaults to False. 147 query_mask_scoring_tokens (Sequence[str] | Literal['punctuation'] | None): Whether and which query tokens 148 to ignore during scoring. Defaults to None. 149 doc_mask_scoring_tokens (Sequence[str] | Literal['punctuation'] | None): Whether and which document tokens 150 to ignore during scoring. Defaults to None. 151 query_aggregation_function (Literal['sum', 'mean', 'max']): How to aggregate similarity 152 scores over query tokens. Defaults to "sum". 153 doc_aggregation_function (Literal['sum', 'mean', 'max']): How to aggregate similarity 154 scores over doc tokens. Defaults to "max". 155 """ 156 super().__init__( 157 query_length=query_length, 158 doc_length=doc_length, 159 similarity_function=similarity_function, 160 normalization_strategy=normalization_strategy, 161 sparsification_strategy=sparsification_strategy, 162 add_marker_tokens=add_marker_tokens, 163 **kwargs, 164 ) 165 self.query_mask_scoring_tokens = query_mask_scoring_tokens 166 self.doc_mask_scoring_tokens = doc_mask_scoring_tokens 167 self.query_aggregation_function = query_aggregation_function 168 self.doc_aggregation_function = doc_aggregation_function