Source code for lightning_ir.bi_encoder.bi_encoder_tokenizer

  1"""
  2Tokenizer module for bi-encoder models.
  3
  4This module contains the tokenizer class bi-encoder models.
  5"""
  6
  7import warnings
  8from typing import Dict, Literal, Sequence, Type
  9
 10from tokenizers.processors import TemplateProcessing
 11from transformers import BatchEncoding
 12
 13from ..base import LightningIRClassFactory, LightningIRTokenizer
 14from .bi_encoder_config import BiEncoderConfig
 15
 16ADD_MARKER_TOKEN_MAPPING = {
 17    "bert": {
 18        "single": "[CLS] {TOKEN} $0 [SEP]",
 19        "pair": "[CLS] {TOKEN_1} $A [SEP] {TOKEN_2} $B:1 [SEP]:1",
 20    },
 21    "modernbert": {
 22        "single": "[CLS] {TOKEN} $0 [SEP]",
 23        "pair": "[CLS] {TOKEN_1} $A [SEP] {TOKEN_2} $B:1 [SEP]:1",
 24    },
 25}
 26
 27
[docs] 28class BiEncoderTokenizer(LightningIRTokenizer): 29 30 config_class: Type[BiEncoderConfig] = BiEncoderConfig 31 """Configuration class for the tokenizer.""" 32 33 QUERY_TOKEN: str = "[QUE]" 34 """Token to mark a query sequence.""" 35 DOC_TOKEN: str = "[DOC]" 36 """Token to mark a document sequence.""" 37
[docs] 38 def __init__( 39 self, 40 *args, 41 query_length: int = 32, 42 doc_length: int = 512, 43 add_marker_tokens: bool = False, 44 **kwargs, 45 ): 46 """:class:`.LightningIRTokenizer` for bi-encoder models. Encodes queries and documents separately. Optionally 47 adds marker tokens are added to encoded input sequences. 48 49 Args: 50 query_length (int): Maximum query length in number of tokens. Defaults to 32. 51 doc_length (int): Maximum document length in number of tokens. Defaults to 512. 52 add_marker_tokens (bool): Whether to add marker tokens to the query and document input sequences. 53 Defaults to False. 54 Raises: 55 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used. 56 """ 57 super().__init__( 58 *args, 59 query_length=query_length, 60 doc_length=doc_length, 61 add_marker_tokens=add_marker_tokens, 62 **kwargs, 63 ) 64 self.query_length = query_length 65 self.doc_length = doc_length 66 self.add_marker_tokens = add_marker_tokens 67 68 self.query_post_processor: TemplateProcessing | None = None 69 self.doc_post_processor: TemplateProcessing | None = None 70 if add_marker_tokens: 71 backbone_model_type = LightningIRClassFactory.get_backbone_model_type(self.name_or_path) 72 if backbone_model_type not in ADD_MARKER_TOKEN_MAPPING: 73 raise ValueError( 74 f"Adding marker tokens is not supported for the backbone model type '{backbone_model_type}'. " 75 f"Supported types are: [{', '.join(ADD_MARKER_TOKEN_MAPPING.keys())}]. " 76 "Please set `add_marker_tokens=False` " 77 "or add the backbone model type to `ADD_MARKER_TOKEN_MAPPING`." 78 ) 79 self.add_tokens([self.QUERY_TOKEN, self.DOC_TOKEN], special_tokens=True) 80 pattern = ADD_MARKER_TOKEN_MAPPING[backbone_model_type] 81 self.query_post_processor = TemplateProcessing( 82 single=pattern["single"].format(TOKEN=self.QUERY_TOKEN), 83 pair=pattern["pair"].format(TOKEN_1=self.QUERY_TOKEN, TOKEN_2=self.DOC_TOKEN), 84 special_tokens=[ 85 ("[CLS]", self.cls_token_id), 86 ("[SEP]", self.sep_token_id), 87 (self.QUERY_TOKEN, self.query_token_id), 88 (self.DOC_TOKEN, self.doc_token_id), 89 ], 90 ) 91 self.doc_post_processor = TemplateProcessing( 92 single=pattern["single"].format(TOKEN=self.DOC_TOKEN), 93 pair=pattern["pair"].format(TOKEN_1=self.QUERY_TOKEN, TOKEN_2=self.DOC_TOKEN), 94 special_tokens=[ 95 ("[CLS]", self.cls_token_id), 96 ("[SEP]", self.sep_token_id), 97 (self.QUERY_TOKEN, self.query_token_id), 98 (self.DOC_TOKEN, self.doc_token_id), 99 ], 100 )
101 102 @property 103 def query_token_id(self) -> int | None: 104 """The token id of the query token if marker tokens are added. 105 106 Return: 107 Token id of the query token if added, otherwise None. 108 """ 109 if self.QUERY_TOKEN in self.added_tokens_encoder: 110 return self.added_tokens_encoder[self.QUERY_TOKEN] 111 return None 112 113 @property 114 def doc_token_id(self) -> int | None: 115 """The token id of the document token if marker tokens are added. 116 117 Returns: 118 Token id of the document token if added, otherwise None. 119 """ 120 if self.DOC_TOKEN in self.added_tokens_encoder: 121 return self.added_tokens_encoder[self.DOC_TOKEN] 122 return None 123 124 def __call__(self, *args, warn: bool = True, **kwargs) -> BatchEncoding: 125 """Overrides the PretrainedTokenizer.__call___ method to warn the user to use :meth:`.tokenize_query` and 126 :meth:`.tokenize_doc` methods instead. 127 128 .. PretrainedTokenizer.__call__: \ 129https://huggingface.co/docs/transformers/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__ 130 131 Args: 132 text (str | Sequence[str]): Text to tokenize. 133 warn (bool): Set to False to silence warning. Defaults to True. 134 Returns: 135 BatchEncoding: Tokenized text. 136 """ 137 if warn: 138 warnings.warn( 139 "BiEncoderTokenizer is being directly called. Use `tokenize`, `tokenize_query`, or `tokenize_doc` " 140 "to make sure tokenization is done correctly.", 141 ) 142 return super().__call__(*args, **kwargs) 143 144 def _encode( 145 self, 146 text: str | Sequence[str], 147 *args, 148 post_processor: TemplateProcessing | None = None, 149 **kwargs, 150 ) -> BatchEncoding: 151 """Encodes text with an optional post-processor.""" 152 orig_post_processor = self._tokenizer.post_processor 153 if post_processor is not None: 154 self._tokenizer.post_processor = post_processor 155 if kwargs.get("return_tensors", None) is not None: 156 kwargs["pad_to_multiple_of"] = 8 157 encoding = self(text, *args, warn=False, **kwargs) 158 self._tokenizer.post_processor = orig_post_processor 159 return encoding 160 161 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding: 162 """Applies mask expansion to the input encoding.""" 163 input_ids = encoding["input_ids"] 164 input_ids[input_ids == self.pad_token_id] = self.mask_token_id 165 encoding["input_ids"] = input_ids 166 if attend_to_expanded_tokens: 167 encoding["attention_mask"].fill_(1) 168 return encoding 169
[docs] 170 def tokenize_input_sequence( 171 self, text: Sequence[str] | str, input_type: Literal["query", "doc"], *args, **kwargs 172 ) -> BatchEncoding: 173 """Tokenizes an input sequence. This method is used to tokenize both queries and documents. 174 175 Args: 176 text (Sequence[str] | str): Input text to tokenize. 177 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc". 178 Returns: 179 BatchEncoding: Tokenized input sequences. 180 """ 181 post_processer = getattr(self, f"{input_type}_post_processor") 182 kwargs["max_length"] = getattr(self, f"{input_type}_length") 183 if "padding" not in kwargs: 184 kwargs["truncation"] = True 185 return self._encode(text, *args, post_processor=post_processer, **kwargs)
186
[docs] 187 def tokenize_query(self, queries: Sequence[str] | str, *args, **kwargs) -> BatchEncoding: 188 """Tokenizes input queries. 189 190 Args: 191 queries (Sequence[str] | str): Query or queries to tokenize. 192 Returns: 193 BatchEncoding: Tokenized queries. 194 """ 195 encoding = self.tokenize_input_sequence(queries, "query", *args, **kwargs) 196 return encoding
197
[docs] 198 def tokenize_doc(self, docs: Sequence[str] | str, *args, **kwargs) -> BatchEncoding: 199 """Tokenizes input documents. 200 201 Args: 202 docs (Sequence[str] | str): Document or documents to tokenize. 203 Returns: 204 BatchEncoding: Tokenized documents. 205 """ 206 encoding = self.tokenize_input_sequence(docs, "doc", *args, **kwargs) 207 return encoding
208
[docs] 209 def tokenize( 210 self, 211 queries: str | Sequence[str] | None = None, 212 docs: str | Sequence[str] | None = None, 213 **kwargs, 214 ) -> Dict[str, BatchEncoding]: 215 """Tokenizes queries and documents. 216 217 Args: 218 queries (str | Sequence[str] | None): Queries to tokenize. Defaults to None. 219 docs (str | Sequence[str] | None): Documents to tokenize. Defaults to None. 220 Returns: 221 Dict[str, BatchEncoding]: Dictionary containing tokenized queries and documents. 222 """ 223 encodings = {} 224 kwargs.pop("num_docs", None) 225 if queries is not None: 226 encodings["query_encoding"] = self.tokenize_query(queries, **kwargs) 227 if docs is not None: 228 encodings["doc_encoding"] = self.tokenize_doc(docs, **kwargs) 229 return encodings