1from functools import partial
2from typing import Dict, Sequence, Tuple
3
4import torch
5from tokenizers.processors import TemplateProcessing
6from transformers import BatchEncoding
7
8from ..cross_encoder.cross_encoder_config import CrossEncoderConfig
9from ..cross_encoder.cross_encoder_model import CrossEncoderModel, CrossEncoderOutput
10from ..cross_encoder.cross_encoder_tokenizer import CrossEncoderTokenizer
11
12
[docs]
13class SetEncoderConfig(CrossEncoderConfig):
14 model_type = "set-encoder"
15
[docs]
16 def __init__(
17 self,
18 *args,
19 depth: int = 100,
20 add_extra_token: bool = False,
21 sample_missing_docs: bool = True,
22 **kwargs,
23 ):
24 super().__init__(*args, **kwargs)
25 self.depth = depth
26 self.add_extra_token = add_extra_token
27 self.sample_missing_docs = sample_missing_docs
28
29
[docs]
30class SetEncoderModel(CrossEncoderModel):
31 config_class = SetEncoderConfig
32 self_attention_pattern = "self"
33
34 ALLOW_SUB_BATCHING = False # listwise model
35
[docs]
36 def __init__(self, config: SetEncoderConfig, *args, **kwargs):
37 super().__init__(config, *args, **kwargs)
38 self.config: SetEncoderConfig
39 self.attn_implementation = "eager"
40 if self.config.backbone_model_type is not None and self.config.backbone_model_type not in ("bert", "electra"):
41 raise ValueError(
42 f"SetEncoderModel does not support backbone model type {self.config.backbone_model_type}. "
43 f"Supported types are 'bert' and 'electra'."
44 )
45
[docs]
46 def get_extended_attention_mask(
47 self,
48 attention_mask: torch.Tensor,
49 input_shape: Tuple[int, ...],
50 device: torch.device | None = None,
51 dtype: torch.dtype | None = None,
52 num_docs: Sequence[int] | None = None,
53 ) -> torch.Tensor:
54 if num_docs is not None:
55 eye = (1 - torch.eye(self.config.depth, device=device)).long()
56 if not self.config.sample_missing_docs:
57 eye = eye[:, : max(num_docs)]
58 other_doc_attention_mask = torch.cat([eye[:n] for n in num_docs])
59 attention_mask = torch.cat(
60 [attention_mask, other_doc_attention_mask.to(attention_mask)],
61 dim=-1,
62 )
63 input_shape = tuple(attention_mask.shape)
64 return super().get_extended_attention_mask(attention_mask, input_shape, device, dtype)
65
[docs]
66 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput:
67 num_docs = encoding.pop("num_docs", None)
68 self.get_extended_attention_mask = partial(self.get_extended_attention_mask, num_docs=num_docs)
69 for name, module in self.named_modules():
70 if name.endswith(self.self_attention_pattern):
71 module.forward = partial(self.attention_forward, self, module, num_docs=num_docs)
72 return super().forward(encoding)
73
74 @staticmethod
75 def attention_forward(
76 _self,
77 self: torch.nn.Module,
78 hidden_states: torch.Tensor,
79 attention_mask: torch.FloatTensor | None,
80 *args,
81 num_docs: Sequence[int],
82 **kwargs,
83 ) -> Tuple[torch.Tensor]:
84 key_value_hidden_states = hidden_states
85 if num_docs is not None:
86 key_value_hidden_states = _self.cat_other_doc_hidden_states(hidden_states, num_docs)
87 query = self.transpose_for_scores(self.query(hidden_states))
88 key = self.transpose_for_scores(self.key(key_value_hidden_states))
89 value = self.transpose_for_scores(self.value(key_value_hidden_states))
90
91 context = torch.nn.functional.scaled_dot_product_attention(
92 query,
93 key,
94 value,
95 attention_mask.to(query.dtype) if attention_mask is not None else None,
96 self.dropout.p if self.training else 0,
97 )
98
99 context = context.permute(0, 2, 1, 3).contiguous()
100 new_context_shape = context.size()[:-2] + (self.all_head_size,)
101 context = context.view(new_context_shape)
102 return (context,)
103
104 def cat_other_doc_hidden_states(
105 self,
106 hidden_states: torch.Tensor,
107 num_docs: Sequence[int],
108 ) -> torch.Tensor:
109 idx = 1 if self.config.add_extra_token else 0
110 split_other_doc_hidden_states = torch.split(hidden_states[:, idx], list(num_docs))
111 repeated_other_doc_hidden_states = []
112 for idx, h_states in enumerate(split_other_doc_hidden_states):
113 missing_docs = 0 if self.config.depth is None else self.config.depth - num_docs[idx]
114 if missing_docs and self.config.sample_missing_docs:
115 mean = h_states.mean(0, keepdim=True).expand(missing_docs, -1)
116 if num_docs[idx] == 1:
117 std = torch.zeros_like(mean)
118 else:
119 std = h_states.std(0, keepdim=True).expand(missing_docs, -1)
120 sampled_h_states = torch.normal(mean, std).to(h_states)
121 h_states = torch.cat([h_states, sampled_h_states])
122 repeated_other_doc_hidden_states.append(h_states.unsqueeze(0).expand(num_docs[idx], -1, -1))
123 other_doc_hidden_states = torch.cat(repeated_other_doc_hidden_states)
124 key_value_hidden_states = torch.cat([hidden_states, other_doc_hidden_states], dim=1)
125 return key_value_hidden_states
126
127
[docs]
128class SetEncoderTokenizer(CrossEncoderTokenizer):
129
130 config_class = SetEncoderConfig
131 """Configuration class for the tokenizer."""
132
[docs]
133 def __init__(
134 self,
135 *args,
136 query_length: int = 32,
137 doc_length: int = 512,
138 add_extra_token: bool = False,
139 **kwargs,
140 ):
141 super().__init__(
142 *args, query_length=query_length, doc_length=doc_length, add_extra_token=add_extra_token, **kwargs
143 )
144 self.add_extra_token = add_extra_token
145 self.interaction_token = "[INT]"
146 if add_extra_token:
147 self.add_tokens([self.interaction_token], special_tokens=True)
148 self._tokenizer.post_processor = TemplateProcessing(
149 single="[CLS] $0 [SEP]",
150 pair="[CLS] [INT] $A [SEP] $B:1 [SEP]:1",
151 special_tokens=[
152 ("[CLS]", self.cls_token_id),
153 ("[SEP]", self.sep_token_id),
154 ("[INT]", self.interaction_token_id),
155 ],
156 )
157
158 @property
159 def interaction_token_id(self) -> int:
160 if self.interaction_token in self.added_tokens_encoder:
161 return self.added_tokens_encoder[self.interaction_token]
162 raise ValueError(f"Token {self.interaction_token} not found in tokenizer")
163
[docs]
164 def tokenize(
165 self,
166 queries: str | Sequence[str] | None = None,
167 docs: str | Sequence[str] | None = None,
168 num_docs: Sequence[int] | int | None = None,
169 **kwargs,
170 ) -> Dict[str, BatchEncoding]:
171 """Tokenizes queries and documents into a single sequence of tokens.
172
173 :param queries: Queries to tokenize, defaults to None
174 :type queries: str | Sequence[str] | None, optional
175 :param docs: Documents to tokenize, defaults to None
176 :type docs: str | Sequence[str] | None, optional
177 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)`
178 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the
179 sequence contains one value per query specifying the number of documents for that query. If an integer,
180 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing
181 the number of documents by the number of queries, defaults to None
182 :type num_docs: Sequence[int] | int | None, optional
183 :return: Tokenized query-document sequence
184 :rtype: Dict[str, BatchEncoding]
185 """
186 if queries is None or docs is None:
187 raise ValueError("Both queries and docs must be provided.")
188 if isinstance(docs, str) and not isinstance(queries, str):
189 raise ValueError("Queries and docs must be both lists or both strings.")
190 is_string_queries = False
191 is_string_docs = False
192 if isinstance(queries, str):
193 queries = [queries]
194 is_string_queries = True
195 if isinstance(docs, str):
196 docs = [docs]
197 is_string_docs = True
198 is_string_both = is_string_queries and is_string_docs
199 num_docs = self._process_num_docs(queries, docs, num_docs)
200 queries, docs = self._preprocess(queries, docs, num_docs)
201 return_tensors = kwargs.get("return_tensors", None)
202 if return_tensors is not None:
203 kwargs["pad_to_multiple_of"] = 8
204 if is_string_both:
205 encoding = self(queries[0], docs[0], **kwargs)
206 else:
207 encoding = self(queries, docs, **kwargs)
208 return {"encoding": BatchEncoding({**encoding, "num_docs": num_docs})}