Source code for lightning_ir.models.cross_encoders.mono

  1"""
  2Model implementation for mono cross-encoder models.
  3
  4A mono cross-encoder model, such as MonoBERT or MonoT5, maximizes retrieval accuracy by processing the user's query and
  5the target document simultaneously.  Instead of encoding texts separately like DPR or delaying their interaction like
  6ColBERT, a cross-encoder combines the query and document into a single text sequence before passing them through the
  7neural network. This "early interaction" allows every word in the query to deeply contextualize with every word in the
  8document, producing a highly precise relevance score. However, because this architecture requires processing every
  9potential query and document pair together from scratch, it is computationally prohibitive for large databases and is
 10instead used almost exclusively as a second-stage re-ranker to carefully sort a small list of candidate documents
 11already found by faster models.
 12
 13Originally introduced in
 14`Passage Re-ranking with BERT
 15<https://arxiv.org/abs/1901.04085>`_.
 16"""
 17
 18from typing import Literal
 19
 20import torch
 21from transformers import BatchEncoding
 22
 23from ...base.model import batch_encoding_wrapper
 24from ...cross_encoder import CrossEncoderConfig, CrossEncoderModel, CrossEncoderOutput
 25from ...modeling_utils.embedding_post_processing import Pooler
 26
 27
[docs] 28class ScaleLinear(torch.nn.Linear):
[docs] 29 def forward(self, input: torch.Tensor) -> torch.Tensor: 30 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa 31 input = input * (input.shape[-1] ** -0.5) 32 return super().forward(input)
33 34
[docs] 35class MonoConfig(CrossEncoderConfig): 36 """Configuration class for mono cross-encoder models.""" 37 38 model_type = "mono" 39 """Model type for mono cross-encoder models.""" 40
[docs] 41 def __init__( 42 self, 43 query_length: int | None = 32, 44 doc_length: int | None = 512, 45 pooling_strategy: Literal["first", "mean", "max", "sum", "bert_pool"] = "first", 46 linear_bias: bool = False, 47 scoring_strategy: Literal["mono", "rank"] = "rank", 48 tokenizer_pattern: str | None = None, 49 **kwargs, 50 ): 51 """Initialize the configuration for mono cross-encoder models. 52 53 Args: 54 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 55 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 56 pooling_strategy (Literal["first", "mean", "max", "sum", "bert_pool"]): Pooling strategy for the 57 embeddings. Defaults to "first". 58 linear_bias (bool): Whether to use bias in the final linear layer. Defaults to False. 59 scoring_strategy (Literal["mono", "rank"]): Scoring strategy to use. Defaults to "rank". 60 tokenizer_pattern (str | None): Optional pattern for tokenization. Defaults to None. 61 """ 62 self._bert_pool = False 63 if pooling_strategy == "bert_pool": # some models use the internal BERT pooler 64 self._bert_pool = True 65 pooling_strategy = "first" 66 super().__init__( 67 query_length=query_length, 68 doc_length=doc_length, 69 pooling_strategy=pooling_strategy, 70 linear_bias=linear_bias, 71 **kwargs, 72 ) 73 self.scoring_strategy = scoring_strategy 74 self.tokenizer_pattern = tokenizer_pattern
75 76
[docs] 77class MonoModel(CrossEncoderModel): 78 config_class: type[MonoConfig] = MonoConfig 79 """Configuration class for mono cross-encoder models.""" 80
[docs] 81 def __init__(self, config: MonoConfig, *args, **kwargs): 82 """A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are 83 aggragated into a single vector and fed to a linear layer which computes a final relevance score. 84 85 Args: 86 config (MonoConfig): Configuration for the mono cross-encoder model. 87 """ 88 super().__init__(config, *args, **kwargs) 89 90 if self.config.scoring_strategy == "mono": 91 output_dim = 2 92 elif self.config.scoring_strategy == "rank": 93 output_dim = 1 94 else: 95 raise ValueError( 96 f"Unknown scoring strategy {self.config.scoring_strategy}. Supported strategies are 'mono' and 'rank'." 97 ) 98 99 self.bert_pool = torch.nn.Identity() 100 if self.config._bert_pool: 101 self.bert_pool = torch.nn.Sequential( 102 torch.nn.Linear(config.hidden_size, config.hidden_size), torch.nn.Tanh() 103 ) 104 105 if self.config.backbone_model_type == "t5": 106 self.linear = ScaleLinear(config.hidden_size, output_dim, bias=self.config.linear_bias) 107 else: 108 self.linear = torch.nn.Linear(config.hidden_size, output_dim, bias=self.config.linear_bias) 109 self.pooler = Pooler(config)
110
[docs] 111 @batch_encoding_wrapper 112 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput: 113 """Computes contextualized embeddings for the joint query-document input sequence and computes a relevance 114 score. 115 116 Args: 117 encoding (BatchEncoding): Tokenizer encodings for the joint query-document input sequence. 118 Returns: 119 CrossEncoderOutput: Output of the model. 120 """ 121 if self.config.is_encoder_decoder: 122 # NOTE encoder-decoder models other than t5 might not use 0 as the sos token id 123 decoder_input_ids = torch.zeros( 124 (encoding["input_ids"].shape[0], 1), device=encoding["input_ids"].device, dtype=torch.long 125 ) 126 encoding["decoder_input_ids"] = decoder_input_ids 127 embeddings = self._backbone_forward(**encoding).last_hidden_state 128 embeddings = self.pooler(embeddings, encoding.get("attention_mask", None)) 129 embeddings = self.bert_pool(embeddings) 130 scores = self.linear(embeddings) 131 132 if self.config.scoring_strategy == "mono": 133 scores = torch.nn.functional.log_softmax(scores.view(-1, 2), dim=-1)[:, 1] 134 135 return CrossEncoderOutput(scores=scores.view(-1), embeddings=embeddings)