1"""
2Tokenizer module for bi-encoder models.
3
4This module contains the tokenizer class bi-encoder models.
5"""
6
7import warnings
8from typing import Dict, Literal, Sequence, Type
9
10from tokenizers.processors import TemplateProcessing
11from transformers import BatchEncoding
12
13from ..base import LightningIRClassFactory, LightningIRTokenizer
14from .bi_encoder_config import BiEncoderConfig
15
16ADD_MARKER_TOKEN_MAPPING = {
17 "bert": {
18 "single": "[CLS] {TOKEN} $0 [SEP]",
19 "pair": "[CLS] {TOKEN_1} $A [SEP] {TOKEN_2} $B:1 [SEP]:1",
20 },
21 "modernbert": {
22 "single": "[CLS] {TOKEN} $0 [SEP]",
23 "pair": "[CLS] {TOKEN_1} $A [SEP] {TOKEN_2} $B:1 [SEP]:1",
24 },
25}
26
27
[docs]
28class BiEncoderTokenizer(LightningIRTokenizer):
29
30 config_class: Type[BiEncoderConfig] = BiEncoderConfig
31 """Configuration class for the tokenizer."""
32
33 QUERY_TOKEN: str = "[QUE]"
34 """Token to mark a query sequence."""
35 DOC_TOKEN: str = "[DOC]"
36 """Token to mark a document sequence."""
37
[docs]
38 def __init__(
39 self,
40 *args,
41 query_length: int = 32,
42 doc_length: int = 512,
43 add_marker_tokens: bool = False,
44 **kwargs,
45 ):
46 """:class:`.LightningIRTokenizer` for bi-encoder models. Encodes queries and documents separately. Optionally
47 adds marker tokens are added to encoded input sequences.
48
49 Args:
50 query_length (int): Maximum query length in number of tokens. Defaults to 32.
51 doc_length (int): Maximum document length in number of tokens. Defaults to 512.
52 add_marker_tokens (bool): Whether to add marker tokens to the query and document input sequences.
53 Defaults to False.
54 Raises:
55 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used.
56 """
57 super().__init__(
58 *args,
59 query_length=query_length,
60 doc_length=doc_length,
61 add_marker_tokens=add_marker_tokens,
62 **kwargs,
63 )
64 self.query_length = query_length
65 self.doc_length = doc_length
66 self.add_marker_tokens = add_marker_tokens
67
68 self.query_post_processor: TemplateProcessing | None = None
69 self.doc_post_processor: TemplateProcessing | None = None
70 if add_marker_tokens:
71 backbone_model_type = LightningIRClassFactory.get_backbone_model_type(self.name_or_path)
72 if backbone_model_type not in ADD_MARKER_TOKEN_MAPPING:
73 raise ValueError(
74 f"Adding marker tokens is not supported for the backbone model type '{backbone_model_type}'. "
75 f"Supported types are: [{', '.join(ADD_MARKER_TOKEN_MAPPING.keys())}]. "
76 "Please set `add_marker_tokens=False` "
77 "or add the backbone model type to `ADD_MARKER_TOKEN_MAPPING`."
78 )
79 self.add_tokens([self.QUERY_TOKEN, self.DOC_TOKEN], special_tokens=True)
80 pattern = ADD_MARKER_TOKEN_MAPPING[backbone_model_type]
81 self.query_post_processor = TemplateProcessing(
82 single=pattern["single"].format(TOKEN=self.QUERY_TOKEN),
83 pair=pattern["pair"].format(TOKEN_1=self.QUERY_TOKEN, TOKEN_2=self.DOC_TOKEN),
84 special_tokens=[
85 ("[CLS]", self.cls_token_id),
86 ("[SEP]", self.sep_token_id),
87 (self.QUERY_TOKEN, self.query_token_id),
88 (self.DOC_TOKEN, self.doc_token_id),
89 ],
90 )
91 self.doc_post_processor = TemplateProcessing(
92 single=pattern["single"].format(TOKEN=self.DOC_TOKEN),
93 pair=pattern["pair"].format(TOKEN_1=self.QUERY_TOKEN, TOKEN_2=self.DOC_TOKEN),
94 special_tokens=[
95 ("[CLS]", self.cls_token_id),
96 ("[SEP]", self.sep_token_id),
97 (self.QUERY_TOKEN, self.query_token_id),
98 (self.DOC_TOKEN, self.doc_token_id),
99 ],
100 )
101
102 @property
103 def query_token_id(self) -> int | None:
104 """The token id of the query token if marker tokens are added.
105
106 Return:
107 Token id of the query token if added, otherwise None.
108 """
109 if self.QUERY_TOKEN in self.added_tokens_encoder:
110 return self.added_tokens_encoder[self.QUERY_TOKEN]
111 return None
112
113 @property
114 def doc_token_id(self) -> int | None:
115 """The token id of the document token if marker tokens are added.
116
117 Returns:
118 Token id of the document token if added, otherwise None.
119 """
120 if self.DOC_TOKEN in self.added_tokens_encoder:
121 return self.added_tokens_encoder[self.DOC_TOKEN]
122 return None
123
124 def __call__(self, *args, warn: bool = True, **kwargs) -> BatchEncoding:
125 """Overrides the PretrainedTokenizer.__call___ method to warn the user to use :meth:`.tokenize_query` and
126 :meth:`.tokenize_doc` methods instead.
127
128 .. PretrainedTokenizer.__call__: \
129https://huggingface.co/docs/transformers/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__
130
131 Args:
132 text (str | Sequence[str]): Text to tokenize.
133 warn (bool): Set to False to silence warning. Defaults to True.
134 Returns:
135 BatchEncoding: Tokenized text.
136 """
137 if warn:
138 warnings.warn(
139 "BiEncoderTokenizer is being directly called. Use `tokenize`, `tokenize_query`, or `tokenize_doc` "
140 "to make sure tokenization is done correctly.",
141 )
142 return super().__call__(*args, **kwargs)
143
144 def _encode(
145 self,
146 text: str | Sequence[str],
147 *args,
148 post_processor: TemplateProcessing | None = None,
149 **kwargs,
150 ) -> BatchEncoding:
151 """Encodes text with an optional post-processor."""
152 orig_post_processor = self._tokenizer.post_processor
153 if post_processor is not None:
154 self._tokenizer.post_processor = post_processor
155 if kwargs.get("return_tensors", None) is not None:
156 kwargs["pad_to_multiple_of"] = 8
157 encoding = self(text, *args, warn=False, **kwargs)
158 self._tokenizer.post_processor = orig_post_processor
159 return encoding
160
161 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding:
162 """Applies mask expansion to the input encoding."""
163 input_ids = encoding["input_ids"]
164 input_ids[input_ids == self.pad_token_id] = self.mask_token_id
165 encoding["input_ids"] = input_ids
166 if attend_to_expanded_tokens:
167 encoding["attention_mask"].fill_(1)
168 return encoding
169
186
[docs]
187 def tokenize_query(self, queries: Sequence[str] | str, *args, **kwargs) -> BatchEncoding:
188 """Tokenizes input queries.
189
190 Args:
191 queries (Sequence[str] | str): Query or queries to tokenize.
192 Returns:
193 BatchEncoding: Tokenized queries.
194 """
195 encoding = self.tokenize_input_sequence(queries, "query", *args, **kwargs)
196 return encoding
197
[docs]
198 def tokenize_doc(self, docs: Sequence[str] | str, *args, **kwargs) -> BatchEncoding:
199 """Tokenizes input documents.
200
201 Args:
202 docs (Sequence[str] | str): Document or documents to tokenize.
203 Returns:
204 BatchEncoding: Tokenized documents.
205 """
206 encoding = self.tokenize_input_sequence(docs, "doc", *args, **kwargs)
207 return encoding
208
[docs]
209 def tokenize(
210 self,
211 queries: str | Sequence[str] | None = None,
212 docs: str | Sequence[str] | None = None,
213 **kwargs,
214 ) -> Dict[str, BatchEncoding]:
215 """Tokenizes queries and documents.
216
217 Args:
218 queries (str | Sequence[str] | None): Queries to tokenize. Defaults to None.
219 docs (str | Sequence[str] | None): Documents to tokenize. Defaults to None.
220 Returns:
221 Dict[str, BatchEncoding]: Dictionary containing tokenized queries and documents.
222 """
223 encodings = {}
224 kwargs.pop("num_docs", None)
225 if queries is not None:
226 encodings["query_encoding"] = self.tokenize_query(queries, **kwargs)
227 if docs is not None:
228 encodings["doc_encoding"] = self.tokenize_doc(docs, **kwargs)
229 return encodings