1from typing import Dict, Literal, Sequence, Type
2
3import torch
4from transformers import BatchEncoding
5
6from ..cross_encoder.cross_encoder_config import CrossEncoderConfig
7from ..cross_encoder.cross_encoder_model import CrossEncoderModel, CrossEncoderOutput
8from ..cross_encoder.cross_encoder_tokenizer import CrossEncoderTokenizer
9
10
[docs]
11class T5CrossEncoderConfig(CrossEncoderConfig):
12
13 model_type = "encoder-decoder-cross-encoder"
14
[docs]
15 def __init__(
16 self,
17 query_length: int = 32,
18 doc_length: int = 512,
19 decoder_strategy: Literal["mono", "rank"] = "mono",
20 **kwargs,
21 ) -> None:
22 kwargs["pooling_strategy"] = "first"
23 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs)
24 self.decoder_strategy = decoder_strategy
25
26
[docs]
27class ScaleLinear(torch.nn.Linear):
28
[docs]
29 def forward(self, input: torch.Tensor) -> torch.Tensor:
30 # Rescale output before projecting on vocab
31 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa
32 input = input * (input.shape[-1] ** -0.5)
33 return super().forward(input)
34
35
[docs]
36class T5CrossEncoderModel(CrossEncoderModel):
37 config_class = T5CrossEncoderConfig
38
39 _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "linear.weight"]
40
[docs]
41 def __init__(self, config: T5CrossEncoderConfig, *args, **kwargs):
42 super().__init__(config, *args, **kwargs)
43 self.config: T5CrossEncoderConfig
44 if self.config.decoder_strategy == "mono":
45 self.linear = ScaleLinear(config.hidden_size, 2, bias=config.linear_bias)
46 else:
47 self.linear = ScaleLinear(config.hidden_size, 1, bias=config.linear_bias)
48
49 # TODO tieing of weights does not work when setting linear to only use slice of lm head for efficiency
50 # def get_output_embeddings(self):
51 # shared = self.shared
52 # if self.config.decoder_strategy == "mono":
53 # self.linear.weight.data = shared.weight.data[[1176, 6136]]
54 # elif self.config.decoder_strategy == "rank":
55 # self.linear.weight.data = shared.weight.data[[32089]]
56 # else:
57 # raise ValueError("Unknown decoder strategy")
58 # return shared
59
[docs]
60 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput:
61 decoder_input_ids = torch.zeros(
62 (encoding["input_ids"].shape[0], 1), device=encoding["input_ids"].device, dtype=torch.long
63 )
64 encoding["decoder_input_ids"] = decoder_input_ids
65 output = super().forward(encoding)
66 if output.scores is None:
67 raise ValueError("Scores are None")
68 if self.config.decoder_strategy == "mono":
69 scores = output.scores.view(-1, 2)
70 scores = torch.nn.functional.log_softmax(scores, dim=-1)[:, 0]
71 output.scores = scores.view(-1)
72 return output
73
74
[docs]
75class T5CrossEncoderTokenizer(CrossEncoderTokenizer):
76
77 config_class: Type[T5CrossEncoderConfig] = T5CrossEncoderConfig
78
[docs]
79 def __init__(
80 self,
81 *args,
82 query_length: int = 32,
83 doc_length: int = 512,
84 decoder_strategy: Literal["mono", "rank"] = "mono",
85 **kwargs,
86 ):
87 super().__init__(
88 *args, query_length=query_length, doc_length=doc_length, decoder_strategy=decoder_strategy, **kwargs
89 )
90 self.decoder_strategy = decoder_strategy
91
[docs]
92 def tokenize(
93 self,
94 queries: str | Sequence[str] | None = None,
95 docs: str | Sequence[str] | None = None,
96 num_docs: Sequence[int] | int | None = None,
97 **kwargs,
98 ) -> Dict[str, BatchEncoding]:
99 expanded_queries, docs = self._preprocess(queries, docs, num_docs)
100 if self.decoder_strategy == "mono":
101 pattern = "Query: {query} Document: {doc} Relevant:"
102 elif self.decoder_strategy == "rank":
103 pattern = "Query: {query} Document: {doc}"
104 else:
105 raise ValueError(f"Unknown decoder strategy: {self.decoder_strategy}")
106 input_texts = [pattern.format(query=query, doc=doc) for query, doc in zip(expanded_queries, docs)]
107
108 return_tensors = kwargs.get("return_tensors", None)
109 if return_tensors is not None:
110 kwargs["pad_to_multiple_of"] = 8
111 return {"encoding": self(input_texts, **kwargs)}