Source code for lightning_ir.cross_encoder.cross_encoder_model
1"""
2Model module for cross-encoder models.
3
4This module defines the model class used to implement cross-encoder models.
5"""
6
7from abc import ABC, abstractmethod
8from dataclasses import dataclass
9from typing import Type
10
11import torch
12from transformers import BatchEncoding
13
14from ..base import LightningIRModel, LightningIROutput
15from ..base.model import batch_encoding_wrapper
16from . import CrossEncoderConfig
17
18
[docs]
19@dataclass
20class CrossEncoderOutput(LightningIROutput):
21 """Dataclass containing the output of a cross-encoder model"""
22
23 embeddings: torch.Tensor | None = None
24 """Joint query-document embeddings"""
25
26
[docs]
27class CrossEncoderModel(LightningIRModel, ABC):
28 config_class: Type[CrossEncoderConfig] = CrossEncoderConfig
29 """Configuration class for cross-encoder models."""
30
[docs]
31 def __init__(self, config: CrossEncoderConfig, *args, **kwargs):
32 """A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are
33 aggragated into a single vector and fed to a linear layer which computes a final relevance score.
34
35 Args:
36 config (CrossEncoderConfig): Configuration for the cross-encoder model.
37 """
38 super().__init__(config, *args, **kwargs)
39 self.config: CrossEncoderConfig
40
[docs]
41 @batch_encoding_wrapper
42 @abstractmethod
43 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput:
44 """Computes contextualized embeddings for the joint query-document input sequence and computes a relevance
45 score.
46
47 Args:
48 encoding (BatchEncoding): Tokenizer encoding for the joint query-document input sequence.
49 Returns:
50 CrossEncoderOutput: Output of the model.
51 """
52 pass