Source code for lightning_ir.models.cross_encoders.set_encoder

  1"""
  2Configuration and model implementation for SetEncoder type models.
  3
  4The Set-Encoder is a cross-encoder architecture designed for listwise passage re-ranking that evaluates an
  5entire group of candidate documents simultaneously while eliminating positional bias. Traditional listwise models
  6concatenate all candidate passages into a single long text sequence, which is computationally heavy and
  7changes relevance scores based on the order the documents are inputted. The Set-Encoder circumvents this by processing
  8each passage in parallel and inserting a dedicated interaction token into each document's sequence. Through a novel
  9inter-passage attention mechanism, all the passages can share context and mathematically "communicate" by attending
 10exclusively to these special interaction tokens. This ensures the model is permutation invariant, meaning the
 11input order has zero effect on the final ranking, while keeping computational costs much lower than standard
 12concatenation methods.
 13
 14Originally proposed in
 15`set-Encoder: Permutation-Invariant Inter-passage Attention for Listwise Passage Re-ranking with Cross-Encoders
 16<https://link.springer.com/chapter/10.1007/978-3-031-88711-6_1>`_.
 17"""
 18
 19from collections.abc import Sequence
 20from functools import partial
 21
 22import torch
 23from tokenizers.processors import TemplateProcessing
 24from transformers import BatchEncoding
 25
 26from ...cross_encoder import CrossEncoderOutput, CrossEncoderTokenizer
 27from .mono import MonoConfig, MonoModel
 28
 29
[docs] 30class SetEncoderConfig(MonoConfig): 31 """Configuration class for a SetEncoder model.""" 32 33 model_type = "set-encoder" 34 """Model type for a SetEncoder model.""" 35
[docs] 36 def __init__( 37 self, 38 *args, 39 depth: int = 100, 40 add_extra_token: bool = False, 41 sample_missing_docs: bool = True, 42 **kwargs, 43 ): 44 """ 45 A SetEncoder model encodes a query and a set of documents jointly. 46 Each document's embedding is updated with context from the entire set, 47 and a relevance score is computed per document using a linear layer. 48 49 Args: 50 depth (int): Number of documents to encode per query. Defaults to 100. 51 add_extra_token (bool): Whether to add an extra token to the input sequence to separate 52 the query from the documents. Defaults to False. 53 sample_missing_docs (bool): Whether to sample missing documents when the number of documents is less 54 than the specified depth. Defaults to True. 55 """ 56 57 super().__init__(*args, **kwargs) 58 self.depth = depth 59 self.add_extra_token = add_extra_token 60 self.sample_missing_docs = sample_missing_docs
61 62
[docs] 63class SetEncoderModel(MonoModel): 64 """SetEncoder model. See :class:`SetEncoderConfig` for configuration options.""" 65 66 config_class = SetEncoderConfig 67 self_attention_pattern = "self" 68 69 ALLOW_SUB_BATCHING = False # listwise model 70
[docs] 71 def __init__(self, config: SetEncoderConfig, *args, **kwargs): 72 """Initializes a SetEncoder model give a :class:`SetEncoderConfig`. 73 74 Args: 75 config (SetEncoderConfig): Configuration for the SetEncoder model. 76 """ 77 super().__init__(config, *args, **kwargs) 78 self.config: SetEncoderConfig 79 self.attn_implementation = "eager" 80 if self.config.backbone_model_type is not None and self.config.backbone_model_type not in ("bert", "electra"): 81 raise ValueError( 82 f"SetEncoderModel does not support backbone model type {self.config.backbone_model_type}. " 83 f"Supported types are 'bert' and 'electra'." 84 )
85
[docs] 86 def get_extended_attention_mask( 87 self, 88 attention_mask: torch.Tensor, 89 input_shape: tuple[int, ...], 90 device: torch.device | None = None, 91 dtype: torch.dtype | None = None, 92 num_docs: Sequence[int] | None = None, 93 ) -> torch.Tensor: 94 """ 95 Extends the attention mask to account for the number of documents per query. 96 97 Args: 98 attention_mask (torch.Tensor): Attention mask for the input sequence. 99 input_shape (tuple[int, ...]): Shape of the input sequence. 100 device (torch.device | None): Device to move the attention mask to. Defaults to None. 101 dtype (torch.dtype | None): Data type of the attention mask. Defaults to None. 102 num_docs (Sequence[int] | None): Specifies how many documents are passed per query. If a sequence of 103 integers, `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the 104 number of documents, i.e., the sequence contains one value per query specifying the number of documents 105 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 106 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 107 Returns: 108 torch.Tensor: Extended attention mask. 109 """ 110 if num_docs is not None: 111 eye = (1 - torch.eye(self.config.depth, device=device)).long() 112 if not self.config.sample_missing_docs: 113 eye = eye[:, : max(num_docs)] 114 other_doc_attention_mask = torch.cat([eye[:n] for n in num_docs]) 115 attention_mask = torch.cat( 116 [attention_mask, other_doc_attention_mask.to(attention_mask)], 117 dim=-1, 118 ) 119 input_shape = tuple(attention_mask.shape) 120 return super().get_extended_attention_mask(attention_mask, input_shape, device, dtype)
121
[docs] 122 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput: 123 """Computes contextualized embeddings for the joint query-document input sequence and computes a relevance 124 score. 125 126 Args: 127 encoding (BatchEncoding): Tokenizer encoding for the joint query-document input sequence. 128 Returns: 129 CrossEncoderOutput: Output of the model. 130 """ 131 num_docs = encoding.pop("num_docs", None) 132 self.get_extended_attention_mask = partial(self.get_extended_attention_mask, num_docs=num_docs) 133 for name, module in self.named_modules(): 134 if name.endswith(self.self_attention_pattern): 135 module.forward = partial(self.attention_forward, self, module, num_docs=num_docs) 136 return super().forward(encoding)
137
[docs] 138 @staticmethod 139 def attention_forward( 140 _self, 141 self: torch.nn.Module, 142 hidden_states: torch.Tensor, 143 attention_mask: torch.FloatTensor | None, 144 *args, 145 num_docs: Sequence[int], 146 **kwargs, 147 ) -> tuple[torch.Tensor]: 148 """Performs the attention forward pass for the SetEncoder model. 149 150 Args: 151 _self (SetEncoderModel): Reference to the SetEncoder instance. 152 self (torch.nn.Module): Reference to the attention module. 153 hidden_states (torch.Tensor): Hidden states from the previous layer. 154 attention_mask (torch.FloatTensor | None): Attention mask for the input sequence. 155 num_docs (Sequence[int]): Specifies how many documents are passed per query. If a sequence of integers, 156 `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the number of 157 documents, i.e., the sequence contains one value per query specifying the number of documents 158 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 159 the number of documents by dividing the number of documents by the number of queries. 160 Returns: 161 tuple[torch.Tensor]: Contextualized embeddings. 162 """ 163 key_value_hidden_states = hidden_states 164 if num_docs is not None: 165 key_value_hidden_states = _self.cat_other_doc_hidden_states(hidden_states, num_docs) 166 167 batch_size = hidden_states.shape[0] 168 query = ( 169 self.query(hidden_states) 170 .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) 171 .transpose(1, 2) 172 ) 173 key = ( 174 self.key(key_value_hidden_states) 175 .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) 176 .transpose(1, 2) 177 ) 178 value = ( 179 self.value(key_value_hidden_states) 180 .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) 181 .transpose(1, 2) 182 ) 183 184 context = torch.nn.functional.scaled_dot_product_attention( 185 query, 186 key, 187 value, 188 attention_mask.to(query.dtype) if attention_mask is not None else None, 189 self.dropout.p if self.training else 0, 190 ) 191 192 context = context.permute(0, 2, 1, 3).contiguous() 193 new_context_shape = context.size()[:-2] + (self.all_head_size,) 194 context = context.view(new_context_shape) 195 return (context,)
196
[docs] 197 def cat_other_doc_hidden_states( 198 self, 199 hidden_states: torch.Tensor, 200 num_docs: Sequence[int], 201 ) -> torch.Tensor: 202 """Concatenates the hidden states of other documents to the hidden states of the query and documents. 203 204 Args: 205 hidden_states (torch.Tensor): Hidden states of the query and documents. 206 num_docs (Sequence[int]): Specifies how many documents are passed per query. If a sequence of integers, 207 `len(num_doc)` should be equal to the number of queries and `sum(num_docs)` equal to the number of 208 documents, i.e., the sequence contains one value per query specifying the number of documents 209 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 210 the number of documents by dividing the number of documents by the number of queries. 211 Returns: 212 torch.Tensor: Concatenated hidden states of the query and documents. 213 """ 214 idx = 1 if self.config.add_extra_token else 0 215 split_other_doc_hidden_states = torch.split(hidden_states[:, idx], list(num_docs)) 216 repeated_other_doc_hidden_states = [] 217 for idx, h_states in enumerate(split_other_doc_hidden_states): 218 missing_docs = 0 if self.config.depth is None else self.config.depth - num_docs[idx] 219 if missing_docs and self.config.sample_missing_docs: 220 mean = h_states.mean(0, keepdim=True).expand(missing_docs, -1) 221 if num_docs[idx] == 1: 222 std = torch.zeros_like(mean) 223 else: 224 std = h_states.std(0, keepdim=True).expand(missing_docs, -1) 225 sampled_h_states = torch.normal(mean, std).to(h_states) 226 h_states = torch.cat([h_states, sampled_h_states]) 227 repeated_other_doc_hidden_states.append(h_states.unsqueeze(0).expand(num_docs[idx], -1, -1)) 228 other_doc_hidden_states = torch.cat(repeated_other_doc_hidden_states) 229 key_value_hidden_states = torch.cat([hidden_states, other_doc_hidden_states], dim=1) 230 return key_value_hidden_states
231 232
[docs] 233class SetEncoderTokenizer(CrossEncoderTokenizer): 234 config_class = SetEncoderConfig 235 """Configuration class for the tokenizer.""" 236
[docs] 237 def __init__( 238 self, 239 *args, 240 query_length: int | None = 32, 241 doc_length: int | None = 512, 242 add_extra_token: bool = False, 243 **kwargs, 244 ): 245 """Initializes a SetEncoder tokenizer. 246 247 Args: 248 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 249 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 250 add_extra_token (bool): Whether to add an extra interaction token. Defaults to False. 251 """ 252 super().__init__( 253 *args, query_length=query_length, doc_length=doc_length, add_extra_token=add_extra_token, **kwargs 254 ) 255 self.add_extra_token = add_extra_token 256 self.interaction_token = "[INT]" 257 if add_extra_token: 258 self.add_tokens([self.interaction_token], special_tokens=True) 259 self._tokenizer.post_processor = TemplateProcessing( 260 single="[CLS] $0 [SEP]", 261 pair="[CLS] [INT] $A [SEP] $B:1 [SEP]:1", 262 special_tokens=[ 263 ("[CLS]", self.cls_token_id), 264 ("[SEP]", self.sep_token_id), 265 ("[INT]", self.interaction_token_id), 266 ], 267 )
268 269 @property 270 def interaction_token_id(self) -> int: 271 if self.interaction_token in self.added_tokens_encoder: 272 return self.added_tokens_encoder[self.interaction_token] 273 raise ValueError(f"Token {self.interaction_token} not found in tokenizer") 274
[docs] 275 def tokenize( 276 self, 277 queries: str | Sequence[str] | None = None, 278 docs: str | Sequence[str] | None = None, 279 num_docs: Sequence[int] | int | None = None, 280 **kwargs, 281 ) -> dict[str, BatchEncoding]: 282 """Tokenizes queries and documents into a single sequence of tokens. 283 284 Args: 285 queries (str | Sequence[str] | None): Queries to tokenize. Defaults to None. 286 docs (str | Sequence[str] | None): Documents to tokenize. Defaults to None. 287 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 288 integers, `len(num_docs)` should be equal to the number of queries and `sum(num_docs)` equal to the 289 number of documents, i.e., the sequence contains one value per query specifying the number of documents 290 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 291 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 292 Returns: 293 dict[str, BatchEncoding]: Tokenized query-document sequence. 294 Raises: 295 ValueError: If both queries and docs are None. 296 ValueError: If queries and docs are not both lists or both strings. 297 """ 298 encoding_dict = super().tokenize(queries, docs, num_docs, **kwargs) 299 encoding_dict["encoding"] = BatchEncoding({**encoding_dict["encoding"], "num_docs": num_docs}) 300 return encoding_dict