1"""
2Tokenizer module for cross-encoder models.
3
4This module contains the tokenizer class cross-encoder models.
5"""
6
7from typing import Dict, List, Sequence, Tuple, Type
8
9from tokenizers.processors import TemplateProcessing
10from transformers import BatchEncoding
11
12from ..base import LightningIRTokenizer, LightningIRTokenizerClassFactory
13from .cross_encoder_config import CrossEncoderConfig
14
15SCORING_STRATEGY_POST_PROCESSOR_MAPPING = {
16 "t5": {
17 "mono": {
18 "pattern": "pre que col $A doc col $B rel1 rel2 col eos",
19 "special_tokens": [
20 ("pre", "▁"),
21 ("que", "Query"),
22 ("col", ":"),
23 ("doc", "▁Document"),
24 ("rel1", "▁Relevan"),
25 ("rel2", "t"),
26 ("eos", "</s>"),
27 ],
28 },
29 "rank": {
30 "pattern": "pre que col $A doc col $B eos",
31 "special_tokens": [
32 ("pre", "▁"),
33 ("que", "Query"),
34 ("col", ":"),
35 ("doc", "▁Document"),
36 ("rel1", "▁Relevan"),
37 ("rel2", "t"),
38 ("eos", "</s>"),
39 ],
40 },
41 },
42}
43
44
[docs]
45class CrossEncoderTokenizer(LightningIRTokenizer):
46
47 config_class: Type[CrossEncoderConfig] = CrossEncoderConfig
48 """Configuration class for the tokenizer."""
49
[docs]
50 def __init__(
51 self,
52 *args,
53 query_length: int | None = 32,
54 doc_length: int | None = 512,
55 scoring_strategy: str | None = None,
56 **kwargs,
57 ):
58 """:class:`.LightningIRTokenizer` for cross-encoder models. Encodes queries and documents jointly and ensures
59 that the input sequences are of the correct length.
60
61 Args:
62 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
63 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
64 """
65 super().__init__(
66 *args, query_length=query_length, doc_length=doc_length, scoring_strategy=scoring_strategy, **kwargs
67 )
68 self.scoring_strategy = scoring_strategy
69 backbone_model_type = LightningIRTokenizerClassFactory.get_backbone_model_type(self.name_or_path)
70 self.post_processor: TemplateProcessing | None = None
71 if backbone_model_type in SCORING_STRATEGY_POST_PROCESSOR_MAPPING:
72 mapping = SCORING_STRATEGY_POST_PROCESSOR_MAPPING[backbone_model_type]
73 if scoring_strategy is not None and scoring_strategy in mapping:
74 pattern = mapping[scoring_strategy]["pattern"]
75 special_tokens = [
76 (placeholder, self.convert_tokens_to_ids(token))
77 for (placeholder, token) in mapping[scoring_strategy]["special_tokens"]
78 ]
79 self.post_processor = TemplateProcessing(
80 single=None,
81 pair=pattern,
82 special_tokens=special_tokens,
83 )
84
85 def _truncate(self, text: Sequence[str], max_length: int | None) -> List[str]:
86 """Encodes a list of texts, truncates them to a maximum number of tokens and decodes them to strings."""
87 if max_length is None:
88 return text
89 return self.batch_decode(
90 self(
91 text,
92 add_special_tokens=False,
93 truncation=True,
94 max_length=max_length,
95 return_attention_mask=False,
96 return_token_type_ids=False,
97 ).input_ids
98 )
99
100 def _repeat_queries(self, queries: Sequence[str], num_docs: Sequence[int]) -> List[str]:
101 """Repeats queries to match the number of documents."""
102 return [query for query_idx, query in enumerate(queries) for _ in range(num_docs[query_idx])]
103
104 def _preprocess(
105 self,
106 queries: Sequence[str],
107 docs: Sequence[str],
108 num_docs: Sequence[int],
109 ) -> Tuple[str | Sequence[str], str | Sequence[str]]:
110 """Preprocesses queries and documents to ensure that they are truncated their respective maximum lengths."""
111 truncated_queries = self._repeat_queries(self._truncate(queries, self.query_length), num_docs)
112 truncated_docs = self._truncate(docs, self.doc_length)
113 return truncated_queries, truncated_docs
114
115 def _process_num_docs(
116 self,
117 queries: str | Sequence[str],
118 docs: str | Sequence[str],
119 num_docs: Sequence[int] | int | None,
120 ) -> List[int]:
121 if num_docs is None:
122 if isinstance(num_docs, int):
123 num_docs = [num_docs] * len(queries)
124 else:
125 if len(docs) % len(queries) != 0:
126 raise ValueError("Number of documents must be divisible by the number of queries.")
127 num_docs = [len(docs) // len(queries) for _ in range(len(queries))]
128 return num_docs
129
[docs]
130 def tokenize(
131 self,
132 queries: str | Sequence[str] | None = None,
133 docs: str | Sequence[str] | None = None,
134 num_docs: Sequence[int] | int | None = None,
135 **kwargs,
136 ) -> Dict[str, BatchEncoding]:
137 """Tokenizes queries and documents into a single sequence of tokens.
138
139 Args:
140 queries (str | Sequence[str] | None): Queries to tokenize. Defaults to None.
141 docs (str | Sequence[str] | None): Documents to tokenize. Defaults to None.
142 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
143 integers, `len(num_docs)` should be equal to the number of queries and `sum(num_docs)` equal to the
144 number of documents, i.e., the sequence contains one value per query specifying the number of documents
145 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
146 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
147 Returns:
148 Dict[str, BatchEncoding]: Tokenized query-document sequence.
149 Raises:
150 ValueError: If either queries or docs are None.
151 ValueError: If queries and docs are not both lists or both strings.
152 """
153 if queries is None or docs is None:
154 raise ValueError("Both queries and docs must be provided.")
155 if isinstance(docs, str) and not isinstance(queries, str):
156 raise ValueError("Queries and docs must be both lists or both strings.")
157 if isinstance(queries, str):
158 queries = [queries]
159 if isinstance(docs, str):
160 docs = [docs]
161 num_docs = self._process_num_docs(queries, docs, num_docs)
162 queries, docs = self._preprocess(queries, docs, num_docs)
163
164 orig_post_processor = self._tokenizer.post_processor
165 if self.post_processor is not None:
166 self._tokenizer.post_processor = self.post_processor
167
168 encoding = self(queries, docs, **kwargs)
169
170 if self.post_processor is not None:
171 self._tokenizer.post_processor = orig_post_processor
172
173 return {"encoding": encoding}