Source code for lightning_ir.cross_encoder.cross_encoder_model

 1"""
 2Model module for cross-encoder models.
 3
 4This module defines the model class used to implement cross-encoder models.
 5"""
 6
 7from abc import ABC, abstractmethod
 8from dataclasses import dataclass
 9from typing import Type
10
11import torch
12from transformers import BatchEncoding
13
14from ..base import LightningIRModel, LightningIROutput
15from ..base.model import batch_encoding_wrapper
16from . import CrossEncoderConfig
17
18
[docs] 19@dataclass 20class CrossEncoderOutput(LightningIROutput): 21 """Dataclass containing the output of a cross-encoder model""" 22 23 embeddings: torch.Tensor | None = None 24 """Joint query-document embeddings"""
25 26
[docs] 27class CrossEncoderModel(LightningIRModel, ABC): 28 config_class: Type[CrossEncoderConfig] = CrossEncoderConfig 29 """Configuration class for cross-encoder models.""" 30
[docs] 31 def __init__(self, config: CrossEncoderConfig, *args, **kwargs): 32 """A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are 33 aggragated into a single vector and fed to a linear layer which computes a final relevance score. 34 35 Args: 36 config (CrossEncoderConfig): Configuration for the cross-encoder model. 37 """ 38 super().__init__(config, *args, **kwargs) 39 self.config: CrossEncoderConfig
40
[docs] 41 @batch_encoding_wrapper 42 @abstractmethod 43 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput: 44 """Computes contextualized embeddings for the joint query-document input sequence and computes a relevance 45 score. 46 47 Args: 48 encoding (BatchEncoding): Tokenizer encoding for the joint query-document input sequence. 49 Returns: 50 CrossEncoderOutput: Output of the model. 51 """ 52 pass