Source code for lightning_ir.cross_encoder.cross_encoder_tokenizer

  1"""
  2Tokenizer module for cross-encoder models.
  3
  4This module contains the tokenizer class cross-encoder models.
  5"""
  6
  7from typing import Dict, List, Sequence, Tuple, Type
  8
  9from transformers import BatchEncoding
 10
 11from ..base import LightningIRTokenizer
 12from .cross_encoder_config import CrossEncoderConfig
 13
 14
[docs] 15class CrossEncoderTokenizer(LightningIRTokenizer): 16 17 config_class: Type[CrossEncoderConfig] = CrossEncoderConfig 18 """Configuration class for the tokenizer.""" 19
[docs] 20 def __init__( 21 self, *args, query_length: int = 32, doc_length: int = 512, tokenizer_pattern: str | None = None, **kwargs 22 ): 23 """:class:`.LightningIRTokenizer` for cross-encoder models. Encodes queries and documents jointly and ensures 24 that the input sequences are of the correct length. 25 26 Args: 27 query_length (int): Maximum number of tokens per query. Defaults to 32. 28 doc_length (int): Maximum number of tokens per document. Defaults to 512. 29 """ 30 super().__init__( 31 *args, query_length=query_length, doc_length=doc_length, tokenizer_pattern=tokenizer_pattern, **kwargs 32 ) 33 self.tokenizer_pattern = tokenizer_pattern
34 35 def _truncate(self, text: Sequence[str], max_length: int) -> List[str]: 36 """Encodes a list of texts, truncates them to a maximum number of tokens and decodes them to strings.""" 37 return self.batch_decode( 38 self( 39 text, 40 add_special_tokens=False, 41 truncation=True, 42 max_length=max_length, 43 return_attention_mask=False, 44 return_token_type_ids=False, 45 ).input_ids 46 ) 47 48 def _repeat_queries(self, queries: Sequence[str], num_docs: Sequence[int]) -> List[str]: 49 """Repeats queries to match the number of documents.""" 50 return [query for query_idx, query in enumerate(queries) for _ in range(num_docs[query_idx])] 51 52 def _preprocess( 53 self, 54 queries: Sequence[str], 55 docs: Sequence[str], 56 num_docs: Sequence[int], 57 ) -> Tuple[str | Sequence[str], str | Sequence[str]]: 58 """Preprocesses queries and documents to ensure that they are truncated their respective maximum lengths.""" 59 truncated_queries = self._repeat_queries(self._truncate(queries, self.query_length), num_docs) 60 truncated_docs = self._truncate(docs, self.doc_length) 61 return truncated_queries, truncated_docs 62 63 def _process_num_docs( 64 self, 65 queries: str | Sequence[str], 66 docs: str | Sequence[str], 67 num_docs: Sequence[int] | int | None, 68 ) -> List[int]: 69 if num_docs is None: 70 if isinstance(num_docs, int): 71 num_docs = [num_docs] * len(queries) 72 else: 73 if len(docs) % len(queries) != 0: 74 raise ValueError("Number of documents must be divisible by the number of queries.") 75 num_docs = [len(docs) // len(queries) for _ in range(len(queries))] 76 return num_docs 77
[docs] 78 def tokenize( 79 self, 80 queries: str | Sequence[str] | None = None, 81 docs: str | Sequence[str] | None = None, 82 num_docs: Sequence[int] | int | None = None, 83 **kwargs, 84 ) -> Dict[str, BatchEncoding]: 85 """Tokenizes queries and documents into a single sequence of tokens. 86 87 Args: 88 queries (str | Sequence[str] | None): Queries to tokenize. Defaults to None. 89 docs (str | Sequence[str] | None): Documents to tokenize. Defaults to None. 90 num_docs (Sequence[int] | int | None): Specifies how many documents are passed per query. If a sequence of 91 integers, `len(num_docs)` should be equal to the number of queries and `sum(num_docs)` equal to the 92 number of documents, i.e., the sequence contains one value per query specifying the number of documents 93 for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer 94 the number of documents by dividing the number of documents by the number of queries. Defaults to None. 95 Returns: 96 Dict[str, BatchEncoding]: Tokenized query-document sequence. 97 Raises: 98 ValueError: If either queries or docs are None. 99 ValueError: If queries and docs are not both lists or both strings. 100 """ 101 if queries is None or docs is None: 102 raise ValueError("Both queries and docs must be provided.") 103 if isinstance(docs, str) and not isinstance(queries, str): 104 raise ValueError("Queries and docs must be both lists or both strings.") 105 if isinstance(queries, str): 106 queries = [queries] 107 if isinstance(docs, str): 108 docs = [docs] 109 num_docs = self._process_num_docs(queries, docs, num_docs) 110 queries, docs = self._preprocess(queries, docs, num_docs) 111 112 if self.tokenizer_pattern is not None: 113 input_texts = [self.tokenizer_pattern.format(query=query, doc=doc) for query, doc in zip(queries, docs)] 114 encoding = self(input_texts, **kwargs) 115 else: 116 encoding = self(queries, docs, **kwargs) 117 118 return {"encoding": encoding}