Source code for lightning_ir.cross_encoder.cross_encoder_config

 1"""
 2Configuration module for cross-encoder models.
 3
 4This module defines the configuration class used to instantiate cross-encoder models.
 5"""
 6
 7from typing import Literal
 8
 9from ..base import LightningIRConfig
10
11
[docs] 12class CrossEncoderConfig(LightningIRConfig): 13 model_type: str = "cross-encoder" 14 """Model type for cross-encoder models.""" 15
[docs] 16 def __init__( 17 self, 18 query_length: int | None = 32, 19 doc_length: int | None = 512, 20 pooling_strategy: Literal["first", "mean", "max", "sum"] = "first", 21 linear_bias: bool = False, 22 **kwargs 23 ): 24 """Configuration class for a cross-encoder model 25 26 Args: 27 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 28 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 29 pooling_strategy (Literal['first', 'mean', 'max', 'sum']): Pooling strategy to aggregate the 30 contextualized embeddings into a single vector for computing a relevance score. Defaults to "first". 31 linear_bias (bool): Whether to use a bias in the prediction linear layer. Defaults to False. 32 """ 33 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs) 34 self.pooling_strategy = pooling_strategy 35 self.linear_bias = linear_bias