Source code for lightning_ir.cross_encoder.cross_encoder_config
1"""
2Configuration module for cross-encoder models.
3
4This module defines the configuration class used to instantiate cross-encoder models.
5"""
6
7from typing import Literal
8
9from ..base import LightningIRConfig
10
11
[docs]
12class CrossEncoderConfig(LightningIRConfig):
13 model_type: str = "cross-encoder"
14 """Model type for cross-encoder models."""
15
[docs]
16 def __init__(
17 self,
18 query_length: int | None = 32,
19 doc_length: int | None = 512,
20 pooling_strategy: Literal["first", "mean", "max", "sum"] = "first",
21 linear_bias: bool = False,
22 **kwargs
23 ):
24 """Configuration class for a cross-encoder model
25
26 Args:
27 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
28 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
29 pooling_strategy (Literal['first', 'mean', 'max', 'sum']): Pooling strategy to aggregate the
30 contextualized embeddings into a single vector for computing a relevance score. Defaults to "first".
31 linear_bias (bool): Whether to use a bias in the prediction linear layer. Defaults to False.
32 """
33 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs)
34 self.pooling_strategy = pooling_strategy
35 self.linear_bias = linear_bias