1"""
2Tokenizer module for cross-encoder models.
3
4This module contains the tokenizer class cross-encoder models.
5"""
6
7from typing import Dict, List, Sequence, Tuple, Type
8
9from transformers import BatchEncoding
10
11from ..base import LightningIRTokenizer
12from .cross_encoder_config import CrossEncoderConfig
13
14
[docs]
15class CrossEncoderTokenizer(LightningIRTokenizer):
16
17 config_class: Type[CrossEncoderConfig] = CrossEncoderConfig
18 """Configuration class for the tokenizer."""
19
[docs]
20 def __init__(self, *args, query_length: int = 32, doc_length: int = 512, **kwargs):
21 """:class:`.LightningIRTokenizer` for cross-encoder models. Encodes queries and documents jointly and ensures
22 that the input sequences are of the correct length.
23
24 :param query_length: Maximum number of tokens per query, defaults to 32
25 :type query_length: int, optional
26 :param doc_length: Maximum number of tokens per document, defaults to 512
27 :type doc_length: int, optional
28 :type doc_length: int, optional
29 """
30 super().__init__(*args, query_length=query_length, doc_length=doc_length, **kwargs)
31
32 def _truncate(self, text: Sequence[str], max_length: int) -> List[str]:
33 """Encodes a list of texts, truncates them to a maximum number of tokens and decodes them to strings."""
34 return self.batch_decode(
35 self(
36 text,
37 add_special_tokens=False,
38 truncation=True,
39 max_length=max_length,
40 return_attention_mask=False,
41 return_token_type_ids=False,
42 ).input_ids
43 )
44
45 def _repeat_queries(self, queries: Sequence[str], num_docs: Sequence[int]) -> List[str]:
46 """Repeats queries to match the number of documents."""
47 return [query for query_idx, query in enumerate(queries) for _ in range(num_docs[query_idx])]
48
49 def _preprocess(
50 self,
51 queries: Sequence[str],
52 docs: Sequence[str],
53 num_docs: Sequence[int],
54 ) -> Tuple[str | Sequence[str], str | Sequence[str]]:
55 """Preprocesses queries and documents to ensure that they are truncated their respective maximum lengths."""
56 truncated_queries = self._repeat_queries(self._truncate(queries, self.query_length), num_docs)
57 truncated_docs = self._truncate(docs, self.doc_length)
58 return truncated_queries, truncated_docs
59
60 def _process_num_docs(
61 self,
62 queries: str | Sequence[str],
63 docs: str | Sequence[str],
64 num_docs: Sequence[int] | int | None,
65 ) -> List[int]:
66 if num_docs is None:
67 if isinstance(num_docs, int):
68 num_docs = [num_docs] * len(queries)
69 else:
70 if len(docs) % len(queries) != 0:
71 raise ValueError("Number of documents must be divisible by the number of queries.")
72 num_docs = [len(docs) // len(queries) for _ in range(len(queries))]
73 return num_docs
74
[docs]
75 def tokenize(
76 self,
77 queries: str | Sequence[str] | None = None,
78 docs: str | Sequence[str] | None = None,
79 num_docs: Sequence[int] | int | None = None,
80 **kwargs,
81 ) -> Dict[str, BatchEncoding]:
82 """Tokenizes queries and documents into a single sequence of tokens.
83
84 :param queries: Queries to tokenize, defaults to None
85 :type queries: str | Sequence[str] | None, optional
86 :param docs: Documents to tokenize, defaults to None
87 :type docs: str | Sequence[str] | None, optional
88 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)`
89 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the
90 sequence contains one value per query specifying the number of documents for that query. If an integer,
91 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing
92 the number of documents by the number of queries, defaults to None
93 :type num_docs: Sequence[int] | int | None, optional
94 :return: Tokenized query-document sequence
95 :rtype: Dict[str, BatchEncoding]
96 """
97 if queries is None or docs is None:
98 raise ValueError("Both queries and docs must be provided.")
99 if isinstance(docs, str) and not isinstance(queries, str):
100 raise ValueError("Queries and docs must be both lists or both strings.")
101 is_string_queries = False
102 is_string_docs = False
103 if isinstance(queries, str):
104 queries = [queries]
105 is_string_queries = True
106 if isinstance(docs, str):
107 docs = [docs]
108 is_string_docs = True
109 is_string_both = is_string_queries and is_string_docs
110 num_docs = self._process_num_docs(queries, docs, num_docs)
111 queries, docs = self._preprocess(queries, docs, num_docs)
112 return_tensors = kwargs.get("return_tensors", None)
113 if return_tensors is not None:
114 kwargs["pad_to_multiple_of"] = 8
115 if is_string_both:
116 encoding = self(queries[0], docs[0], **kwargs)
117 else:
118 encoding = self(queries, docs, **kwargs)
119 return {"encoding": encoding}