Source code for lightning_ir.models.cross_encoders.mono

  1"""
  2Model implementation for mono cross-encoder models. Originally introduced in
  3`Passage Re-ranking with BERT
  4<https://arxiv.org/abs/1901.04085>`_.
  5"""
  6
  7from typing import Literal, Type
  8
  9import torch
 10from transformers import BatchEncoding
 11
 12from ...base.model import batch_encoding_wrapper
 13from ...cross_encoder import CrossEncoderConfig, CrossEncoderModel, CrossEncoderOutput
 14
 15
[docs] 16class ScaleLinear(torch.nn.Linear): 17
[docs] 18 def forward(self, input: torch.Tensor) -> torch.Tensor: 19 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa 20 input = input * (input.shape[-1] ** -0.5) 21 return super().forward(input)
22 23
[docs] 24class MonoConfig(CrossEncoderConfig): 25 """Configuration class for mono cross-encoder models.""" 26 27 model_type = "mono" 28 """Model type for mono cross-encoder models.""" 29
[docs] 30 def __init__( 31 self, 32 query_length: int | None = 32, 33 doc_length: int | None = 512, 34 pooling_strategy: Literal["first", "mean", "max", "sum", "bert_pool"] = "first", 35 linear_bias: bool = False, 36 scoring_strategy: Literal["mono", "rank"] = "rank", 37 tokenizer_pattern: str | None = None, 38 **kwargs, 39 ): 40 """Initialize the configuration for mono cross-encoder models. 41 42 Args: 43 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 44 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 45 pooling_strategy (Literal["first", "mean", "max", "sum", "bert_pool"]): Pooling strategy for the 46 embeddings. Defaults to "first". 47 linear_bias (bool): Whether to use bias in the final linear layer. Defaults to False. 48 scoring_strategy (Literal["mono", "rank"]): Scoring strategy to use. Defaults to "rank". 49 tokenizer_pattern (str | None): Optional pattern for tokenization. Defaults to None. 50 """ 51 self._bert_pool = False 52 if pooling_strategy == "bert_pool": 53 self._bert_pool = True 54 pooling_strategy = "first" 55 super().__init__( 56 query_length=query_length, 57 doc_length=doc_length, 58 pooling_strategy=pooling_strategy, 59 linear_bias=linear_bias, 60 **kwargs, 61 ) 62 self.scoring_strategy = scoring_strategy 63 self.tokenizer_pattern = tokenizer_pattern
64 65
[docs] 66class MonoModel(CrossEncoderModel): 67 config_class: Type[MonoConfig] = MonoConfig 68 """Configuration class for mono cross-encoder models.""" 69
[docs] 70 def __init__(self, config: MonoConfig, *args, **kwargs): 71 """A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are 72 aggragated into a single vector and fed to a linear layer which computes a final relevance score. 73 74 Args: 75 config (MonoConfig): Configuration for the mono cross-encoder model. 76 """ 77 super().__init__(config, *args, **kwargs) 78 79 if self.config.scoring_strategy == "mono": 80 output_dim = 2 81 elif self.config.scoring_strategy == "rank": 82 output_dim = 1 83 else: 84 raise ValueError( 85 f"Unknown scoring strategy {self.config.scoring_strategy}. Supported strategies are 'mono' and 'rank'." 86 ) 87 88 self.bert_pool = torch.nn.Identity() 89 if self.config._bert_pool: 90 self.bert_pool = torch.nn.Sequential( 91 torch.nn.Linear(config.hidden_size, config.hidden_size), torch.nn.Tanh() 92 ) 93 94 if self.config.backbone_model_type == "t5": 95 self.linear = ScaleLinear(config.hidden_size, output_dim, bias=self.config.linear_bias) 96 else: 97 self.linear = torch.nn.Linear(config.hidden_size, output_dim, bias=self.config.linear_bias)
98
[docs] 99 @batch_encoding_wrapper 100 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput: 101 """Computes contextualized embeddings for the joint query-document input sequence and computes a relevance 102 score. 103 104 Args: 105 encoding (BatchEncoding): Tokenizer encodings for the joint query-document input sequence. 106 Returns: 107 CrossEncoderOutput: Output of the model. 108 """ 109 if self.config.is_encoder_decoder: 110 # NOTE encoder-decoder models other than t5 might not use 0 as the sos token id 111 decoder_input_ids = torch.zeros( 112 (encoding["input_ids"].shape[0], 1), device=encoding["input_ids"].device, dtype=torch.long 113 ) 114 encoding["decoder_input_ids"] = decoder_input_ids 115 embeddings = self._backbone_forward(**encoding).last_hidden_state 116 embeddings = self.pooling( 117 embeddings, encoding.get("attention_mask", None), pooling_strategy=self.config.pooling_strategy 118 ) 119 embeddings = self.bert_pool(embeddings) 120 scores = self.linear(embeddings) 121 122 if self.config.scoring_strategy == "mono": 123 scores = torch.nn.functional.log_softmax(scores.view(-1, 2), dim=-1)[:, 1] 124 125 return CrossEncoderOutput(scores=scores.view(-1), embeddings=embeddings)