1"""
2Tokenizer module for bi-encoder models.
3
4This module contains the tokenizer class bi-encoder models.
5"""
6
7import warnings
8from typing import Dict, Literal, Sequence, Type
9
10from tokenizers.processors import TemplateProcessing
11from transformers import BatchEncoding
12
13from ..base import LightningIRClassFactory, LightningIRTokenizer
14from .bi_encoder_config import BiEncoderConfig
15
16ADD_MARKER_TOKEN_MAPPING = {
17 "bert": {"pattern": "[CLS] {TOKEN} $0 [SEP]", "special_tokens": ["[CLS]", "[SEP]"]},
18 "modernbert": {"pattern": "[CLS] {TOKEN} $0 [SEP]", "special_tokens": ["[CLS]", "[SEP]"]},
19}
20
21
[docs]
22class BiEncoderTokenizer(LightningIRTokenizer):
23
24 config_class: Type[BiEncoderConfig] = BiEncoderConfig
25 """Configuration class for the tokenizer."""
26
27 QUERY_TOKEN: str = "[QUE]"
28 """Token to mark a query sequence."""
29 DOC_TOKEN: str = "[DOC]"
30 """Token to mark a document sequence."""
31
[docs]
32 def __init__(
33 self,
34 *args,
35 query_length: int | None = 32,
36 doc_length: int | None = 512,
37 add_marker_tokens: bool = False,
38 **kwargs,
39 ):
40 """:class:`.LightningIRTokenizer` for bi-encoder models. Encodes queries and documents separately. Optionally
41 adds marker tokens are added to encoded input sequences.
42
43 Args:
44 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
45 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
46 add_marker_tokens (bool): Whether to add marker tokens to the query and document input sequences.
47 Defaults to False.
48 Raises:
49 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used.
50 """
51 super().__init__(
52 *args,
53 query_length=query_length,
54 doc_length=doc_length,
55 add_marker_tokens=add_marker_tokens,
56 **kwargs,
57 )
58 self.query_length = query_length
59 self.doc_length = doc_length
60 self.add_marker_tokens = add_marker_tokens
61
62 self.query_post_processor: TemplateProcessing | None = None
63 self.doc_post_processor: TemplateProcessing | None = None
64 if add_marker_tokens:
65 backbone_model_type = LightningIRClassFactory.get_backbone_model_type(self.name_or_path)
66 if backbone_model_type not in ADD_MARKER_TOKEN_MAPPING:
67 raise ValueError(
68 f"Adding marker tokens is not supported for the backbone model type '{backbone_model_type}'. "
69 f"Supported types are: [{', '.join(ADD_MARKER_TOKEN_MAPPING.keys())}]. "
70 "Please set `add_marker_tokens=False` "
71 "or add the backbone model type to `ADD_MARKER_TOKEN_MAPPING`."
72 )
73 self.add_tokens([self.QUERY_TOKEN, self.DOC_TOKEN], special_tokens=True)
74 pattern = ADD_MARKER_TOKEN_MAPPING[backbone_model_type]["pattern"]
75 special_tokens = [
76 (token, self.convert_tokens_to_ids(token))
77 for token in ADD_MARKER_TOKEN_MAPPING[backbone_model_type]["special_tokens"]
78 ]
79
80 self.query_post_processor = TemplateProcessing(
81 single=pattern.format(TOKEN=self.QUERY_TOKEN),
82 pair=None,
83 special_tokens=special_tokens + [(self.QUERY_TOKEN, self.query_token_id)],
84 )
85 self.doc_post_processor = TemplateProcessing(
86 single=pattern.format(TOKEN=self.DOC_TOKEN),
87 pair=None,
88 special_tokens=special_tokens + [(self.DOC_TOKEN, self.doc_token_id)],
89 )
90
91 @property
92 def query_token_id(self) -> int | None:
93 """The token id of the query token if marker tokens are added.
94
95 Return:
96 Token id of the query token if added, otherwise None.
97 """
98 if self.QUERY_TOKEN in self.added_tokens_encoder:
99 return self.added_tokens_encoder[self.QUERY_TOKEN]
100 return None
101
102 @property
103 def doc_token_id(self) -> int | None:
104 """The token id of the document token if marker tokens are added.
105
106 Returns:
107 Token id of the document token if added, otherwise None.
108 """
109 if self.DOC_TOKEN in self.added_tokens_encoder:
110 return self.added_tokens_encoder[self.DOC_TOKEN]
111 return None
112
113 def __call__(self, *args, warn: bool = True, **kwargs) -> BatchEncoding:
114 """Overrides the PretrainedTokenizer.__call___ method to warn the user to use :meth:`.tokenize_query` and
115 :meth:`.tokenize_doc` methods instead.
116
117 .. PretrainedTokenizer.__call__: \
118https://huggingface.co/docs/transformers/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__
119
120 Args:
121 text (str | Sequence[str]): Text to tokenize.
122 warn (bool): Set to False to silence warning. Defaults to True.
123 Returns:
124 BatchEncoding: Tokenized text.
125 """
126 if warn:
127 warnings.warn(
128 "BiEncoderTokenizer is being directly called. Use `tokenize`, `tokenize_query`, or `tokenize_doc` "
129 "to make sure tokenization is done correctly.",
130 )
131 return super().__call__(*args, **kwargs)
132
133 def _encode(
134 self,
135 text: str | Sequence[str],
136 *args,
137 post_processor: TemplateProcessing | None = None,
138 **kwargs,
139 ) -> BatchEncoding:
140 """Encodes text with an optional post-processor."""
141 orig_post_processor = self._tokenizer.post_processor
142 if post_processor is not None:
143 self._tokenizer.post_processor = post_processor
144 if kwargs.get("return_tensors", None) is not None:
145 kwargs["pad_to_multiple_of"] = 8
146 encoding = self(text, *args, warn=False, **kwargs)
147 self._tokenizer.post_processor = orig_post_processor
148 return encoding
149
150 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding:
151 """Applies mask expansion to the input encoding."""
152 input_ids = encoding["input_ids"]
153 input_ids[input_ids == self.pad_token_id] = self.mask_token_id
154 encoding["input_ids"] = input_ids
155 if attend_to_expanded_tokens:
156 encoding["attention_mask"].fill_(1)
157 return encoding
158
176
[docs]
177 def tokenize_query(self, queries: Sequence[str] | str, *args, **kwargs) -> BatchEncoding:
178 """Tokenizes input queries.
179
180 Args:
181 queries (Sequence[str] | str): Query or queries to tokenize.
182 Returns:
183 BatchEncoding: Tokenized queries.
184 """
185 encoding = self.tokenize_input_sequence(queries, "query", *args, **kwargs)
186 return encoding
187
[docs]
188 def tokenize_doc(self, docs: Sequence[str] | str, *args, **kwargs) -> BatchEncoding:
189 """Tokenizes input documents.
190
191 Args:
192 docs (Sequence[str] | str): Document or documents to tokenize.
193 Returns:
194 BatchEncoding: Tokenized documents.
195 """
196 encoding = self.tokenize_input_sequence(docs, "doc", *args, **kwargs)
197 return encoding
198
[docs]
199 def tokenize(
200 self,
201 queries: str | Sequence[str] | None = None,
202 docs: str | Sequence[str] | None = None,
203 **kwargs,
204 ) -> Dict[str, BatchEncoding]:
205 """Tokenizes queries and documents.
206
207 Args:
208 queries (str | Sequence[str] | None): Queries to tokenize. Defaults to None.
209 docs (str | Sequence[str] | None): Documents to tokenize. Defaults to None.
210 Returns:
211 Dict[str, BatchEncoding]: Dictionary containing tokenized queries and documents.
212 """
213 encodings = {}
214 kwargs.pop("num_docs", None)
215 if queries is not None:
216 encodings["query_encoding"] = self.tokenize_query(queries, **kwargs)
217 if docs is not None:
218 encodings["doc_encoding"] = self.tokenize_doc(docs, **kwargs)
219 return encodings