1"""
2Tokenizer module for bi-encoder models.
3
4This module contains the tokenizer class bi-encoder models.
5"""
6
7import warnings
8from typing import Dict, Literal, Sequence, Type
9
10from tokenizers.processors import TemplateProcessing
11from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast
12
13from ..base import LightningIRTokenizer
14from .bi_encoder_config import BiEncoderConfig
15
16
[docs]
17class BiEncoderTokenizer(LightningIRTokenizer):
18
19 config_class: Type[BiEncoderConfig] = BiEncoderConfig
20 """Configuration class for the tokenizer."""
21
22 QUERY_TOKEN: str = "[QUE]"
23 """Token to mark a query sequence."""
24 DOC_TOKEN: str = "[DOC]"
25 """Token to mark a document sequence."""
26
[docs]
27 def __init__(
28 self,
29 *args,
30 query_length: int = 32,
31 doc_length: int = 512,
32 add_marker_tokens: bool = False,
33 **kwargs,
34 ):
35 """:class:`.LightningIRTokenizer` for bi-encoder models. Encodes queries and documents separately. Optionally
36 adds marker tokens are added to encoded input sequences.
37
38 :param query_length: Maximum query length in number of tokens, defaults to 32
39 :type query_length: int, optional
40 :param doc_length: Maximum document length in number of tokens, defaults to 512
41 :type doc_length: int, optional
42 :param add_marker_tokens: Whether to add marker tokens to the query and document input sequences,
43 defaults to False
44 :type add_marker_tokens: bool, optional
45 :raises ValueError: If add_marker_tokens is True and a non-supported tokenizer is used
46 """
47 super().__init__(
48 *args,
49 query_length=query_length,
50 doc_length=doc_length,
51 add_marker_tokens=add_marker_tokens,
52 **kwargs,
53 )
54 self.query_length = query_length
55 self.doc_length = doc_length
56 self.add_marker_tokens = add_marker_tokens
57
58 self.query_post_processor: TemplateProcessing | None = None
59 self.doc_post_processor: TemplateProcessing | None = None
60 if add_marker_tokens:
61 # TODO support other tokenizers
62 if not isinstance(self, (BertTokenizer, BertTokenizerFast)):
63 warnings.warn(f"Adding marker tokens may not be supported for {type(self)}.")
64 self.add_tokens([self.QUERY_TOKEN, self.DOC_TOKEN], special_tokens=True)
65 self.query_post_processor = TemplateProcessing(
66 single=f"[CLS] {self.QUERY_TOKEN} $0 [SEP]",
67 pair=f"[CLS] {self.QUERY_TOKEN} $A [SEP] {self.DOC_TOKEN} $B:1 [SEP]:1",
68 special_tokens=[
69 ("[CLS]", self.cls_token_id),
70 ("[SEP]", self.sep_token_id),
71 (self.QUERY_TOKEN, self.query_token_id),
72 (self.DOC_TOKEN, self.doc_token_id),
73 ],
74 )
75 self.doc_post_processor = TemplateProcessing(
76 single=f"[CLS] {self.DOC_TOKEN} $0 [SEP]",
77 pair=f"[CLS] {self.QUERY_TOKEN} $A [SEP] {self.DOC_TOKEN} $B:1 [SEP]:1",
78 special_tokens=[
79 ("[CLS]", self.cls_token_id),
80 ("[SEP]", self.sep_token_id),
81 (self.QUERY_TOKEN, self.query_token_id),
82 (self.DOC_TOKEN, self.doc_token_id),
83 ],
84 )
85
86 @property
87 def query_token_id(self) -> int | None:
88 """The token id of the query token if marker tokens are added.
89
90 :return: Token id of the query token
91 :rtype: int | None
92 """
93 if self.QUERY_TOKEN in self.added_tokens_encoder:
94 return self.added_tokens_encoder[self.QUERY_TOKEN]
95 return None
96
97 @property
98 def doc_token_id(self) -> int | None:
99 """The token id of the document token if marker tokens are added.
100
101 :return: Token id of the document token
102 :rtype: int | None
103 """
104 if self.DOC_TOKEN in self.added_tokens_encoder:
105 return self.added_tokens_encoder[self.DOC_TOKEN]
106 return None
107
108 def __call__(self, *args, warn: bool = True, **kwargs) -> BatchEncoding:
109 """Overrides the PretrainedTokenizer.__call___ method to warn the user to use :meth:`.tokenize_query` and
110 :meth:`.tokenize_doc` methods instead.
111
112 .. PretrainedTokenizer.__call__: \
113https://huggingface.co/docs/transformers/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__
114
115 :param text: Text to tokenize
116 :type text: str | Sequence[str]
117 :param warn: Set to false to silence warning, defaults to True
118 :type warn: bool, optional
119 :return: Tokenized text
120 :rtype: BatchEncoding
121 """
122 if warn:
123 warnings.warn(
124 "BiEncoderTokenizer is being directly called. Use `tokenize`, `tokenize_query`, or `tokenize_doc` "
125 "to make sure tokenization is done correctly.",
126 )
127 return super().__call__(*args, **kwargs)
128
129 def _encode(
130 self,
131 text: str | Sequence[str],
132 *args,
133 post_processor: TemplateProcessing | None = None,
134 **kwargs,
135 ) -> BatchEncoding:
136 """Encodes text with an optional post-processor."""
137 orig_post_processor = self._tokenizer.post_processor
138 if post_processor is not None:
139 self._tokenizer.post_processor = post_processor
140 if kwargs.get("return_tensors", None) is not None:
141 kwargs["pad_to_multiple_of"] = 8
142 encoding = self(text, *args, warn=False, **kwargs)
143 self._tokenizer.post_processor = orig_post_processor
144 return encoding
145
146 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding:
147 """Applies mask expansion to the input encoding."""
148 input_ids = encoding["input_ids"]
149 input_ids[input_ids == self.pad_token_id] = self.mask_token_id
150 encoding["input_ids"] = input_ids
151 if attend_to_expanded_tokens:
152 encoding["attention_mask"].fill_(1)
153 return encoding
154
170
[docs]
171 def tokenize_query(self, queries: Sequence[str] | str, *args, **kwargs) -> BatchEncoding:
172 """Tokenizes input queries.
173
174 :param queries: Query or queries to tokenize
175 :type queries: Sequence[str] | str
176 :return: Tokenized queries
177 :rtype: BatchEncoding
178 """
179 encoding = self.tokenize_input_sequence(queries, "query", *args, **kwargs)
180 return encoding
181
[docs]
182 def tokenize_doc(self, docs: Sequence[str] | str, *args, **kwargs) -> BatchEncoding:
183 """Tokenizes input documents.
184
185 :param docs: Document or documents to tokenize
186 :type docs: Sequence[str] | str
187 :return: Tokenized documents
188 :rtype: BatchEncoding
189 """
190 encoding = self.tokenize_input_sequence(docs, "doc", *args, **kwargs)
191 return encoding
192
[docs]
193 def tokenize(
194 self,
195 queries: str | Sequence[str] | None = None,
196 docs: str | Sequence[str] | None = None,
197 **kwargs,
198 ) -> Dict[str, BatchEncoding]:
199 """Tokenizes queries and documents.
200
201 :param queries: Queries to tokenize, defaults to None
202 :type queries: str | Sequence[str] | None, optional
203 :param docs: Documents to tokenize, defaults to None
204 :type docs: str | Sequence[str] | None, optional
205 :return: Dictionary of tokenized queries and documents
206 :rtype: Dict[str, BatchEncoding]
207 """
208 encodings = {}
209 kwargs.pop("num_docs", None)
210 if queries is not None:
211 encodings["query_encoding"] = self.tokenize_query(queries, **kwargs)
212 if docs is not None:
213 encodings["doc_encoding"] = self.tokenize_doc(docs, **kwargs)
214 return encodings