1"""
2Tokenizer module for bi-encoder models.
3
4This module contains the tokenizer class bi-encoder models.
5"""
6
7import warnings
8from collections.abc import Sequence
9from typing import Literal
10
11from tokenizers.processors import TemplateProcessing
12from transformers import BatchEncoding
13
14from ..base import LightningIRClassFactory, LightningIRTokenizer
15from .bi_encoder_config import BiEncoderConfig
16
17ADD_MARKER_TOKEN_MAPPING = {
18 "bert": {"pattern": "[CLS] {TOKEN} $0 [SEP]", "special_tokens": ["[CLS]", "[SEP]"]},
19 "modernbert": {"pattern": "[CLS] {TOKEN} $0 [SEP]", "special_tokens": ["[CLS]", "[SEP]"]},
20}
21
22
[docs]
23class BiEncoderTokenizer(LightningIRTokenizer):
24 config_class: type[BiEncoderConfig] = BiEncoderConfig
25 """Configuration class for the tokenizer."""
26
27 QUERY_TOKEN: str = "[QUE]"
28 """Token to mark a query sequence."""
29 DOC_TOKEN: str = "[DOC]"
30 """Token to mark a document sequence."""
31
[docs]
32 def __init__(
33 self,
34 *args,
35 query_length: int | None = 32,
36 doc_length: int | None = 512,
37 add_marker_tokens: bool = False,
38 **kwargs,
39 ):
40 """:class:`.LightningIRTokenizer` for bi-encoder models. Encodes queries and documents separately. Optionally
41 adds marker tokens are added to encoded input sequences.
42
43 Args:
44 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
45 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
46 add_marker_tokens (bool): Whether to add marker tokens to the query and document input sequences.
47 Defaults to False.
48 Raises:
49 ValueError: If `add_marker_tokens` is True and a non-supported tokenizer is used.
50 """
51 super().__init__(
52 *args,
53 query_length=query_length,
54 doc_length=doc_length,
55 add_marker_tokens=add_marker_tokens,
56 **kwargs,
57 )
58 self.query_length = query_length
59 self.doc_length = doc_length
60 self.add_marker_tokens = add_marker_tokens
61
62 self.query_post_processor: TemplateProcessing | None = None
63 self.doc_post_processor: TemplateProcessing | None = None
64 if add_marker_tokens:
65 backbone_model_type = LightningIRClassFactory.get_backbone_model_type(self.name_or_path)
66 if backbone_model_type not in ADD_MARKER_TOKEN_MAPPING:
67 raise ValueError(
68 f"Adding marker tokens is not supported for the backbone model type '{backbone_model_type}'. "
69 f"Supported types are: [{', '.join(ADD_MARKER_TOKEN_MAPPING.keys())}]. "
70 "Please set `add_marker_tokens=False` "
71 "or add the backbone model type to `ADD_MARKER_TOKEN_MAPPING`."
72 )
73 self.add_tokens([self.QUERY_TOKEN, self.DOC_TOKEN], special_tokens=True)
74 pattern = ADD_MARKER_TOKEN_MAPPING[backbone_model_type]["pattern"]
75 special_tokens = [
76 (token, self.convert_tokens_to_ids(token))
77 for token in ADD_MARKER_TOKEN_MAPPING[backbone_model_type]["special_tokens"]
78 ]
79
80 self.query_post_processor = TemplateProcessing(
81 single=pattern.format(TOKEN=self.QUERY_TOKEN),
82 pair=None,
83 special_tokens=special_tokens + [(self.QUERY_TOKEN, self.query_token_id)],
84 )
85 self.doc_post_processor = TemplateProcessing(
86 single=pattern.format(TOKEN=self.DOC_TOKEN),
87 pair=None,
88 special_tokens=special_tokens + [(self.DOC_TOKEN, self.doc_token_id)],
89 )
90
91 @property
92 def query_token_id(self) -> int | None:
93 """The token id of the query token if marker tokens are added.
94
95 Return:
96 Token id of the query token if added, otherwise None.
97 """
98 if self.QUERY_TOKEN in self.added_tokens_encoder:
99 return self.added_tokens_encoder[self.QUERY_TOKEN]
100 return None
101
102 @property
103 def doc_token_id(self) -> int | None:
104 """The token id of the document token if marker tokens are added.
105
106 Returns:
107 Token id of the document token if added, otherwise None.
108 """
109 if self.DOC_TOKEN in self.added_tokens_encoder:
110 return self.added_tokens_encoder[self.DOC_TOKEN]
111 return None
112
113 def __call__(self, *args, warn: bool = True, **kwargs) -> BatchEncoding:
114 """Overrides the PretrainedTokenizer.__call___ method to warn the user to use :meth:`.tokenize_query` and
115 :meth:`.tokenize_doc` methods instead.
116
117 .. PretrainedTokenizer.__call__: \
118https://huggingface.co/docs/transformers/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__
119
120 Args:
121 text (str | Sequence[str]): Text to tokenize.
122 warn (bool): set to False to silence warning. Defaults to True.
123 Returns:
124 BatchEncoding: Tokenized text.
125 """
126 if warn:
127 warnings.warn(
128 "BiEncoderTokenizer is being directly called. Use `tokenize`, `tokenize_query`, or `tokenize_doc` "
129 "to make sure tokenization is done correctly.",
130 stacklevel=2,
131 )
132 return super().__call__(*args, **kwargs)
133
134 def _encode(
135 self,
136 text: str | Sequence[str],
137 *args,
138 post_processor: TemplateProcessing | None = None,
139 **kwargs,
140 ) -> BatchEncoding:
141 """Encodes text with an optional post-processor."""
142 orig_post_processor = self._tokenizer.post_processor
143 if post_processor is not None:
144 self._tokenizer.post_processor = post_processor
145 if kwargs.get("return_tensors", None) is not None:
146 kwargs["pad_to_multiple_of"] = 8
147 encoding = self(text, *args, warn=False, **kwargs)
148 self._tokenizer.post_processor = orig_post_processor
149 return encoding
150
151 def _expand(self, encoding: BatchEncoding, attend_to_expanded_tokens: bool) -> BatchEncoding:
152 """Applies mask expansion to the input encoding."""
153 input_ids = encoding["input_ids"]
154 input_ids[input_ids == self.pad_token_id] = self.mask_token_id
155 encoding["input_ids"] = input_ids
156 if attend_to_expanded_tokens:
157 encoding["attention_mask"].fill_(1)
158 return encoding
159
177
[docs]
178 def tokenize_query(self, queries: Sequence[str] | str, *args, **kwargs) -> BatchEncoding:
179 """Tokenizes input queries.
180
181 Args:
182 queries (Sequence[str] | str): Query or queries to tokenize.
183 Returns:
184 BatchEncoding: Tokenized queries.
185 """
186 encoding = self.tokenize_input_sequence(queries, "query", *args, **kwargs)
187 return encoding
188
[docs]
189 def tokenize_doc(self, docs: Sequence[str] | str, *args, **kwargs) -> BatchEncoding:
190 """Tokenizes input documents.
191
192 Args:
193 docs (Sequence[str] | str): Document or documents to tokenize.
194 Returns:
195 BatchEncoding: Tokenized documents.
196 """
197 encoding = self.tokenize_input_sequence(docs, "doc", *args, **kwargs)
198 return encoding
199
[docs]
200 def tokenize(
201 self,
202 queries: str | Sequence[str] | None = None,
203 docs: str | Sequence[str] | None = None,
204 **kwargs,
205 ) -> dict[str, BatchEncoding]:
206 """Tokenizes queries and documents.
207
208 Args:
209 queries (str | Sequence[str] | None): Queries to tokenize. Defaults to None.
210 docs (str | Sequence[str] | None): Documents to tokenize. Defaults to None.
211 Returns:
212 dict[str, BatchEncoding]: Dictionary containing tokenized queries and documents.
213 """
214 encodings = {}
215 kwargs.pop("num_docs", None)
216 if queries is not None:
217 encodings["query_encoding"] = self.tokenize_query(queries, **kwargs)
218 if docs is not None:
219 encodings["doc_encoding"] = self.tokenize_doc(docs, **kwargs)
220 return encodings