Source code for lightning_ir.bi_encoder.bi_encoder_tokenizer

  1"""
  2Tokenizer module for bi-encoder models.
  3
  4This module contains the tokenizer class bi-encoder models.
  5"""
  6
  7import warnings
  8from typing import Dict, Literal, Sequence, Type
  9
 10from tokenizers.processors import TemplateProcessing
 11from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast
 12
 13from ..base import LightningIRTokenizer
 14from .bi_encoder_config import BiEncoderConfig
 15
 16
[docs] 17class BiEncoderTokenizer(LightningIRTokenizer): 18 19 config_class: Type[BiEncoderConfig] = BiEncoderConfig 20 """Configuration class for the tokenizer.""" 21 22 QUERY_TOKEN: str = "[QUE]" 23 """Token to mark a query sequence.""" 24 DOC_TOKEN: str = "[DOC]" 25 """Token to mark a document sequence.""" 26
[docs] 27 def __init__( 28 self, 29 *args, 30 query_length: int = 32, 31 doc_length: int = 512, 32 add_marker_tokens: bool = False, 33 **kwargs, 34 ): 35 """:class:`.LightningIRTokenizer` for bi-encoder models. Encodes queries and documents separately. Optionally 36 adds marker tokens are added to encoded input sequences. 37 38 :param query_length: Maximum query length in number of tokens, defaults to 32 39 :type query_length: int, optional 40 :param doc_length: Maximum document length in number of tokens, defaults to 512 41 :type doc_length: int, optional 42 :param add_marker_tokens: Whether to add marker tokens to the query and document input sequences, 43 defaults to False 44 :type add_marker_tokens: bool, optional 45 :raises ValueError: If add_marker_tokens is True and a non-supported tokenizer is used 46 """ 47 super().__init__( 48 *args, 49 query_length=query_length, 50 doc_length=doc_length, 51 add_marker_tokens=add_marker_tokens, 52 **kwargs, 53 ) 54 self.query_length = query_length 55 self.doc_length = doc_length 56 self.add_marker_tokens = add_marker_tokens 57 58 self.query_post_processor: TemplateProcessing | None = None 59 self.doc_post_processor: TemplateProcessing | None = None 60 if add_marker_tokens: 61 # TODO support other tokenizers 62 if not isinstance(self, (BertTokenizer, BertTokenizerFast)): 63 warnings.warn(f"Adding marker tokens may not be supported for {type(self)}.") 64 self.add_tokens([self.QUERY_TOKEN, self.DOC_TOKEN], special_tokens=True) 65 self.query_post_processor = TemplateProcessing( 66 single=f"[CLS] {self.QUERY_TOKEN} $0 [SEP]", 67 pair=f"[CLS] {self.QUERY_TOKEN} $A [SEP] {self.DOC_TOKEN} $B:1 [SEP]:1", 68 special_tokens=[ 69 ("[CLS]", self.cls_token_id), 70 ("[SEP]", self.sep_token_id), 71 (self.QUERY_TOKEN, self.query_token_id), 72 (self.DOC_TOKEN, self.doc_token_id), 73 ], 74 ) 75 self.doc_post_processor = TemplateProcessing( 76 single=f"[CLS] {self.DOC_TOKEN} $0 [SEP]", 77 pair=f"[CLS] {self.QUERY_TOKEN} $A [SEP] {self.DOC_TOKEN} $B:1 [SEP]:1", 78 special_tokens=[ 79 ("[CLS]", self.cls_token_id), 80 ("[SEP]", self.sep_token_id), 81 (self.QUERY_TOKEN, self.query_token_id), 82 (self.DOC_TOKEN, self.doc_token_id), 83 ], 84 )
85 86 @property 87 def query_token_id(self) -> int | None: 88 """The token id of the query token if marker tokens are added. 89 90 :return: Token id of the query token 91 :rtype: int | None 92 """ 93 if self.QUERY_TOKEN in self.added_tokens_encoder: 94 return self.added_tokens_encoder[self.QUERY_TOKEN] 95 return None 96 97 @property 98 def doc_token_id(self) -> int | None: 99 """The token id of the document token if marker tokens are added. 100 101 :return: Token id of the document token 102 :rtype: int | None 103 """ 104 if self.DOC_TOKEN in self.added_tokens_encoder: 105 return self.added_tokens_encoder[self.DOC_TOKEN] 106 return None 107 108 def __call__(self, *args, warn: bool = True, **kwargs) -> BatchEncoding: 109 """Overrides the PretrainedTokenizer.__call___ method to warn the user to use :meth:`.tokenize_query` and 110 :meth:`.tokenize_doc` methods instead. 111 112 .. PretrainedTokenizer.__call__: \ 113https://huggingface.co/docs/transformers/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__ 114 115 :param text: Text to tokenize 116 :type text: str | Sequence[str] 117 :param warn: Set to false to silence warning, defaults to True 118 :type warn: bool, optional 119 :return: Tokenized text 120 :rtype: BatchEncoding 121 """ 122 if warn: 123 warnings.warn( 124 "BiEncoderTokenizer is being directly called. Use `tokenize`, `tokenize_query`, or `tokenize_doc` " 125 "to make sure tokenization is done correctly.", 126 ) 127 return super().__call__(*args, **kwargs) 128 129 def _encode( 130 self, 131 text: str | Sequence[str], 132 *args, 133 post_processor: TemplateProcessing | None = None, 134 **kwargs, 135 ) -> BatchEncoding: 136 """Encodes text with an optional post-processor.""" 137 orig_post_processor = self._tokenizer.post_processor 138 if post_processor is not None: 139 self._tokenizer.post_processor = post_processor 140 if kwargs.get("return_tensors", None) is not None: 141 kwargs["pad_to_multiple_of"] = 8 142 encoding = self(text, *args, warn=False, **kwargs) 143 self._tokenizer.post_processor = orig_post_processor 144 return encoding 145 146 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding: 147 """Applies mask expansion to the input encoding.""" 148 input_ids = encoding["input_ids"] 149 input_ids[input_ids == self.pad_token_id] = self.mask_token_id 150 encoding["input_ids"] = input_ids 151 if attend_to_expanded_tokens: 152 encoding["attention_mask"].fill_(1) 153 return encoding 154
[docs] 155 def tokenize_input_sequence( 156 self, text: Sequence[str] | str, input_type: Literal["query", "doc"], *args, **kwargs 157 ) -> BatchEncoding: 158 """Tokenizes an input sequence. This method is used to tokenize both queries and documents. 159 160 :param queries: Single string or multiple strings to tokenize 161 :type queries: Sequence[str] | str 162 :return: Tokenized input sequences 163 :rtype: BatchEncoding 164 """ 165 post_processer = getattr(self, f"{input_type}_post_processor") 166 kwargs["max_length"] = getattr(self, f"{input_type}_length") 167 if "padding" not in kwargs: 168 kwargs["truncation"] = True 169 return self._encode(text, *args, post_processor=post_processer, **kwargs)
170
[docs] 171 def tokenize_query(self, queries: Sequence[str] | str, *args, **kwargs) -> BatchEncoding: 172 """Tokenizes input queries. 173 174 :param queries: Query or queries to tokenize 175 :type queries: Sequence[str] | str 176 :return: Tokenized queries 177 :rtype: BatchEncoding 178 """ 179 encoding = self.tokenize_input_sequence(queries, "query", *args, **kwargs) 180 return encoding
181
[docs] 182 def tokenize_doc(self, docs: Sequence[str] | str, *args, **kwargs) -> BatchEncoding: 183 """Tokenizes input documents. 184 185 :param docs: Document or documents to tokenize 186 :type docs: Sequence[str] | str 187 :return: Tokenized documents 188 :rtype: BatchEncoding 189 """ 190 encoding = self.tokenize_input_sequence(docs, "doc", *args, **kwargs) 191 return encoding
192
[docs] 193 def tokenize( 194 self, 195 queries: str | Sequence[str] | None = None, 196 docs: str | Sequence[str] | None = None, 197 **kwargs, 198 ) -> Dict[str, BatchEncoding]: 199 """Tokenizes queries and documents. 200 201 :param queries: Queries to tokenize, defaults to None 202 :type queries: str | Sequence[str] | None, optional 203 :param docs: Documents to tokenize, defaults to None 204 :type docs: str | Sequence[str] | None, optional 205 :return: Dictionary of tokenized queries and documents 206 :rtype: Dict[str, BatchEncoding] 207 """ 208 encodings = {} 209 kwargs.pop("num_docs", None) 210 if queries is not None: 211 encodings["query_encoding"] = self.tokenize_query(queries, **kwargs) 212 if docs is not None: 213 encodings["doc_encoding"] = self.tokenize_doc(docs, **kwargs) 214 return encodings