Source code for lightning_ir.cross_encoder.cross_encoder_model
1"""
2Model module for cross-encoder models.
3
4This module defines the model class used to implement cross-encoder models.
5"""
6
7from abc import ABC, abstractmethod
8from dataclasses import dataclass
9
10import torch
11from transformers import BatchEncoding
12
13from ..base import LightningIRModel, LightningIROutput
14from ..base.model import batch_encoding_wrapper
15from . import CrossEncoderConfig
16
17
[docs]
18@dataclass
19class CrossEncoderOutput(LightningIROutput):
20 """Dataclass containing the output of a cross-encoder model"""
21
22 embeddings: torch.Tensor | None = None
23 """Joint query-document embeddings"""
24
25
[docs]
26class CrossEncoderModel(LightningIRModel, ABC):
27 config_class: type[CrossEncoderConfig] = CrossEncoderConfig
28 """Configuration class for cross-encoder models."""
29
[docs]
30 def __init__(self, config: CrossEncoderConfig, *args, **kwargs):
31 """A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are
32 aggragated into a single vector and fed to a linear layer which computes a final relevance score.
33
34 Args:
35 config (CrossEncoderConfig): Configuration for the cross-encoder model.
36 """
37 super().__init__(config, *args, **kwargs)
38 self.config: CrossEncoderConfig
39
[docs]
40 @batch_encoding_wrapper
41 @abstractmethod
42 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput:
43 """Computes contextualized embeddings for the joint query-document input sequence and computes a relevance
44 score.
45
46 Args:
47 encoding (BatchEncoding): Tokenizer encoding for the joint query-document input sequence.
48 Returns:
49 CrossEncoderOutput: Output of the model.
50 """
51 pass