1"""
2Configuration and model implementation for SetEncoder type models. Originally proposed in
3`Set-Encoder: Permutation-Invariant Inter-passage Attention for Listwise Passage Re-ranking with Cross-Encoders
4<https://link.springer.com/chapter/10.1007/978-3-031-88711-6_1>`_.
5"""
6
7from functools import partial
8from typing import Dict, Sequence, Tuple
9
10import torch
11from tokenizers.processors import TemplateProcessing
12from transformers import BatchEncoding
13
14from ...cross_encoder import CrossEncoderOutput, CrossEncoderTokenizer
15from .mono import MonoConfig, MonoModel
16
17
[docs]
18class SetEncoderConfig(MonoConfig):
19 """Configuration class for a SetEncoder model."""
20
21 model_type = "set-encoder"
22 """Model type for a SetEncoder model."""
23
[docs]
24 def __init__(
25 self,
26 *args,
27 depth: int = 100,
28 add_extra_token: bool = False,
29 sample_missing_docs: bool = True,
30 **kwargs,
31 ):
32 """
33 A SetEncoder model encodes a query and a set of documents jointly.
34 Each document's embedding is updated with context from the entire set,
35 and a relevance score is computed per document using a linear layer.
36
37 Args:
38 depth (int): Number of documents to encode per query. Defaults to 100.
39 add_extra_token (bool): Whether to add an extra token to the input sequence to separate
40 the query from the documents. Defaults to False.
41 sample_missing_docs (bool): Whether to sample missing documents when the number of documents is less
42 than the specified depth. Defaults to True.
43 """
44
45 super().__init__(*args, **kwargs)
46 self.depth = depth
47 self.add_extra_token = add_extra_token
48 self.sample_missing_docs = sample_missing_docs
49
50
[docs]
51class SetEncoderModel(MonoModel):
52 """SetEncoder model. See :class:`SetEncoderConfig` for configuration options."""
53
54 config_class = SetEncoderConfig
55 self_attention_pattern = "self"
56
57 ALLOW_SUB_BATCHING = False # listwise model
58
[docs]
59 def __init__(self, config: SetEncoderConfig, *args, **kwargs):
60 """Initializes a SetEncoder model give a :class:`SetEncoderConfig`.
61
62 Args:
63 config (SetEncoderConfig): Configuration for the SetEncoder model.
64 """
65 super().__init__(config, *args, **kwargs)
66 self.config: SetEncoderConfig
67 self.attn_implementation = "eager"
68 if self.config.backbone_model_type is not None and self.config.backbone_model_type not in ("bert", "electra"):
69 raise ValueError(
70 f"SetEncoderModel does not support backbone model type {self.config.backbone_model_type}. "
71 f"Supported types are 'bert' and 'electra'."
72 )
73
[docs]
74 def get_extended_attention_mask(
75 self,
76 attention_mask: torch.Tensor,
77 input_shape: Tuple[int, ...],
78 device: torch.device | None = None,
79 dtype: torch.dtype | None = None,
80 num_docs: Sequence[int] | None = None,
81 ) -> torch.Tensor:
82 """
83 Extends the attention mask to account for the number of documents per query.
84
85 Args:
86 attention_mask (torch.Tensor): Attention mask for the input sequence.
87 input_shape (Tuple[int, ...]): Shape of the input sequence.
88 device (torch.device | None): Device to move the attention mask to. Defaults to None.
89 dtype (torch.dtype | None): Data type of the attention mask. Defaults to None.
90 num_docs (Sequence[int] | None): Specifies how many documents are passed per query. If a sequence of
91 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the
92 number of documents, i.e., the sequence contains one value per query specifying the number of documents
93 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
94 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
95 Returns:
96 torch.Tensor: Extended attention mask.
97 """
98 if num_docs is not None:
99 eye = (1 - torch.eye(self.config.depth, device=device)).long()
100 if not self.config.sample_missing_docs:
101 eye = eye[:, : max(num_docs)]
102 other_doc_attention_mask = torch.cat([eye[:n] for n in num_docs])
103 attention_mask = torch.cat(
104 [attention_mask, other_doc_attention_mask.to(attention_mask)],
105 dim=-1,
106 )
107 input_shape = tuple(attention_mask.shape)
108 return super().get_extended_attention_mask(attention_mask, input_shape, device, dtype)
109
[docs]
110 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput:
111 """Computes contextualized embeddings for the joint query-document input sequence and computes a relevance
112 score.
113
114 Args:
115 encoding (BatchEncoding): Tokenizer encoding for the joint query-document input sequence.
116 Returns:
117 CrossEncoderOutput: Output of the model.
118 """
119 num_docs = encoding.pop("num_docs", None)
120 self.get_extended_attention_mask = partial(self.get_extended_attention_mask, num_docs=num_docs)
121 for name, module in self.named_modules():
122 if name.endswith(self.self_attention_pattern):
123 module.forward = partial(self.attention_forward, self, module, num_docs=num_docs)
124 return super().forward(encoding)
125
[docs]
126 @staticmethod
127 def attention_forward(
128 _self,
129 self: torch.nn.Module,
130 hidden_states: torch.Tensor,
131 attention_mask: torch.FloatTensor | None,
132 *args,
133 num_docs: Sequence[int],
134 **kwargs,
135 ) -> Tuple[torch.Tensor]:
136 """Performs the attention forward pass for the SetEncoder model.
137
138 Args:
139 _self (SetEncoderModel): Reference to the SetEncoder instance.
140 self (torch.nn.Module): Reference to the attention module.
141 hidden_states (torch.Tensor): Hidden states from the previous layer.
142 attention_mask (torch.FloatTensor | None): Attention mask for the input sequence.
143 num_docs (Sequence[int]): Specifies how many documents are passed per query. If a sequence of integers,
144 `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the number of
145 documents, i.e., the sequence contains one value per query specifying the number of documents
146 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
147 the number of documents by dividing the number of documents by the number of queries.
148 Returns:
149 Tuple[torch.Tensor]: Contextualized embeddings.
150 """
151 key_value_hidden_states = hidden_states
152 if num_docs is not None:
153 key_value_hidden_states = _self.cat_other_doc_hidden_states(hidden_states, num_docs)
154
155 batch_size = hidden_states.shape[0]
156 query = (
157 self.query(hidden_states)
158 .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
159 .transpose(1, 2)
160 )
161 key = (
162 self.key(key_value_hidden_states)
163 .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
164 .transpose(1, 2)
165 )
166 value = (
167 self.value(key_value_hidden_states)
168 .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
169 .transpose(1, 2)
170 )
171
172 context = torch.nn.functional.scaled_dot_product_attention(
173 query,
174 key,
175 value,
176 attention_mask.to(query.dtype) if attention_mask is not None else None,
177 self.dropout.p if self.training else 0,
178 )
179
180 context = context.permute(0, 2, 1, 3).contiguous()
181 new_context_shape = context.size()[:-2] + (self.all_head_size,)
182 context = context.view(new_context_shape)
183 return (context,)
184
[docs]
185 def cat_other_doc_hidden_states(
186 self,
187 hidden_states: torch.Tensor,
188 num_docs: Sequence[int],
189 ) -> torch.Tensor:
190 """Concatenates the hidden states of other documents to the hidden states of the query and documents.
191
192 Args:
193 hidden_states (torch.Tensor): Hidden states of the query and documents.
194 num_docs (Sequence[int]): Specifies how many documents are passed per query. If a sequence of integers,
195 `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the number of
196 documents, i.e., the sequence contains one value per query specifying the number of documents
197 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
198 the number of documents by dividing the number of documents by the number of queries.
199 Returns:
200 torch.Tensor: Concatenated hidden states of the query and documents.
201 """
202 idx = 1 if self.config.add_extra_token else 0
203 split_other_doc_hidden_states = torch.split(hidden_states[:, idx], list(num_docs))
204 repeated_other_doc_hidden_states = []
205 for idx, h_states in enumerate(split_other_doc_hidden_states):
206 missing_docs = 0 if self.config.depth is None else self.config.depth - num_docs[idx]
207 if missing_docs and self.config.sample_missing_docs:
208 mean = h_states.mean(0, keepdim=True).expand(missing_docs, -1)
209 if num_docs[idx] == 1:
210 std = torch.zeros_like(mean)
211 else:
212 std = h_states.std(0, keepdim=True).expand(missing_docs, -1)
213 sampled_h_states = torch.normal(mean, std).to(h_states)
214 h_states = torch.cat([h_states, sampled_h_states])
215 repeated_other_doc_hidden_states.append(h_states.unsqueeze(0).expand(num_docs[idx], -1, -1))
216 other_doc_hidden_states = torch.cat(repeated_other_doc_hidden_states)
217 key_value_hidden_states = torch.cat([hidden_states, other_doc_hidden_states], dim=1)
218 return key_value_hidden_states
219
220
[docs]
221class SetEncoderTokenizer(CrossEncoderTokenizer):
222
223 config_class = SetEncoderConfig
224 """Configuration class for the tokenizer."""
225
[docs]
226 def __init__(
227 self,
228 *args,
229 query_length: int | None = 32,
230 doc_length: int | None = 512,
231 add_extra_token: bool = False,
232 **kwargs,
233 ):
234 """Initializes a SetEncoder tokenizer.
235
236 Args:
237 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
238 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
239 add_extra_token (bool): Whether to add an extra interaction token. Defaults to False.
240 """
241 super().__init__(
242 *args, query_length=query_length, doc_length=doc_length, add_extra_token=add_extra_token, **kwargs
243 )
244 self.add_extra_token = add_extra_token
245 self.interaction_token = "[INT]"
246 if add_extra_token:
247 self.add_tokens([self.interaction_token], special_tokens=True)
248 self._tokenizer.post_processor = TemplateProcessing(
249 single="[CLS] $0 [SEP]",
250 pair="[CLS] [INT] $A [SEP] $B:1 [SEP]:1",
251 special_tokens=[
252 ("[CLS]", self.cls_token_id),
253 ("[SEP]", self.sep_token_id),
254 ("[INT]", self.interaction_token_id),
255 ],
256 )
257
258 @property
259 def interaction_token_id(self) -> int:
260 if self.interaction_token in self.added_tokens_encoder:
261 return self.added_tokens_encoder[self.interaction_token]
262 raise ValueError(f"Token {self.interaction_token} not found in tokenizer")
263
[docs]
264 def tokenize(
265 self,
266 queries: str | Sequence[str] | None = None,
267 docs: str | Sequence[str] | None = None,
268 num_docs: Sequence[int] | int | None = None,
269 **kwargs,
270 ) -> Dict[str, BatchEncoding]:
271 """Tokenizes queries and documents into a single sequence of tokens.
272
273 Args:
274 queries (str | Sequence[str] | None): Queries to tokenize. Defaults to None.
275 docs (str | Sequence[str] | None): Documents to tokenize. Defaults to None.
276 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
277 integers, `len(num_docs)` should be equal to the number of queries and `sum(num_docs)` equal to the
278 number of documents, i.e., the sequence contains one value per query specifying the number of documents
279 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
280 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
281 Returns:
282 Dict[str, BatchEncoding]: Tokenized query-document sequence.
283 Raises:
284 ValueError: If both queries and docs are None.
285 ValueError: If queries and docs are not both lists or both strings.
286 """
287 encoding_dict = super().tokenize(queries, docs, num_docs, **kwargs)
288 encoding_dict["encoding"] = BatchEncoding({**encoding_dict["encoding"], "num_docs": num_docs})
289 return encoding_dict