1"""
2Model implementation for mono cross-encoder models. Originally introduced in
3`Passage Re-ranking with BERT
4<https://arxiv.org/abs/1901.04085>`_.
5"""
6
7from typing import Literal, Type
8
9import torch
10from transformers import BatchEncoding
11
12from ..base.model import batch_encoding_wrapper
13from ..cross_encoder.cross_encoder_config import CrossEncoderConfig
14from ..cross_encoder.cross_encoder_model import CrossEncoderModel, CrossEncoderOutput
15
16
[docs]
17class ScaleLinear(torch.nn.Linear):
18
[docs]
19 def forward(self, input: torch.Tensor) -> torch.Tensor:
20 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa
21 input = input * (input.shape[-1] ** -0.5)
22 return super().forward(input)
23
24
[docs]
25class MonoConfig(CrossEncoderConfig):
26 """Configuration class for mono cross-encoder models."""
27
28 model_type = "mono"
29 """Model type for mono cross-encoder models."""
30
[docs]
31 def __init__(
32 self,
33 query_length: int = 32,
34 doc_length: int = 512,
35 pooling_strategy: Literal["first", "mean", "max", "sum", "bert_pool"] = "first",
36 linear_bias: bool = False,
37 scoring_strategy: Literal["mono", "rank"] = "rank",
38 tokenizer_pattern: str | None = None,
39 **kwargs,
40 ):
41 """Initialize the configuration for mono cross-encoder models.
42
43 Args:
44 query_length (int): Maximum query length. Defaults to 32.
45 doc_length (int): Maximum document length. Defaults to 512.
46 pooling_strategy (Literal["first", "mean", "max", "sum", "bert_pool"]): Pooling strategy for the
47 embeddings. Defaults to "first".
48 linear_bias (bool): Whether to use bias in the final linear layer. Defaults to False.
49 scoring_strategy (Literal["mono", "rank"]): Scoring strategy to use. Defaults to "rank".
50 tokenizer_pattern (str | None): Optional pattern for tokenization. Defaults to None.
51 """
52 self._bert_pool = False
53 if pooling_strategy == "bert_pool":
54 self._bert_pool = True
55 pooling_strategy = "first"
56 super().__init__(
57 query_length=query_length,
58 doc_length=doc_length,
59 pooling_strategy=pooling_strategy,
60 linear_bias=linear_bias,
61 **kwargs,
62 )
63 self.scoring_strategy = scoring_strategy
64 self.tokenizer_pattern = tokenizer_pattern
65
66
[docs]
67class MonoModel(CrossEncoderModel):
68 config_class: Type[MonoConfig] = MonoConfig
69 """Configuration class for mono cross-encoder models."""
70
[docs]
71 def __init__(self, config: MonoConfig, *args, **kwargs):
72 """A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are
73 aggragated into a single vector and fed to a linear layer which computes a final relevance score.
74
75 Args:
76 config (MonoConfig): Configuration for the mono cross-encoder model.
77 """
78 super().__init__(config, *args, **kwargs)
79
80 if self.config.scoring_strategy == "mono":
81 output_dim = 2
82 elif self.config.scoring_strategy == "rank":
83 output_dim = 1
84 else:
85 raise ValueError(
86 f"Unknown scoring strategy {self.config.scoring_strategy}. Supported strategies are 'mono' and 'rank'."
87 )
88
89 self.bert_pool = torch.nn.Identity()
90 if self.config._bert_pool:
91 self.bert_pool = torch.nn.Sequential(
92 torch.nn.Linear(config.hidden_size, config.hidden_size), torch.nn.Tanh()
93 )
94
95 if self.config.backbone_model_type == "t5":
96 self.linear = ScaleLinear(config.hidden_size, output_dim, bias=self.config.linear_bias)
97 else:
98 self.linear = torch.nn.Linear(config.hidden_size, output_dim, bias=self.config.linear_bias)
99
[docs]
100 @batch_encoding_wrapper
101 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput:
102 """Computes contextualized embeddings for the joint query-document input sequence and computes a relevance
103 score.
104
105 Args:
106 encoding (BatchEncoding): Tokenizer encodings for the joint query-document input sequence.
107 Returns:
108 CrossEncoderOutput: Output of the model.
109 """
110 if self.config.is_encoder_decoder:
111 # NOTE encoder-decoder models other than t5 might not use 0 as the sos token id
112 decoder_input_ids = torch.zeros(
113 (encoding["input_ids"].shape[0], 1), device=encoding["input_ids"].device, dtype=torch.long
114 )
115 encoding["decoder_input_ids"] = decoder_input_ids
116 embeddings = self._backbone_forward(**encoding).last_hidden_state
117 embeddings = self.pooling(
118 embeddings, encoding.get("attention_mask", None), pooling_strategy=self.config.pooling_strategy
119 )
120 embeddings = self.bert_pool(embeddings)
121 scores = self.linear(embeddings)
122
123 if self.config.scoring_strategy == "mono":
124 scores = torch.nn.functional.log_softmax(scores.view(-1, 2), dim=-1)[:, 1]
125
126 return CrossEncoderOutput(scores=scores.view(-1), embeddings=embeddings)