Source code for lightning_ir.bi_encoder.bi_encoder_config

  1"""
  2Configuration module for bi-encoder models.
  3
  4This module defines the configuration class used to instantiate bi-encoder models.
  5"""
  6
  7from typing import Literal, Sequence
  8
  9from ..base import LightningIRConfig
 10
 11
[docs] 12class BiEncoderConfig(LightningIRConfig): 13 """Configuration class for a bi-encoder model.""" 14 15 model_type: str = "bi-encoder" 16 """Model type for bi-encoder models.""" 17
[docs] 18 def __init__( 19 self, 20 query_length: int = 32, 21 doc_length: int = 512, 22 similarity_function: Literal["cosine", "dot"] = "dot", 23 normalize: bool = False, 24 sparsification: Literal["relu", "relu_log"] | None = None, 25 add_marker_tokens: bool = False, 26 **kwargs, 27 ): 28 """A bi-encoder model encodes queries and documents separately and computes a relevance score based on the 29 similarity of the query and document embeddings. Normalization and sparsification can be applied to the 30 embeddings before computing the similarity score. 31 32 :param query_length: Maximum query length, defaults to 32 33 :type query_length: int, optional 34 :param doc_length: Maximum document length, defaults to 512 35 :type doc_length: int, optional 36 :param similarity_function: Similarity function to compute scores between query and document embeddings, 37 defaults to "dot" 38 :type similarity_function: Literal['cosine', 'dot'], optional 39 :param normalize: Whether to normalize query and document embeddings, defaults to False 40 :type normalize: bool, optional 41 :param sparsification: Whether and which sparsification function to apply, defaults to None 42 :type sparsification: Literal['relu', 'relu_log'] | None, optional 43 :param add_marker_tokens: Whether to preprend extra marker tokens [Q] / [D] to queries / documents, 44 defaults to False 45 :type add_marker_tokens: bool, optional 46 """ 47 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs) 48 self.similarity_function = similarity_function 49 self.normalize = normalize 50 self.sparsification = sparsification 51 self.add_marker_tokens = add_marker_tokens 52 self.embedding_dim: int | None = getattr(self, "hidden_size", None)
53 54
[docs] 55class SingleVectorBiEncoderConfig(BiEncoderConfig): 56 """Configuration class for a single-vector bi-encoder model.""" 57 58 model_type: str = "single-vector-bi-encoder" 59 """Model type for single-vector bi-encoder models.""" 60
[docs] 61 def __init__( 62 self, 63 query_length: int = 32, 64 doc_length: int = 512, 65 similarity_function: Literal["cosine", "dot"] = "dot", 66 normalize: bool = False, 67 sparsification: Literal["relu", "relu_log"] | None = None, 68 add_marker_tokens: bool = False, 69 query_pooling_strategy: Literal["first", "mean", "max", "sum"] = "mean", 70 doc_pooling_strategy: Literal["first", "mean", "max", "sum"] = "mean", 71 **kwargs, 72 ): 73 """Configuration class for a single-vector bi-encoder model. A single-vector bi-encoder model pools the 74 representations of queries and documents into a single vector before computing a similarity score. 75 76 :param query_length: Maximum query length, defaults to 32 77 :type query_length: int, optional 78 :param doc_length: Maximum document length, defaults to 512 79 :type doc_length: int, optional 80 :param similarity_function: Similarity function to compute scores between query and document embeddings, 81 defaults to "dot" 82 :type similarity_function: Literal['cosine', 'dot'], optional 83 :param normalize: Whether to normalize query and document embeddings, defaults to False 84 :type normalize: bool, optional 85 :param sparsification: Whether and which sparsification function to apply, defaults to None 86 :type sparsification: Literal['relu', 'relu_log'] | None, optional 87 :param add_marker_tokens: Whether to preprend extra marker tokens [Q] / [D] to queries / documents, 88 defaults to False 89 :type add_marker_tokens: bool, optional 90 :param query_pooling_strategy: Whether and how to pool the query token embeddings, defaults to "mean" 91 :type query_pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None, optional 92 :param doc_pooling_strategy: Whether and how to pool document token embeddings, defaults to "mean" 93 :type doc_pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None, optional 94 """ 95 super().__init__( 96 query_length=query_length, 97 doc_length=doc_length, 98 similarity_function=similarity_function, 99 normalize=normalize, 100 sparsification=sparsification, 101 add_marker_tokens=add_marker_tokens, 102 **kwargs, 103 ) 104 self.query_pooling_strategy = query_pooling_strategy 105 self.doc_pooling_strategy = doc_pooling_strategy
106 107
[docs] 108class MultiVectorBiEncoderConfig(BiEncoderConfig): 109 """Configuration class for a multi-vector bi-encoder model.""" 110 111 model_type: str = "multi-vector-bi-encoder" 112 """Model type for multi-vector bi-encoder models.""" 113
[docs] 114 def __init__( 115 self, 116 query_length: int = 32, 117 doc_length: int = 512, 118 similarity_function: Literal["cosine", "dot"] = "dot", 119 normalize: bool = False, 120 sparsification: None | Literal["relu", "relu_log"] = None, 121 add_marker_tokens: bool = False, 122 query_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 123 doc_mask_scoring_tokens: Sequence[str] | Literal["punctuation"] | None = None, 124 query_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "sum", 125 doc_aggregation_function: Literal["sum", "mean", "max", "harmonic_mean"] = "max", 126 **kwargs, 127 ): 128 """A multi-vector bi-encoder model keeps the representation of all tokens in query or document and computes a 129 relevance score by aggregating the similarities of query-document token pairs. Optionally, some tokens can be 130 masked out during scoring. 131 132 :param query_length: Maximum query length, defaults to 32 133 :type query_length: int, optional 134 :param doc_length: Maximum document length, defaults to 512 135 :type doc_length: int, optional 136 :param similarity_function: Similarity function to compute scores between query and document embeddings, 137 defaults to "dot" 138 :type similarity_function: Literal['cosine', 'dot'], optional 139 :param normalize: Whether to normalize query and document embeddings, defaults to False 140 :type normalize: bool, optional 141 :param sparsification: Whether and which sparsification function to apply, defaults to None 142 :type sparsification: Literal['relu', 'relu_log'] | None, optional 143 :param add_marker_tokens: Whether to preprend extra marker tokens [Q] / [D] to queries / documents, 144 defaults to False 145 :type add_marker_tokens: bool, optional 146 :param query_mask_scoring_tokens: Whether and which query tokens to ignore during scoring, defaults to None 147 :type query_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None, optional 148 :param doc_mask_scoring_tokens: Whether and which document tokens to ignore during scoring, defaults to None 149 :type doc_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None, optional 150 :param doc_aggregation_function: How to aggregate similarity scores over doc tokens, defaults to "max" 151 :type doc_aggregation_function: Literal[ 'sum', 'mean', 'max', 'harmonic_mean' ], optional 152 :param query_aggregation_function: How to aggregate similarity scores over query tokens, defaults to "sum" 153 :type query_aggregation_function: Literal[ 'sum', 'mean', 'max', 'harmonic_mean' ], optional 154 """ 155 super().__init__( 156 query_length, doc_length, similarity_function, normalize, sparsification, add_marker_tokens, **kwargs 157 ) 158 self.query_mask_scoring_tokens = query_mask_scoring_tokens 159 self.doc_mask_scoring_tokens = doc_mask_scoring_tokens 160 self.query_aggregation_function = query_aggregation_function 161 self.doc_aggregation_function = doc_aggregation_function