Source code for lightning_ir.cross_encoder.cross_encoder_config
1"""
2Configuration module for cross-encoder models.
3
4This module defines the configuration class used to instantiate cross-encoder models.
5"""
6
7from typing import Literal
8
9from ..base import LightningIRConfig
10
11
[docs]
12class CrossEncoderConfig(LightningIRConfig):
13 model_type: str = "cross-encoder"
14 """Model type for cross-encoder models."""
15
[docs]
16 def __init__(
17 self,
18 query_length: int = 32,
19 doc_length: int = 512,
20 pooling_strategy: Literal["first", "mean", "max", "sum"] = "first",
21 linear_bias: bool = False,
22 **kwargs
23 ):
24 """Configuration class for a cross-encoder model
25
26 :param query_length: Maximum query length, defaults to 32
27 :type query_length: int, optional
28 :param doc_length: Maximum document length, defaults to 512
29 :type doc_length: int, optional
30 :param pooling_strategy: Pooling strategy to aggregate the contextualized embeddings into a single vector for
31 computing a relevance score, defaults to "first"
32 :type pooling_strategy: Literal['first', 'mean', 'max', 'sum'], optional
33 :param linear_bias: Whether to use a bias in the prediction linear layer, defaults to False
34 :type linear_bias: bool, optional
35 """
36 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs)
37 self.pooling_strategy = pooling_strategy
38 self.linear_bias = linear_bias