1"""
2Tokenizer module for cross-encoder models.
3
4This module contains the tokenizer class cross-encoder models.
5"""
6
7from collections.abc import Sequence
8
9from tokenizers.processors import TemplateProcessing
10from transformers import BatchEncoding
11
12from ..base import LightningIRTokenizer, LightningIRTokenizerClassFactory
13from .cross_encoder_config import CrossEncoderConfig
14
15SCORING_STRATEGY_POST_PROCESSOR_MAPPING = {
16 "t5": {
17 "mono": {
18 "pattern": "pre que col $A doc col $B rel1 rel2 col eos",
19 "special_tokens": [
20 ("pre", "▁"),
21 ("que", "Query"),
22 ("col", ":"),
23 ("doc", "▁Document"),
24 ("rel1", "▁Relevan"),
25 ("rel2", "t"),
26 ("eos", "</s>"),
27 ],
28 },
29 "rank": {
30 "pattern": "pre que col $A doc col $B eos",
31 "special_tokens": [
32 ("pre", "▁"),
33 ("que", "Query"),
34 ("col", ":"),
35 ("doc", "▁Document"),
36 ("rel1", "▁Relevan"),
37 ("rel2", "t"),
38 ("eos", "</s>"),
39 ],
40 },
41 },
42}
43
44
[docs]
45class CrossEncoderTokenizer(LightningIRTokenizer):
46 config_class: type[CrossEncoderConfig] = CrossEncoderConfig
47 """Configuration class for the tokenizer."""
48
[docs]
49 def __init__(
50 self,
51 *args,
52 query_length: int | None = 32,
53 doc_length: int | None = 512,
54 scoring_strategy: str | None = None,
55 **kwargs,
56 ):
57 """:class:`.LightningIRTokenizer` for cross-encoder models. Encodes queries and documents jointly and ensures
58 that the input sequences are of the correct length.
59
60 Args:
61 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
62 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
63 """
64 super().__init__(
65 *args, query_length=query_length, doc_length=doc_length, scoring_strategy=scoring_strategy, **kwargs
66 )
67 self.scoring_strategy = scoring_strategy
68 backbone_model_type = LightningIRTokenizerClassFactory.get_backbone_model_type(self.name_or_path)
69 self.post_processor: TemplateProcessing | None = None
70 if backbone_model_type in SCORING_STRATEGY_POST_PROCESSOR_MAPPING:
71 mapping = SCORING_STRATEGY_POST_PROCESSOR_MAPPING[backbone_model_type]
72 if scoring_strategy is not None and scoring_strategy in mapping:
73 pattern = mapping[scoring_strategy]["pattern"]
74 special_tokens = [
75 (placeholder, self.convert_tokens_to_ids(token))
76 for (placeholder, token) in mapping[scoring_strategy]["special_tokens"]
77 ]
78 self.post_processor = TemplateProcessing(
79 single=None,
80 pair=pattern,
81 special_tokens=special_tokens,
82 )
83
84 def _truncate(self, text: Sequence[str], max_length: int | None) -> list[str]:
85 """Encodes a list of texts, truncates them to a maximum number of tokens and decodes them to strings."""
86 if max_length is None:
87 return text
88 return self.batch_decode(
89 self(
90 text,
91 add_special_tokens=False,
92 truncation=True,
93 max_length=max_length,
94 return_attention_mask=False,
95 return_token_type_ids=False,
96 ).input_ids
97 )
98
99 def _repeat_queries(self, queries: Sequence[str], num_docs: Sequence[int]) -> list[str]:
100 """Repeats queries to match the number of documents."""
101 return [query for query_idx, query in enumerate(queries) for _ in range(num_docs[query_idx])]
102
103 def _preprocess(
104 self,
105 queries: Sequence[str],
106 docs: Sequence[str],
107 num_docs: Sequence[int],
108 ) -> tuple[str | Sequence[str], str | Sequence[str]]:
109 """Preprocesses queries and documents to ensure that they are truncated their respective maximum lengths."""
110 truncated_queries = self._repeat_queries(self._truncate(queries, self.query_length), num_docs)
111 truncated_docs = self._truncate(docs, self.doc_length)
112 return truncated_queries, truncated_docs
113
114 def _process_num_docs(
115 self,
116 queries: str | Sequence[str],
117 docs: str | Sequence[str],
118 num_docs: Sequence[int] | int | None,
119 ) -> list[int]:
120 if num_docs is None:
121 if isinstance(num_docs, int):
122 num_docs = [num_docs] * len(queries)
123 else:
124 if len(docs) % len(queries) != 0:
125 raise ValueError("Number of documents must be divisible by the number of queries.")
126 num_docs = [len(docs) // len(queries) for _ in range(len(queries))]
127 return num_docs
128
[docs]
129 def tokenize(
130 self,
131 queries: str | Sequence[str] | None = None,
132 docs: str | Sequence[str] | None = None,
133 num_docs: Sequence[int] | int | None = None,
134 **kwargs,
135 ) -> dict[str, BatchEncoding]:
136 """Tokenizes queries and documents into a single sequence of tokens.
137
138 Args:
139 queries (str | Sequence[str] | None): Queries to tokenize. Defaults to None.
140 docs (str | Sequence[str] | None): Documents to tokenize. Defaults to None.
141 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of
142 integers, `len(num_docs)` should be equal to the number of queries and `sum(num_docs)` equal to the
143 number of documents, i.e., the sequence contains one value per query specifying the number of documents
144 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer
145 the number of documents by dividing the number of documents by the number of queries. Defaults to None.
146 Returns:
147 dict[str, BatchEncoding]: Tokenized query-document sequence.
148 Raises:
149 ValueError: If either queries or docs are None.
150 ValueError: If queries and docs are not both lists or both strings.
151 """
152 if queries is None or docs is None:
153 raise ValueError("Both queries and docs must be provided.")
154 if isinstance(docs, str) and not isinstance(queries, str):
155 raise ValueError("Queries and docs must be both lists or both strings.")
156 if isinstance(queries, str):
157 queries = [queries]
158 if isinstance(docs, str):
159 docs = [docs]
160 num_docs = self._process_num_docs(queries, docs, num_docs)
161 queries, docs = self._preprocess(queries, docs, num_docs)
162
163 orig_post_processor = self._tokenizer.post_processor
164 if self.post_processor is not None:
165 self._tokenizer.post_processor = self.post_processor
166
167 encoding = self(queries, docs, **kwargs)
168
169 if self.post_processor is not None:
170 self._tokenizer.post_processor = orig_post_processor
171
172 return {"encoding": encoding}