Source code for lightning_ir.models.set_encoder

  1from functools import partial
  2from typing import Dict, Sequence, Tuple
  3
  4import torch
  5from tokenizers.processors import TemplateProcessing
  6from transformers import BatchEncoding
  7
  8from ..cross_encoder.cross_encoder_config import CrossEncoderConfig
  9from ..cross_encoder.cross_encoder_model import CrossEncoderModel, CrossEncoderOutput
 10from ..cross_encoder.cross_encoder_tokenizer import CrossEncoderTokenizer
 11
 12
[docs] 13class SetEncoderConfig(CrossEncoderConfig): 14 model_type = "set-encoder" 15
[docs] 16 def __init__( 17 self, 18 *args, 19 depth: int = 100, 20 add_extra_token: bool = False, 21 sample_missing_docs: bool = True, 22 **kwargs, 23 ): 24 super().__init__(*args, **kwargs) 25 self.depth = depth 26 self.add_extra_token = add_extra_token 27 self.sample_missing_docs = sample_missing_docs
28 29
[docs] 30class SetEncoderModel(CrossEncoderModel): 31 config_class = SetEncoderConfig 32 self_attention_pattern = "self" 33 34 ALLOW_SUB_BATCHING = False # listwise model 35
[docs] 36 def __init__(self, config: SetEncoderConfig, *args, **kwargs): 37 super().__init__(config, *args, **kwargs) 38 self.config: SetEncoderConfig 39 self.attn_implementation = "eager" 40 if self.config.backbone_model_type is not None and self.config.backbone_model_type not in ("bert", "electra"): 41 raise ValueError( 42 f"SetEncoderModel does not support backbone model type {self.config.backbone_model_type}. " 43 f"Supported types are 'bert' and 'electra'." 44 )
45
[docs] 46 def get_extended_attention_mask( 47 self, 48 attention_mask: torch.Tensor, 49 input_shape: Tuple[int, ...], 50 device: torch.device | None = None, 51 dtype: torch.dtype | None = None, 52 num_docs: Sequence[int] | None = None, 53 ) -> torch.Tensor: 54 if num_docs is not None: 55 eye = (1 - torch.eye(self.config.depth, device=device)).long() 56 if not self.config.sample_missing_docs: 57 eye = eye[:, : max(num_docs)] 58 other_doc_attention_mask = torch.cat([eye[:n] for n in num_docs]) 59 attention_mask = torch.cat( 60 [attention_mask, other_doc_attention_mask.to(attention_mask)], 61 dim=-1, 62 ) 63 input_shape = tuple(attention_mask.shape) 64 return super().get_extended_attention_mask(attention_mask, input_shape, device, dtype)
65
[docs] 66 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput: 67 num_docs = encoding.pop("num_docs", None) 68 self.get_extended_attention_mask = partial(self.get_extended_attention_mask, num_docs=num_docs) 69 for name, module in self.named_modules(): 70 if name.endswith(self.self_attention_pattern): 71 module.forward = partial(self.attention_forward, self, module, num_docs=num_docs) 72 return super().forward(encoding)
73 74 @staticmethod 75 def attention_forward( 76 _self, 77 self: torch.nn.Module, 78 hidden_states: torch.Tensor, 79 attention_mask: torch.FloatTensor | None, 80 *args, 81 num_docs: Sequence[int], 82 **kwargs, 83 ) -> Tuple[torch.Tensor]: 84 key_value_hidden_states = hidden_states 85 if num_docs is not None: 86 key_value_hidden_states = _self.cat_other_doc_hidden_states(hidden_states, num_docs) 87 query = self.transpose_for_scores(self.query(hidden_states)) 88 key = self.transpose_for_scores(self.key(key_value_hidden_states)) 89 value = self.transpose_for_scores(self.value(key_value_hidden_states)) 90 91 context = torch.nn.functional.scaled_dot_product_attention( 92 query, 93 key, 94 value, 95 attention_mask.to(query.dtype) if attention_mask is not None else None, 96 self.dropout.p if self.training else 0, 97 ) 98 99 context = context.permute(0, 2, 1, 3).contiguous() 100 new_context_shape = context.size()[:-2] + (self.all_head_size,) 101 context = context.view(new_context_shape) 102 return (context,) 103 104 def cat_other_doc_hidden_states( 105 self, 106 hidden_states: torch.Tensor, 107 num_docs: Sequence[int], 108 ) -> torch.Tensor: 109 idx = 1 if self.config.add_extra_token else 0 110 split_other_doc_hidden_states = torch.split(hidden_states[:, idx], list(num_docs)) 111 repeated_other_doc_hidden_states = [] 112 for idx, h_states in enumerate(split_other_doc_hidden_states): 113 missing_docs = 0 if self.config.depth is None else self.config.depth - num_docs[idx] 114 if missing_docs and self.config.sample_missing_docs: 115 mean = h_states.mean(0, keepdim=True).expand(missing_docs, -1) 116 if num_docs[idx] == 1: 117 std = torch.zeros_like(mean) 118 else: 119 std = h_states.std(0, keepdim=True).expand(missing_docs, -1) 120 sampled_h_states = torch.normal(mean, std).to(h_states) 121 h_states = torch.cat([h_states, sampled_h_states]) 122 repeated_other_doc_hidden_states.append(h_states.unsqueeze(0).expand(num_docs[idx], -1, -1)) 123 other_doc_hidden_states = torch.cat(repeated_other_doc_hidden_states) 124 key_value_hidden_states = torch.cat([hidden_states, other_doc_hidden_states], dim=1) 125 return key_value_hidden_states
126 127
[docs] 128class SetEncoderTokenizer(CrossEncoderTokenizer): 129 130 config_class = SetEncoderConfig 131 """Configuration class for the tokenizer.""" 132
[docs] 133 def __init__( 134 self, 135 *args, 136 query_length: int = 32, 137 doc_length: int = 512, 138 add_extra_token: bool = False, 139 **kwargs, 140 ): 141 super().__init__( 142 *args, query_length=query_length, doc_length=doc_length, add_extra_token=add_extra_token, **kwargs 143 ) 144 self.add_extra_token = add_extra_token 145 self.interaction_token = "[INT]" 146 if add_extra_token: 147 self.add_tokens([self.interaction_token], special_tokens=True) 148 self._tokenizer.post_processor = TemplateProcessing( 149 single="[CLS] $0 [SEP]", 150 pair="[CLS] [INT] $A [SEP] $B:1 [SEP]:1", 151 special_tokens=[ 152 ("[CLS]", self.cls_token_id), 153 ("[SEP]", self.sep_token_id), 154 ("[INT]", self.interaction_token_id), 155 ], 156 )
157 158 @property 159 def interaction_token_id(self) -> int: 160 if self.interaction_token in self.added_tokens_encoder: 161 return self.added_tokens_encoder[self.interaction_token] 162 raise ValueError(f"Token {self.interaction_token} not found in tokenizer") 163
[docs] 164 def tokenize( 165 self, 166 queries: str | Sequence[str] | None = None, 167 docs: str | Sequence[str] | None = None, 168 num_docs: Sequence[int] | int | None = None, 169 **kwargs, 170 ) -> Dict[str, BatchEncoding]: 171 """Tokenizes queries and documents into a single sequence of tokens. 172 173 :param queries: Queries to tokenize, defaults to None 174 :type queries: str | Sequence[str] | None, optional 175 :param docs: Documents to tokenize, defaults to None 176 :type docs: str | Sequence[str] | None, optional 177 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)` 178 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the 179 sequence contains one value per query specifying the number of documents for that query. If an integer, 180 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing 181 the number of documents by the number of queries, defaults to None 182 :type num_docs: Sequence[int] | int | None, optional 183 :return: Tokenized query-document sequence 184 :rtype: Dict[str, BatchEncoding] 185 """ 186 if queries is None or docs is None: 187 raise ValueError("Both queries and docs must be provided.") 188 if isinstance(docs, str) and not isinstance(queries, str): 189 raise ValueError("Queries and docs must be both lists or both strings.") 190 is_string_queries = False 191 is_string_docs = False 192 if isinstance(queries, str): 193 queries = [queries] 194 is_string_queries = True 195 if isinstance(docs, str): 196 docs = [docs] 197 is_string_docs = True 198 is_string_both = is_string_queries and is_string_docs 199 num_docs = self._process_num_docs(queries, docs, num_docs) 200 queries, docs = self._preprocess(queries, docs, num_docs) 201 return_tensors = kwargs.get("return_tensors", None) 202 if return_tensors is not None: 203 kwargs["pad_to_multiple_of"] = 8 204 if is_string_both: 205 encoding = self(queries[0], docs[0], **kwargs) 206 else: 207 encoding = self(queries, docs, **kwargs) 208 return {"encoding": BatchEncoding({**encoding, "num_docs": num_docs})}