Source code for lightning_ir.bi_encoder.bi_encoder_tokenizer

  1"""
  2Tokenizer module for bi-encoder models.
  3
  4This module contains the tokenizer class bi-encoder models.
  5"""
  6
  7import warnings
  8from typing import Dict, Literal, Sequence, Type
  9
 10from tokenizers.processors import TemplateProcessing
 11from transformers import BatchEncoding
 12
 13from ..base import LightningIRClassFactory, LightningIRTokenizer
 14from .bi_encoder_config import BiEncoderConfig
 15
 16ADD_MARKER_TOKEN_MAPPING = {
 17    "bert": {"pattern": "[CLS] {TOKEN} $0 [SEP]", "special_tokens": ["[CLS]", "[SEP]"]},
 18    "modernbert": {"pattern": "[CLS] {TOKEN} $0 [SEP]", "special_tokens": ["[CLS]", "[SEP]"]},
 19}
 20
 21
[docs] 22class BiEncoderTokenizer(LightningIRTokenizer): 23 24 config_class: Type[BiEncoderConfig] = BiEncoderConfig 25 """Configuration class for the tokenizer.""" 26 27 QUERY_TOKEN: str = "[QUE]" 28 """Token to mark a query sequence.""" 29 DOC_TOKEN: str = "[DOC]" 30 """Token to mark a document sequence.""" 31
[docs] 32 def __init__( 33 self, 34 *args, 35 query_length: int | None = 32, 36 doc_length: int | None = 512, 37 add_marker_tokens: bool = False, 38 **kwargs, 39 ): 40 """:class:`.LightningIRTokenizer` for bi-encoder models. Encodes queries and documents separately. Optionally 41 adds marker tokens are added to encoded input sequences. 42 43 Args: 44 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 45 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 46 add_marker_tokens (bool): Whether to add marker tokens to the query and document input sequences. 47 Defaults to False. 48 Raises: 49 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used. 50 """ 51 super().__init__( 52 *args, 53 query_length=query_length, 54 doc_length=doc_length, 55 add_marker_tokens=add_marker_tokens, 56 **kwargs, 57 ) 58 self.query_length = query_length 59 self.doc_length = doc_length 60 self.add_marker_tokens = add_marker_tokens 61 62 self.query_post_processor: TemplateProcessing | None = None 63 self.doc_post_processor: TemplateProcessing | None = None 64 if add_marker_tokens: 65 backbone_model_type = LightningIRClassFactory.get_backbone_model_type(self.name_or_path) 66 if backbone_model_type not in ADD_MARKER_TOKEN_MAPPING: 67 raise ValueError( 68 f"Adding marker tokens is not supported for the backbone model type '{backbone_model_type}'. " 69 f"Supported types are: [{', '.join(ADD_MARKER_TOKEN_MAPPING.keys())}]. " 70 "Please set `add_marker_tokens=False` " 71 "or add the backbone model type to `ADD_MARKER_TOKEN_MAPPING`." 72 ) 73 self.add_tokens([self.QUERY_TOKEN, self.DOC_TOKEN], special_tokens=True) 74 pattern = ADD_MARKER_TOKEN_MAPPING[backbone_model_type]["pattern"] 75 special_tokens = [ 76 (token, self.convert_tokens_to_ids(token)) 77 for token in ADD_MARKER_TOKEN_MAPPING[backbone_model_type]["special_tokens"] 78 ] 79 80 self.query_post_processor = TemplateProcessing( 81 single=pattern.format(TOKEN=self.QUERY_TOKEN), 82 pair=None, 83 special_tokens=special_tokens + [(self.QUERY_TOKEN, self.query_token_id)], 84 ) 85 self.doc_post_processor = TemplateProcessing( 86 single=pattern.format(TOKEN=self.DOC_TOKEN), 87 pair=None, 88 special_tokens=special_tokens + [(self.DOC_TOKEN, self.doc_token_id)], 89 )
90 91 @property 92 def query_token_id(self) -> int | None: 93 """The token id of the query token if marker tokens are added. 94 95 Return: 96 Token id of the query token if added, otherwise None. 97 """ 98 if self.QUERY_TOKEN in self.added_tokens_encoder: 99 return self.added_tokens_encoder[self.QUERY_TOKEN] 100 return None 101 102 @property 103 def doc_token_id(self) -> int | None: 104 """The token id of the document token if marker tokens are added. 105 106 Returns: 107 Token id of the document token if added, otherwise None. 108 """ 109 if self.DOC_TOKEN in self.added_tokens_encoder: 110 return self.added_tokens_encoder[self.DOC_TOKEN] 111 return None 112 113 def __call__(self, *args, warn: bool = True, **kwargs) -> BatchEncoding: 114 """Overrides the PretrainedTokenizer.__call___ method to warn the user to use :meth:`.tokenize_query` and 115 :meth:`.tokenize_doc` methods instead. 116 117 .. PretrainedTokenizer.__call__: \ 118https://huggingface.co/docs/transformers/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__ 119 120 Args: 121 text (str | Sequence[str]): Text to tokenize. 122 warn (bool): Set to False to silence warning. Defaults to True. 123 Returns: 124 BatchEncoding: Tokenized text. 125 """ 126 if warn: 127 warnings.warn( 128 "BiEncoderTokenizer is being directly called. Use `tokenize`, `tokenize_query`, or `tokenize_doc` " 129 "to make sure tokenization is done correctly.", 130 ) 131 return super().__call__(*args, **kwargs) 132 133 def _encode( 134 self, 135 text: str | Sequence[str], 136 *args, 137 post_processor: TemplateProcessing | None = None, 138 **kwargs, 139 ) -> BatchEncoding: 140 """Encodes text with an optional post-processor.""" 141 orig_post_processor = self._tokenizer.post_processor 142 if post_processor is not None: 143 self._tokenizer.post_processor = post_processor 144 if kwargs.get("return_tensors", None) is not None: 145 kwargs["pad_to_multiple_of"] = 8 146 encoding = self(text, *args, warn=False, **kwargs) 147 self._tokenizer.post_processor = orig_post_processor 148 return encoding 149 150 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding: 151 """Applies mask expansion to the input encoding.""" 152 input_ids = encoding["input_ids"] 153 input_ids[input_ids == self.pad_token_id] = self.mask_token_id 154 encoding["input_ids"] = input_ids 155 if attend_to_expanded_tokens: 156 encoding["attention_mask"].fill_(1) 157 return encoding 158
[docs] 159 def tokenize_input_sequence( 160 self, text: Sequence[str] | str, input_type: Literal["query", "doc"], *args, **kwargs 161 ) -> BatchEncoding: 162 """Tokenizes an input sequence. This method is used to tokenize both queries and documents. 163 164 Args: 165 text (Sequence[str] | str): Input text to tokenize. 166 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc". 167 Returns: 168 BatchEncoding: Tokenized input sequences. 169 """ 170 post_processer = getattr(self, f"{input_type}_post_processor") 171 if hasattr(self, f"{input_type}_length"): 172 kwargs["max_length"] = getattr(self, f"{input_type}_length") 173 if "padding" not in kwargs: 174 kwargs["truncation"] = True 175 return self._encode(text, *args, post_processor=post_processer, **kwargs)
176
[docs] 177 def tokenize_query(self, queries: Sequence[str] | str, *args, **kwargs) -> BatchEncoding: 178 """Tokenizes input queries. 179 180 Args: 181 queries (Sequence[str] | str): Query or queries to tokenize. 182 Returns: 183 BatchEncoding: Tokenized queries. 184 """ 185 encoding = self.tokenize_input_sequence(queries, "query", *args, **kwargs) 186 return encoding
187
[docs] 188 def tokenize_doc(self, docs: Sequence[str] | str, *args, **kwargs) -> BatchEncoding: 189 """Tokenizes input documents. 190 191 Args: 192 docs (Sequence[str] | str): Document or documents to tokenize. 193 Returns: 194 BatchEncoding: Tokenized documents. 195 """ 196 encoding = self.tokenize_input_sequence(docs, "doc", *args, **kwargs) 197 return encoding
198
[docs] 199 def tokenize( 200 self, 201 queries: str | Sequence[str] | None = None, 202 docs: str | Sequence[str] | None = None, 203 **kwargs, 204 ) -> Dict[str, BatchEncoding]: 205 """Tokenizes queries and documents. 206 207 Args: 208 queries (str | Sequence[str] | None): Queries to tokenize. Defaults to None. 209 docs (str | Sequence[str] | None): Documents to tokenize. Defaults to None. 210 Returns: 211 Dict[str, BatchEncoding]: Dictionary containing tokenized queries and documents. 212 """ 213 encodings = {} 214 kwargs.pop("num_docs", None) 215 if queries is not None: 216 encodings["query_encoding"] = self.tokenize_query(queries, **kwargs) 217 if docs is not None: 218 encodings["doc_encoding"] = self.tokenize_doc(docs, **kwargs) 219 return encodings