Source code for lightning_ir.cross_encoder.cross_encoder_tokenizer

  1"""
  2Tokenizer module for cross-encoder models.
  3
  4This module contains the tokenizer class cross-encoder models.
  5"""
  6
  7from typing import Dict, List, Sequence, Tuple, Type
  8
  9from tokenizers.processors import TemplateProcessing
 10from transformers import BatchEncoding
 11
 12from ..base import LightningIRTokenizer, LightningIRTokenizerClassFactory
 13from .cross_encoder_config import CrossEncoderConfig
 14
 15SCORING_STRATEGY_POST_PROCESSOR_MAPPING = {
 16    "t5": {
 17        "mono": {
 18            "pattern": "pre que col $A doc col $B rel1 rel2 col eos",
 19            "special_tokens": [
 20                ("pre", "▁"),
 21                ("que", "Query"),
 22                ("col", ":"),
 23                ("doc", "▁Document"),
 24                ("rel1", "▁Relevan"),
 25                ("rel2", "t"),
 26                ("eos", "</s>"),
 27            ],
 28        },
 29        "rank": {
 30            "pattern": "pre que col $A doc col $B eos",
 31            "special_tokens": [
 32                ("pre", "▁"),
 33                ("que", "Query"),
 34                ("col", ":"),
 35                ("doc", "▁Document"),
 36                ("rel1", "▁Relevan"),
 37                ("rel2", "t"),
 38                ("eos", "</s>"),
 39            ],
 40        },
 41    },
 42}
 43
 44
[docs] 45class CrossEncoderTokenizer(LightningIRTokenizer): 46 47 config_class: Type[CrossEncoderConfig] = CrossEncoderConfig 48 """Configuration class for the tokenizer.""" 49
[docs] 50 def __init__( 51 self, 52 *args, 53 query_length: int | None = 32, 54 doc_length: int | None = 512, 55 scoring_strategy: str | None = None, 56 **kwargs, 57 ): 58 """:class:`.LightningIRTokenizer` for cross-encoder models. Encodes queries and documents jointly and ensures 59 that the input sequences are of the correct length. 60 61 Args: 62 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 63 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 64 """ 65 super().__init__( 66 *args, query_length=query_length, doc_length=doc_length, scoring_strategy=scoring_strategy, **kwargs 67 ) 68 self.scoring_strategy = scoring_strategy 69 backbone_model_type = LightningIRTokenizerClassFactory.get_backbone_model_type(self.name_or_path) 70 self.post_processor: TemplateProcessing | None = None 71 if backbone_model_type in SCORING_STRATEGY_POST_PROCESSOR_MAPPING: 72 mapping = SCORING_STRATEGY_POST_PROCESSOR_MAPPING[backbone_model_type] 73 if scoring_strategy is not None and scoring_strategy in mapping: 74 pattern = mapping[scoring_strategy]["pattern"] 75 special_tokens = [ 76 (placeholder, self.convert_tokens_to_ids(token)) 77 for (placeholder, token) in mapping[scoring_strategy]["special_tokens"] 78 ] 79 self.post_processor = TemplateProcessing( 80 single=None, 81 pair=pattern, 82 special_tokens=special_tokens, 83 )
84 85 def _truncate(self, text: Sequence[str], max_length: int | None) -> List[str]: 86 """Encodes a list of texts, truncates them to a maximum number of tokens and decodes them to strings.""" 87 if max_length is None: 88 return text 89 return self.batch_decode( 90 self( 91 text, 92 add_special_tokens=False, 93 truncation=True, 94 max_length=max_length, 95 return_attention_mask=False, 96 return_token_type_ids=False, 97 ).input_ids 98 ) 99 100 def _repeat_queries(self, queries: Sequence[str], num_docs: Sequence[int]) -> List[str]: 101 """Repeats queries to match the number of documents.""" 102 return [query for query_idx, query in enumerate(queries) for _ in range(num_docs[query_idx])] 103 104 def _preprocess( 105 self, 106 queries: Sequence[str], 107 docs: Sequence[str], 108 num_docs: Sequence[int], 109 ) -> Tuple[str | Sequence[str], str | Sequence[str]]: 110 """Preprocesses queries and documents to ensure that they are truncated their respective maximum lengths.""" 111 truncated_queries = self._repeat_queries(self._truncate(queries, self.query_length), num_docs) 112 truncated_docs = self._truncate(docs, self.doc_length) 113 return truncated_queries, truncated_docs 114 115 def _process_num_docs( 116 self, 117 queries: str | Sequence[str], 118 docs: str | Sequence[str], 119 num_docs: Sequence[int] | int | None, 120 ) -> List[int]: 121 if num_docs is None: 122 if isinstance(num_docs, int): 123 num_docs = [num_docs] * len(queries) 124 else: 125 if len(docs) % len(queries) != 0: 126 raise ValueError("Number of documents must be divisible by the number of queries.") 127 num_docs = [len(docs) // len(queries) for _ in range(len(queries))] 128 return num_docs 129
[docs] 130 def tokenize( 131 self, 132 queries: str | Sequence[str] | None = None, 133 docs: str | Sequence[str] | None = None, 134 num_docs: Sequence[int] | int | None = None, 135 **kwargs, 136 ) -> Dict[str, BatchEncoding]: 137 """Tokenizes queries and documents into a single sequence of tokens. 138 139 Args: 140 queries (str | Sequence[str] | None): Queries to tokenize. Defaults to None. 141 docs (str | Sequence[str] | None): Documents to tokenize. Defaults to None. 142 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 143 integers, `len(num_docs)` should be equal to the number of queries and `sum(num_docs)` equal to the 144 number of documents, i.e., the sequence contains one value per query specifying the number of documents 145 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 146 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 147 Returns: 148 Dict[str, BatchEncoding]: Tokenized query-document sequence. 149 Raises: 150 ValueError: If either queries or docs are None. 151 ValueError: If queries and docs are not both lists or both strings. 152 """ 153 if queries is None or docs is None: 154 raise ValueError("Both queries and docs must be provided.") 155 if isinstance(docs, str) and not isinstance(queries, str): 156 raise ValueError("Queries and docs must be both lists or both strings.") 157 if isinstance(queries, str): 158 queries = [queries] 159 if isinstance(docs, str): 160 docs = [docs] 161 num_docs = self._process_num_docs(queries, docs, num_docs) 162 queries, docs = self._preprocess(queries, docs, num_docs) 163 164 orig_post_processor = self._tokenizer.post_processor 165 if self.post_processor is not None: 166 self._tokenizer.post_processor = self.post_processor 167 168 encoding = self(queries, docs, **kwargs) 169 170 if self.post_processor is not None: 171 self._tokenizer.post_processor = orig_post_processor 172 173 return {"encoding": encoding}