Source code for lightning_ir.models.cross_encoders.mono

  1"""
  2Model implementation for mono cross-encoder models. Originally introduced in
  3`Passage Re-ranking with BERT
  4<https://arxiv.org/abs/1901.04085>`_.
  5"""
  6
  7from typing import Literal
  8
  9import torch
 10from transformers import BatchEncoding
 11
 12from ...base.model import batch_encoding_wrapper
 13from ...cross_encoder import CrossEncoderConfig, CrossEncoderModel, CrossEncoderOutput
 14from ...modeling_utils.embedding_post_processing import Pooler
 15
 16
[docs] 17class ScaleLinear(torch.nn.Linear):
[docs] 18 def forward(self, input: torch.Tensor) -> torch.Tensor: 19 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa 20 input = input * (input.shape[-1] ** -0.5) 21 return super().forward(input)
22 23
[docs] 24class MonoConfig(CrossEncoderConfig): 25 """Configuration class for mono cross-encoder models.""" 26 27 model_type = "mono" 28 """Model type for mono cross-encoder models.""" 29
[docs] 30 def __init__( 31 self, 32 query_length: int | None = 32, 33 doc_length: int | None = 512, 34 pooling_strategy: Literal["first", "mean", "max", "sum", "bert_pool"] = "first", 35 linear_bias: bool = False, 36 scoring_strategy: Literal["mono", "rank"] = "rank", 37 tokenizer_pattern: str | None = None, 38 **kwargs, 39 ): 40 """Initialize the configuration for mono cross-encoder models. 41 42 Args: 43 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 44 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 45 pooling_strategy (Literal["first", "mean", "max", "sum", "bert_pool"]): Pooling strategy for the 46 embeddings. Defaults to "first". 47 linear_bias (bool): Whether to use bias in the final linear layer. Defaults to False. 48 scoring_strategy (Literal["mono", "rank"]): Scoring strategy to use. Defaults to "rank". 49 tokenizer_pattern (str | None): Optional pattern for tokenization. Defaults to None. 50 """ 51 self._bert_pool = False 52 if pooling_strategy == "bert_pool": # some models use the internal BERT pooler 53 self._bert_pool = True 54 pooling_strategy = "first" 55 super().__init__( 56 query_length=query_length, 57 doc_length=doc_length, 58 pooling_strategy=pooling_strategy, 59 linear_bias=linear_bias, 60 **kwargs, 61 ) 62 self.scoring_strategy = scoring_strategy 63 self.tokenizer_pattern = tokenizer_pattern
64 65
[docs] 66class MonoModel(CrossEncoderModel): 67 config_class: type[MonoConfig] = MonoConfig 68 """Configuration class for mono cross-encoder models.""" 69
[docs] 70 def __init__(self, config: MonoConfig, *args, **kwargs): 71 """A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are 72 aggragated into a single vector and fed to a linear layer which computes a final relevance score. 73 74 Args: 75 config (MonoConfig): Configuration for the mono cross-encoder model. 76 """ 77 super().__init__(config, *args, **kwargs) 78 79 if self.config.scoring_strategy == "mono": 80 output_dim = 2 81 elif self.config.scoring_strategy == "rank": 82 output_dim = 1 83 else: 84 raise ValueError( 85 f"Unknown scoring strategy {self.config.scoring_strategy}. Supported strategies are 'mono' and 'rank'." 86 ) 87 88 self.bert_pool = torch.nn.Identity() 89 if self.config._bert_pool: 90 self.bert_pool = torch.nn.Sequential( 91 torch.nn.Linear(config.hidden_size, config.hidden_size), torch.nn.Tanh() 92 ) 93 94 if self.config.backbone_model_type == "t5": 95 self.linear = ScaleLinear(config.hidden_size, output_dim, bias=self.config.linear_bias) 96 else: 97 self.linear = torch.nn.Linear(config.hidden_size, output_dim, bias=self.config.linear_bias) 98 self.pooler = Pooler(config)
99
[docs] 100 @batch_encoding_wrapper 101 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput: 102 """Computes contextualized embeddings for the joint query-document input sequence and computes a relevance 103 score. 104 105 Args: 106 encoding (BatchEncoding): Tokenizer encodings for the joint query-document input sequence. 107 Returns: 108 CrossEncoderOutput: Output of the model. 109 """ 110 if self.config.is_encoder_decoder: 111 # NOTE encoder-decoder models other than t5 might not use 0 as the sos token id 112 decoder_input_ids = torch.zeros( 113 (encoding["input_ids"].shape[0], 1), device=encoding["input_ids"].device, dtype=torch.long 114 ) 115 encoding["decoder_input_ids"] = decoder_input_ids 116 embeddings = self._backbone_forward(**encoding).last_hidden_state 117 embeddings = self.pooler(embeddings, encoding.get("attention_mask", None)) 118 embeddings = self.bert_pool(embeddings) 119 scores = self.linear(embeddings) 120 121 if self.config.scoring_strategy == "mono": 122 scores = torch.nn.functional.log_softmax(scores.view(-1, 2), dim=-1)[:, 1] 123 124 return CrossEncoderOutput(scores=scores.view(-1), embeddings=embeddings)