Source code for lightning_ir.models.mono

  1"""
  2Model implementation for mono cross-encoder models. Originally introduced in
  3`Passage Re-ranking with BERT
  4<https://arxiv.org/abs/1901.04085>`_.
  5"""
  6
  7from typing import Literal, Type
  8
  9import torch
 10from transformers import BatchEncoding
 11
 12from ..base.model import batch_encoding_wrapper
 13from ..cross_encoder.cross_encoder_config import CrossEncoderConfig
 14from ..cross_encoder.cross_encoder_model import CrossEncoderModel, CrossEncoderOutput
 15
 16
[docs] 17class ScaleLinear(torch.nn.Linear): 18
[docs] 19 def forward(self, input: torch.Tensor) -> torch.Tensor: 20 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa 21 input = input * (input.shape[-1] ** -0.5) 22 return super().forward(input)
23 24
[docs] 25class MonoConfig(CrossEncoderConfig): 26 """Configuration class for mono cross-encoder models.""" 27 28 model_type = "mono" 29 """Model type for mono cross-encoder models.""" 30
[docs] 31 def __init__( 32 self, 33 query_length: int = 32, 34 doc_length: int = 512, 35 pooling_strategy: Literal["first", "mean", "max", "sum", "bert_pool"] = "first", 36 linear_bias: bool = False, 37 scoring_strategy: Literal["mono", "rank"] = "rank", 38 tokenizer_pattern: str | None = None, 39 **kwargs, 40 ): 41 """Initialize the configuration for mono cross-encoder models. 42 43 Args: 44 query_length (int): Maximum query length. Defaults to 32. 45 doc_length (int): Maximum document length. Defaults to 512. 46 pooling_strategy (Literal["first", "mean", "max", "sum", "bert_pool"]): Pooling strategy for the 47 embeddings. Defaults to "first". 48 linear_bias (bool): Whether to use bias in the final linear layer. Defaults to False. 49 scoring_strategy (Literal["mono", "rank"]): Scoring strategy to use. Defaults to "rank". 50 tokenizer_pattern (str | None): Optional pattern for tokenization. Defaults to None. 51 """ 52 self._bert_pool = False 53 if pooling_strategy == "bert_pool": 54 self._bert_pool = True 55 pooling_strategy = "first" 56 super().__init__( 57 query_length=query_length, 58 doc_length=doc_length, 59 pooling_strategy=pooling_strategy, 60 linear_bias=linear_bias, 61 **kwargs, 62 ) 63 self.scoring_strategy = scoring_strategy 64 self.tokenizer_pattern = tokenizer_pattern
65 66
[docs] 67class MonoModel(CrossEncoderModel): 68 config_class: Type[MonoConfig] = MonoConfig 69 """Configuration class for mono cross-encoder models.""" 70
[docs] 71 def __init__(self, config: MonoConfig, *args, **kwargs): 72 """A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are 73 aggragated into a single vector and fed to a linear layer which computes a final relevance score. 74 75 Args: 76 config (MonoConfig): Configuration for the mono cross-encoder model. 77 """ 78 super().__init__(config, *args, **kwargs) 79 80 if self.config.scoring_strategy == "mono": 81 output_dim = 2 82 elif self.config.scoring_strategy == "rank": 83 output_dim = 1 84 else: 85 raise ValueError( 86 f"Unknown scoring strategy {self.config.scoring_strategy}. Supported strategies are 'mono' and 'rank'." 87 ) 88 89 self.bert_pool = torch.nn.Identity() 90 if self.config._bert_pool: 91 self.bert_pool = torch.nn.Sequential( 92 torch.nn.Linear(config.hidden_size, config.hidden_size), torch.nn.Tanh() 93 ) 94 95 if self.config.backbone_model_type == "t5": 96 self.linear = ScaleLinear(config.hidden_size, output_dim, bias=self.config.linear_bias) 97 else: 98 self.linear = torch.nn.Linear(config.hidden_size, output_dim, bias=self.config.linear_bias)
99
[docs] 100 @batch_encoding_wrapper 101 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput: 102 """Computes contextualized embeddings for the joint query-document input sequence and computes a relevance 103 score. 104 105 Args: 106 encoding (BatchEncoding): Tokenizer encodings for the joint query-document input sequence. 107 Returns: 108 CrossEncoderOutput: Output of the model. 109 """ 110 if self.config.is_encoder_decoder: 111 # NOTE encoder-decoder models other than t5 might not use 0 as the sos token id 112 decoder_input_ids = torch.zeros( 113 (encoding["input_ids"].shape[0], 1), device=encoding["input_ids"].device, dtype=torch.long 114 ) 115 encoding["decoder_input_ids"] = decoder_input_ids 116 embeddings = self._backbone_forward(**encoding).last_hidden_state 117 embeddings = self.pooling( 118 embeddings, encoding.get("attention_mask", None), pooling_strategy=self.config.pooling_strategy 119 ) 120 embeddings = self.bert_pool(embeddings) 121 scores = self.linear(embeddings) 122 123 if self.config.scoring_strategy == "mono": 124 scores = torch.nn.functional.log_softmax(scores.view(-1, 2), dim=-1)[:, 1] 125 126 return CrossEncoderOutput(scores=scores.view(-1), embeddings=embeddings)