1"""
2Model implementation for mono cross-encoder models. Originally introduced in
3`Passage Re-ranking with BERT
4<https://arxiv.org/abs/1901.04085>`_.
5"""
6
7from typing import Literal, Type
8
9import torch
10from transformers import BatchEncoding
11
12from ...base.model import batch_encoding_wrapper
13from ...cross_encoder import CrossEncoderConfig, CrossEncoderModel, CrossEncoderOutput
14
15
[docs]
16class ScaleLinear(torch.nn.Linear):
17
[docs]
18 def forward(self, input: torch.Tensor) -> torch.Tensor:
19 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa
20 input = input * (input.shape[-1] ** -0.5)
21 return super().forward(input)
22
23
[docs]
24class MonoConfig(CrossEncoderConfig):
25 """Configuration class for mono cross-encoder models."""
26
27 model_type = "mono"
28 """Model type for mono cross-encoder models."""
29
[docs]
30 def __init__(
31 self,
32 query_length: int | None = 32,
33 doc_length: int | None = 512,
34 pooling_strategy: Literal["first", "mean", "max", "sum", "bert_pool"] = "first",
35 linear_bias: bool = False,
36 scoring_strategy: Literal["mono", "rank"] = "rank",
37 tokenizer_pattern: str | None = None,
38 **kwargs,
39 ):
40 """Initialize the configuration for mono cross-encoder models.
41
42 Args:
43 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
44 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
45 pooling_strategy (Literal["first", "mean", "max", "sum", "bert_pool"]): Pooling strategy for the
46 embeddings. Defaults to "first".
47 linear_bias (bool): Whether to use bias in the final linear layer. Defaults to False.
48 scoring_strategy (Literal["mono", "rank"]): Scoring strategy to use. Defaults to "rank".
49 tokenizer_pattern (str | None): Optional pattern for tokenization. Defaults to None.
50 """
51 self._bert_pool = False
52 if pooling_strategy == "bert_pool":
53 self._bert_pool = True
54 pooling_strategy = "first"
55 super().__init__(
56 query_length=query_length,
57 doc_length=doc_length,
58 pooling_strategy=pooling_strategy,
59 linear_bias=linear_bias,
60 **kwargs,
61 )
62 self.scoring_strategy = scoring_strategy
63 self.tokenizer_pattern = tokenizer_pattern
64
65
[docs]
66class MonoModel(CrossEncoderModel):
67 config_class: Type[MonoConfig] = MonoConfig
68 """Configuration class for mono cross-encoder models."""
69
[docs]
70 def __init__(self, config: MonoConfig, *args, **kwargs):
71 """A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are
72 aggragated into a single vector and fed to a linear layer which computes a final relevance score.
73
74 Args:
75 config (MonoConfig): Configuration for the mono cross-encoder model.
76 """
77 super().__init__(config, *args, **kwargs)
78
79 if self.config.scoring_strategy == "mono":
80 output_dim = 2
81 elif self.config.scoring_strategy == "rank":
82 output_dim = 1
83 else:
84 raise ValueError(
85 f"Unknown scoring strategy {self.config.scoring_strategy}. Supported strategies are 'mono' and 'rank'."
86 )
87
88 self.bert_pool = torch.nn.Identity()
89 if self.config._bert_pool:
90 self.bert_pool = torch.nn.Sequential(
91 torch.nn.Linear(config.hidden_size, config.hidden_size), torch.nn.Tanh()
92 )
93
94 if self.config.backbone_model_type == "t5":
95 self.linear = ScaleLinear(config.hidden_size, output_dim, bias=self.config.linear_bias)
96 else:
97 self.linear = torch.nn.Linear(config.hidden_size, output_dim, bias=self.config.linear_bias)
98
[docs]
99 @batch_encoding_wrapper
100 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput:
101 """Computes contextualized embeddings for the joint query-document input sequence and computes a relevance
102 score.
103
104 Args:
105 encoding (BatchEncoding): Tokenizer encodings for the joint query-document input sequence.
106 Returns:
107 CrossEncoderOutput: Output of the model.
108 """
109 if self.config.is_encoder_decoder:
110 # NOTE encoder-decoder models other than t5 might not use 0 as the sos token id
111 decoder_input_ids = torch.zeros(
112 (encoding["input_ids"].shape[0], 1), device=encoding["input_ids"].device, dtype=torch.long
113 )
114 encoding["decoder_input_ids"] = decoder_input_ids
115 embeddings = self._backbone_forward(**encoding).last_hidden_state
116 embeddings = self.pooling(
117 embeddings, encoding.get("attention_mask", None), pooling_strategy=self.config.pooling_strategy
118 )
119 embeddings = self.bert_pool(embeddings)
120 scores = self.linear(embeddings)
121
122 if self.config.scoring_strategy == "mono":
123 scores = torch.nn.functional.log_softmax(scores.view(-1, 2), dim=-1)[:, 1]
124
125 return CrossEncoderOutput(scores=scores.view(-1), embeddings=embeddings)