1"""
2Model implementation for mono cross-encoder models.
3
4A mono cross-encoder model, such as MonoBERT or MonoT5, maximizes retrieval accuracy by processing the user's query and
5the target document simultaneously. Instead of encoding texts separately like DPR or delaying their interaction like
6ColBERT, a cross-encoder combines the query and document into a single text sequence before passing them through the
7neural network. This "early interaction" allows every word in the query to deeply contextualize with every word in the
8document, producing a highly precise relevance score. However, because this architecture requires processing every
9potential query and document pair together from scratch, it is computationally prohibitive for large databases and is
10instead used almost exclusively as a second-stage re-ranker to carefully sort a small list of candidate documents
11already found by faster models.
12
13Originally introduced in
14`Passage Re-ranking with BERT
15<https://arxiv.org/abs/1901.04085>`_.
16"""
17
18from typing import Literal
19
20import torch
21from transformers import BatchEncoding
22
23from ...base.model import batch_encoding_wrapper
24from ...cross_encoder import CrossEncoderConfig, CrossEncoderModel, CrossEncoderOutput
25from ...modeling_utils.embedding_post_processing import Pooler
26
27
[docs]
28class ScaleLinear(torch.nn.Linear):
[docs]
29 def forward(self, input: torch.Tensor) -> torch.Tensor:
30 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa
31 input = input * (input.shape[-1] ** -0.5)
32 return super().forward(input)
33
34
[docs]
35class MonoConfig(CrossEncoderConfig):
36 """Configuration class for mono cross-encoder models."""
37
38 model_type = "mono"
39 """Model type for mono cross-encoder models."""
40
[docs]
41 def __init__(
42 self,
43 query_length: int | None = 32,
44 doc_length: int | None = 512,
45 pooling_strategy: Literal["first", "mean", "max", "sum", "bert_pool"] = "first",
46 linear_bias: bool = False,
47 scoring_strategy: Literal["mono", "rank"] = "rank",
48 tokenizer_pattern: str | None = None,
49 **kwargs,
50 ):
51 """Initialize the configuration for mono cross-encoder models.
52
53 Args:
54 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
55 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
56 pooling_strategy (Literal["first", "mean", "max", "sum", "bert_pool"]): Pooling strategy for the
57 embeddings. Defaults to "first".
58 linear_bias (bool): Whether to use bias in the final linear layer. Defaults to False.
59 scoring_strategy (Literal["mono", "rank"]): Scoring strategy to use. Defaults to "rank".
60 tokenizer_pattern (str | None): Optional pattern for tokenization. Defaults to None.
61 """
62 self._bert_pool = False
63 if pooling_strategy == "bert_pool": # some models use the internal BERT pooler
64 self._bert_pool = True
65 pooling_strategy = "first"
66 super().__init__(
67 query_length=query_length,
68 doc_length=doc_length,
69 pooling_strategy=pooling_strategy,
70 linear_bias=linear_bias,
71 **kwargs,
72 )
73 self.scoring_strategy = scoring_strategy
74 self.tokenizer_pattern = tokenizer_pattern
75
76
[docs]
77class MonoModel(CrossEncoderModel):
78 config_class: type[MonoConfig] = MonoConfig
79 """Configuration class for mono cross-encoder models."""
80
[docs]
81 def __init__(self, config: MonoConfig, *args, **kwargs):
82 """A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are
83 aggragated into a single vector and fed to a linear layer which computes a final relevance score.
84
85 Args:
86 config (MonoConfig): Configuration for the mono cross-encoder model.
87 """
88 super().__init__(config, *args, **kwargs)
89
90 if self.config.scoring_strategy == "mono":
91 output_dim = 2
92 elif self.config.scoring_strategy == "rank":
93 output_dim = 1
94 else:
95 raise ValueError(
96 f"Unknown scoring strategy {self.config.scoring_strategy}. Supported strategies are 'mono' and 'rank'."
97 )
98
99 self.bert_pool = torch.nn.Identity()
100 if self.config._bert_pool:
101 self.bert_pool = torch.nn.Sequential(
102 torch.nn.Linear(config.hidden_size, config.hidden_size), torch.nn.Tanh()
103 )
104
105 if self.config.backbone_model_type == "t5":
106 self.linear = ScaleLinear(config.hidden_size, output_dim, bias=self.config.linear_bias)
107 else:
108 self.linear = torch.nn.Linear(config.hidden_size, output_dim, bias=self.config.linear_bias)
109 self.pooler = Pooler(config)
110
[docs]
111 @batch_encoding_wrapper
112 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput:
113 """Computes contextualized embeddings for the joint query-document input sequence and computes a relevance
114 score.
115
116 Args:
117 encoding (BatchEncoding): Tokenizer encodings for the joint query-document input sequence.
118 Returns:
119 CrossEncoderOutput: Output of the model.
120 """
121 if self.config.is_encoder_decoder:
122 # NOTE encoder-decoder models other than t5 might not use 0 as the sos token id
123 decoder_input_ids = torch.zeros(
124 (encoding["input_ids"].shape[0], 1), device=encoding["input_ids"].device, dtype=torch.long
125 )
126 encoding["decoder_input_ids"] = decoder_input_ids
127 embeddings = self._backbone_forward(**encoding).last_hidden_state
128 embeddings = self.pooler(embeddings, encoding.get("attention_mask", None))
129 embeddings = self.bert_pool(embeddings)
130 scores = self.linear(embeddings)
131
132 if self.config.scoring_strategy == "mono":
133 scores = torch.nn.functional.log_softmax(scores.view(-1, 2), dim=-1)[:, 1]
134
135 return CrossEncoderOutput(scores=scores.view(-1), embeddings=embeddings)