Source code for lightning_ir.cross_encoder.cross_encoder_config

 1"""
 2Configuration module for cross-encoder models.
 3
 4This module defines the configuration class used to instantiate cross-encoder models.
 5"""
 6
 7from typing import Literal
 8
 9from ..base import LightningIRConfig
10
11
[docs] 12class CrossEncoderConfig(LightningIRConfig): 13 model_type: str = "cross-encoder" 14 """Model type for cross-encoder models.""" 15
[docs] 16 def __init__( 17 self, 18 query_length: int = 32, 19 doc_length: int = 512, 20 pooling_strategy: Literal["first", "mean", "max", "sum"] = "first", 21 linear_bias: bool = False, 22 **kwargs 23 ): 24 """Configuration class for a cross-encoder model 25 26 :param query_length: Maximum query length, defaults to 32 27 :type query_length: int, optional 28 :param doc_length: Maximum document length, defaults to 512 29 :type doc_length: int, optional 30 :param pooling_strategy: Pooling strategy to aggregate the contextualized embeddings into a single vector for 31 computing a relevance score, defaults to "first" 32 :type pooling_strategy: Literal['first', 'mean', 'max', 'sum'], optional 33 :param linear_bias: Whether to use a bias in the prediction linear layer, defaults to False 34 :type linear_bias: bool, optional 35 """ 36 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs) 37 self.pooling_strategy = pooling_strategy 38 self.linear_bias = linear_bias