Source code for lightning_ir.cross_encoder.cross_encoder_model

 1"""
 2Model module for cross-encoder models.
 3
 4This module defines the model class used to implement cross-encoder models.
 5"""
 6
 7from abc import ABC, abstractmethod
 8from dataclasses import dataclass
 9
10import torch
11from transformers import BatchEncoding
12
13from ..base import LightningIRModel, LightningIROutput
14from ..base.model import batch_encoding_wrapper
15from . import CrossEncoderConfig
16
17
[docs] 18@dataclass 19class CrossEncoderOutput(LightningIROutput): 20 """Dataclass containing the output of a cross-encoder model""" 21 22 embeddings: torch.Tensor | None = None 23 """Joint query-document embeddings"""
24 25
[docs] 26class CrossEncoderModel(LightningIRModel, ABC): 27 config_class: type[CrossEncoderConfig] = CrossEncoderConfig 28 """Configuration class for cross-encoder models.""" 29
[docs] 30 def __init__(self, config: CrossEncoderConfig, *args, **kwargs): 31 """A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are 32 aggragated into a single vector and fed to a linear layer which computes a final relevance score. 33 34 Args: 35 config (CrossEncoderConfig): Configuration for the cross-encoder model. 36 """ 37 super().__init__(config, *args, **kwargs) 38 self.config: CrossEncoderConfig
39
[docs] 40 @batch_encoding_wrapper 41 @abstractmethod 42 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput: 43 """Computes contextualized embeddings for the joint query-document input sequence and computes a relevance 44 score. 45 46 Args: 47 encoding (BatchEncoding): Tokenizer encoding for the joint query-document input sequence. 48 Returns: 49 CrossEncoderOutput: Output of the model. 50 """ 51 pass