MultiVectorBiEncoderConfig
- class lightning_ir.bi_encoder.bi_encoder_config.MultiVectorBiEncoderConfig(query_length: int = 32, doc_length: int = 512, similarity_function: Literal['cosine', 'dot'] = 'dot', normalize: bool = False, sparsification: None | Literal['relu', 'relu_log'] = None, add_marker_tokens: bool = False, query_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None = None, doc_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None = None, query_aggregation_function: Literal['sum', 'mean', 'max', 'harmonic_mean'] = 'sum', doc_aggregation_function: Literal['sum', 'mean', 'max', 'harmonic_mean'] = 'max', **kwargs)[source]
Bases:
BiEncoderConfig
Configuration class for a multi-vector bi-encoder model.
- __init__(query_length: int = 32, doc_length: int = 512, similarity_function: Literal['cosine', 'dot'] = 'dot', normalize: bool = False, sparsification: None | Literal['relu', 'relu_log'] = None, add_marker_tokens: bool = False, query_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None = None, doc_mask_scoring_tokens: Sequence[str] | Literal['punctuation'] | None = None, query_aggregation_function: Literal['sum', 'mean', 'max', 'harmonic_mean'] = 'sum', doc_aggregation_function: Literal['sum', 'mean', 'max', 'harmonic_mean'] = 'max', **kwargs)[source]
A multi-vector bi-encoder model keeps the representation of all tokens in query or document and computes a relevance score by aggregating the similarities of query-document token pairs. Optionally, some tokens can be masked out during scoring.
- Parameters:
query_length (int, optional) – Maximum query length, defaults to 32
doc_length (int, optional) – Maximum document length, defaults to 512
similarity_function (Literal['cosine', 'dot'], optional) – Similarity function to compute scores between query and document embeddings, defaults to “dot”
normalize (bool, optional) – Whether to normalize query and document embeddings, defaults to False
sparsification (Literal['relu', 'relu_log'] | None, optional) – Whether and which sparsification function to apply, defaults to None
add_marker_tokens (bool, optional) – Whether to preprend extra marker tokens [Q] / [D] to queries / documents, defaults to False
query_mask_scoring_tokens (Sequence[str] | Literal['punctuation'] | None, optional) – Whether and which query tokens to ignore during scoring, defaults to None
doc_mask_scoring_tokens (Sequence[str] | Literal['punctuation'] | None, optional) – Whether and which document tokens to ignore during scoring, defaults to None
doc_aggregation_function (Literal[ 'sum', 'mean', 'max', 'harmonic_mean' ], optional) – How to aggregate similarity scores over doc tokens, defaults to “max”
query_aggregation_function (Literal[ 'sum', 'mean', 'max', 'harmonic_mean' ], optional) – How to aggregate similarity scores over query tokens, defaults to “sum”
Methods
__init__
([query_length, doc_length, ...])A multi-vector bi-encoder model keeps the representation of all tokens in query or document and computes a relevance score by aggregating the similarities of query-document token pairs.
Attributes
Model type for multi-vector bi-encoder models.
- backbone_model_type: str | None = None
Backbone model type for the configuration. Set by
LightningIRModelClassFactory()
.
- classmethod from_pretrained(pretrained_model_name_or_path: str | Path, *args, **kwargs) LightningIRConfig
Loads the configuration from a pretrained model. Wraps the transformers.PretrainedConfig.from_pretrained
- Parameters:
pretrained_model_name_or_path (str | Path) – Pretrained model name or path
- Raises:
ValueError – If pre_trained_model_name_or_path is not a Lightning IR model and no
LightningIRConfig
is passed- Returns:
Derived LightningIRConfig class
- Return type:
- get_tokenizer_kwargs(Tokenizer: Type[LightningIRTokenizer]) Dict[str, Any]
Returns the keyword arguments for the tokenizer. This method is used to pass the configuration parameters to the tokenizer.
- Parameters:
Tokenizer (Type[LightningIRTokenizer]) – Class of the tokenizer to be used
- Returns:
Keyword arguments for the tokenizer
- Return type:
Dict[str, Any]
- to_dict() Dict[str, Any]
Overrides the transformers.PretrainedConfig.to_dict method to include the added arguments and the backbone model type.
- Returns:
Configuration dictionary
- Return type:
Dict[str, Any]