Source code for lightning_ir.bi_encoder.bi_encoder_tokenizer

  1"""
  2Tokenizer module for bi-encoder models.
  3
  4This module contains the tokenizer class bi-encoder models.
  5"""
  6
  7import warnings
  8from collections.abc import Sequence
  9from typing import Literal
 10
 11from tokenizers.processors import TemplateProcessing
 12from transformers import BatchEncoding
 13
 14from ..base import LightningIRClassFactory, LightningIRTokenizer
 15from .bi_encoder_config import BiEncoderConfig
 16
 17ADD_MARKER_TOKEN_MAPPING = {
 18    "bert": {"pattern": "[CLS] {TOKEN} $0 [SEP]", "special_tokens": ["[CLS]", "[SEP]"]},
 19    "modernbert": {"pattern": "[CLS] {TOKEN} $0 [SEP]", "special_tokens": ["[CLS]", "[SEP]"]},
 20}
 21
 22
[docs] 23class BiEncoderTokenizer(LightningIRTokenizer): 24 config_class: type[BiEncoderConfig] = BiEncoderConfig 25 """Configuration class for the tokenizer.""" 26 27 QUERY_TOKEN: str = "[QUE]" 28 """Token to mark a query sequence.""" 29 DOC_TOKEN: str = "[DOC]" 30 """Token to mark a document sequence.""" 31
[docs] 32 def __init__( 33 self, 34 *args, 35 query_length: int | None = 32, 36 doc_length: int | None = 512, 37 add_marker_tokens: bool = False, 38 **kwargs, 39 ): 40 """:class:`.LightningIRTokenizer` for bi-encoder models. Encodes queries and documents separately. Optionally 41 adds marker tokens are added to encoded input sequences. 42 43 Args: 44 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 45 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 46 add_marker_tokens (bool): Whether to add marker tokens to the query and document input sequences. 47 Defaults to False. 48 Raises: 49 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used. 50 """ 51 super().__init__( 52 *args, 53 query_length=query_length, 54 doc_length=doc_length, 55 add_marker_tokens=add_marker_tokens, 56 **kwargs, 57 ) 58 self.query_length = query_length 59 self.doc_length = doc_length 60 self.add_marker_tokens = add_marker_tokens 61 62 self.query_post_processor: TemplateProcessing | None = None 63 self.doc_post_processor: TemplateProcessing | None = None 64 if add_marker_tokens: 65 backbone_model_type = LightningIRClassFactory.get_backbone_model_type(self.name_or_path) 66 if backbone_model_type not in ADD_MARKER_TOKEN_MAPPING: 67 raise ValueError( 68 f"Adding marker tokens is not supported for the backbone model type '{backbone_model_type}'. " 69 f"Supported types are: [{', '.join(ADD_MARKER_TOKEN_MAPPING.keys())}]. " 70 "Please set `add_marker_tokens=False` " 71 "or add the backbone model type to `ADD_MARKER_TOKEN_MAPPING`." 72 ) 73 self.add_tokens([self.QUERY_TOKEN, self.DOC_TOKEN], special_tokens=True) 74 pattern = ADD_MARKER_TOKEN_MAPPING[backbone_model_type]["pattern"] 75 special_tokens = [ 76 (token, self.convert_tokens_to_ids(token)) 77 for token in ADD_MARKER_TOKEN_MAPPING[backbone_model_type]["special_tokens"] 78 ] 79 80 self.query_post_processor = TemplateProcessing( 81 single=pattern.format(TOKEN=self.QUERY_TOKEN), 82 pair=None, 83 special_tokens=special_tokens + [(self.QUERY_TOKEN, self.query_token_id)], 84 ) 85 self.doc_post_processor = TemplateProcessing( 86 single=pattern.format(TOKEN=self.DOC_TOKEN), 87 pair=None, 88 special_tokens=special_tokens + [(self.DOC_TOKEN, self.doc_token_id)], 89 )
90 91 @property 92 def query_token_id(self) -> int | None: 93 """The token id of the query token if marker tokens are added. 94 95 Return: 96 Token id of the query token if added, otherwise None. 97 """ 98 if self.QUERY_TOKEN in self.added_tokens_encoder: 99 return self.added_tokens_encoder[self.QUERY_TOKEN] 100 return None 101 102 @property 103 def doc_token_id(self) -> int | None: 104 """The token id of the document token if marker tokens are added. 105 106 Returns: 107 Token id of the document token if added, otherwise None. 108 """ 109 if self.DOC_TOKEN in self.added_tokens_encoder: 110 return self.added_tokens_encoder[self.DOC_TOKEN] 111 return None 112 113 def __call__(self, *args, warn: bool = True, **kwargs) -> BatchEncoding: 114 """Overrides the PretrainedTokenizer.__call___ method to warn the user to use :meth:`.tokenize_query` and 115 :meth:`.tokenize_doc` methods instead. 116 117 .. PretrainedTokenizer.__call__: \ 118https://huggingface.co/docs/transformers/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__ 119 120 Args: 121 text (str | Sequence[str]): Text to tokenize. 122 warn (bool): set to False to silence warning. Defaults to True. 123 Returns: 124 BatchEncoding: Tokenized text. 125 """ 126 if warn: 127 warnings.warn( 128 "BiEncoderTokenizer is being directly called. Use `tokenize`, `tokenize_query`, or `tokenize_doc` " 129 "to make sure tokenization is done correctly.", 130 stacklevel=2, 131 ) 132 return super().__call__(*args, **kwargs) 133 134 def _encode( 135 self, 136 text: str | Sequence[str], 137 *args, 138 post_processor: TemplateProcessing | None = None, 139 **kwargs, 140 ) -> BatchEncoding: 141 """Encodes text with an optional post-processor.""" 142 orig_post_processor = self._tokenizer.post_processor 143 if post_processor is not None: 144 self._tokenizer.post_processor = post_processor 145 if kwargs.get("return_tensors", None) is not None: 146 kwargs["pad_to_multiple_of"] = 8 147 encoding = self(text, *args, warn=False, **kwargs) 148 self._tokenizer.post_processor = orig_post_processor 149 return encoding 150 151 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding: 152 """Applies mask expansion to the input encoding.""" 153 input_ids = encoding["input_ids"] 154 input_ids[input_ids == self.pad_token_id] = self.mask_token_id 155 encoding["input_ids"] = input_ids 156 if attend_to_expanded_tokens: 157 encoding["attention_mask"].fill_(1) 158 return encoding 159
[docs] 160 def tokenize_input_sequence( 161 self, text: Sequence[str] | str, input_type: Literal["query", "doc"], *args, **kwargs 162 ) -> BatchEncoding: 163 """Tokenizes an input sequence. This method is used to tokenize both queries and documents. 164 165 Args: 166 text (Sequence[str] | str): Input text to tokenize. 167 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 168 Returns: 169 BatchEncoding: Tokenized input sequences. 170 """ 171 post_processer = getattr(self, f"{input_type}_post_processor") 172 if hasattr(self, f"{input_type}_length"): 173 kwargs["max_length"] = getattr(self, f"{input_type}_length") 174 if "padding" not in kwargs: 175 kwargs["truncation"] = True 176 return self._encode(text, *args, post_processor=post_processer, **kwargs)
177
[docs] 178 def tokenize_query(self, queries: Sequence[str] | str, *args, **kwargs) -> BatchEncoding: 179 """Tokenizes input queries. 180 181 Args: 182 queries (Sequence[str] | str): Query or queries to tokenize. 183 Returns: 184 BatchEncoding: Tokenized queries. 185 """ 186 encoding = self.tokenize_input_sequence(queries, "query", *args, **kwargs) 187 return encoding
188
[docs] 189 def tokenize_doc(self, docs: Sequence[str] | str, *args, **kwargs) -> BatchEncoding: 190 """Tokenizes input documents. 191 192 Args: 193 docs (Sequence[str] | str): Document or documents to tokenize. 194 Returns: 195 BatchEncoding: Tokenized documents. 196 """ 197 encoding = self.tokenize_input_sequence(docs, "doc", *args, **kwargs) 198 return encoding
199
[docs] 200 def tokenize( 201 self, 202 queries: str | Sequence[str] | None = None, 203 docs: str | Sequence[str] | None = None, 204 **kwargs, 205 ) -> dict[str, BatchEncoding]: 206 """Tokenizes queries and documents. 207 208 Args: 209 queries (str | Sequence[str] | None): Queries to tokenize. Defaults to None. 210 docs (str | Sequence[str] | None): Documents to tokenize. Defaults to None. 211 Returns: 212 dict[str, BatchEncoding]: Dictionary containing tokenized queries and documents. 213 """ 214 encodings = {} 215 kwargs.pop("num_docs", None) 216 if queries is not None: 217 encodings["query_encoding"] = self.tokenize_query(queries, **kwargs) 218 if docs is not None: 219 encodings["doc_encoding"] = self.tokenize_doc(docs, **kwargs) 220 return encodings