Source code for lightning_ir.cross_encoder.cross_encoder_tokenizer

  1"""
  2Tokenizer module for cross-encoder models.
  3
  4This module contains the tokenizer class cross-encoder models.
  5"""
  6
  7from typing import Dict, List, Sequence, Tuple, Type
  8
  9from transformers import BatchEncoding
 10
 11from ..base import LightningIRTokenizer
 12from .cross_encoder_config import CrossEncoderConfig
 13
 14
[docs] 15class CrossEncoderTokenizer(LightningIRTokenizer): 16 17 config_class: Type[CrossEncoderConfig] = CrossEncoderConfig 18 """Configuration class for the tokenizer.""" 19
[docs] 20 def __init__(self, *args, query_length: int = 32, doc_length: int = 512, **kwargs): 21 """:class:`.LightningIRTokenizer` for cross-encoder models. Encodes queries and documents jointly and ensures 22 that the input sequences are of the correct length. 23 24 :param query_length: Maximum number of tokens per query, defaults to 32 25 :type query_length: int, optional 26 :param doc_length: Maximum number of tokens per document, defaults to 512 27 :type doc_length: int, optional 28 :type doc_length: int, optional 29 """ 30 super().__init__(*args, query_length=query_length, doc_length=doc_length, **kwargs)
31 32 def _truncate(self, text: Sequence[str], max_length: int) -> List[str]: 33 """Encodes a list of texts, truncates them to a maximum number of tokens and decodes them to strings.""" 34 return self.batch_decode( 35 self( 36 text, 37 add_special_tokens=False, 38 truncation=True, 39 max_length=max_length, 40 return_attention_mask=False, 41 return_token_type_ids=False, 42 ).input_ids 43 ) 44 45 def _repeat_queries(self, queries: Sequence[str], num_docs: Sequence[int]) -> List[str]: 46 """Repeats queries to match the number of documents.""" 47 return [query for query_idx, query in enumerate(queries) for _ in range(num_docs[query_idx])] 48 49 def _preprocess( 50 self, 51 queries: Sequence[str], 52 docs: Sequence[str], 53 num_docs: Sequence[int], 54 ) -> Tuple[str | Sequence[str], str | Sequence[str]]: 55 """Preprocesses queries and documents to ensure that they are truncated their respective maximum lengths.""" 56 truncated_queries = self._repeat_queries(self._truncate(queries, self.query_length), num_docs) 57 truncated_docs = self._truncate(docs, self.doc_length) 58 return truncated_queries, truncated_docs 59 60 def _process_num_docs( 61 self, 62 queries: str | Sequence[str], 63 docs: str | Sequence[str], 64 num_docs: Sequence[int] | int | None, 65 ) -> List[int]: 66 if num_docs is None: 67 if isinstance(num_docs, int): 68 num_docs = [num_docs] * len(queries) 69 else: 70 if len(docs) % len(queries) != 0: 71 raise ValueError("Number of documents must be divisible by the number of queries.") 72 num_docs = [len(docs) // len(queries) for _ in range(len(queries))] 73 return num_docs 74
[docs] 75 def tokenize( 76 self, 77 queries: str | Sequence[str] | None = None, 78 docs: str | Sequence[str] | None = None, 79 num_docs: Sequence[int] | int | None = None, 80 **kwargs, 81 ) -> Dict[str, BatchEncoding]: 82 """Tokenizes queries and documents into a single sequence of tokens. 83 84 :param queries: Queries to tokenize, defaults to None 85 :type queries: str | Sequence[str] | None, optional 86 :param docs: Documents to tokenize, defaults to None 87 :type docs: str | Sequence[str] | None, optional 88 :param num_docs: Specifies how many documents are passed per query. If a sequence of integers, `len(num_doc)` 89 should be equal to the number of queries and `sum(num_docs)` equal to the number of documents, i.e., the 90 sequence contains one value per query specifying the number of documents for that query. If an integer, 91 assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing 92 the number of documents by the number of queries, defaults to None 93 :type num_docs: Sequence[int] | int | None, optional 94 :return: Tokenized query-document sequence 95 :rtype: Dict[str, BatchEncoding] 96 """ 97 if queries is None or docs is None: 98 raise ValueError("Both queries and docs must be provided.") 99 if isinstance(docs, str) and not isinstance(queries, str): 100 raise ValueError("Queries and docs must be both lists or both strings.") 101 is_string_queries = False 102 is_string_docs = False 103 if isinstance(queries, str): 104 queries = [queries] 105 is_string_queries = True 106 if isinstance(docs, str): 107 docs = [docs] 108 is_string_docs = True 109 is_string_both = is_string_queries and is_string_docs 110 num_docs = self._process_num_docs(queries, docs, num_docs) 111 queries, docs = self._preprocess(queries, docs, num_docs) 112 return_tensors = kwargs.get("return_tensors", None) 113 if return_tensors is not None: 114 kwargs["pad_to_multiple_of"] = 8 115 if is_string_both: 116 encoding = self(queries[0], docs[0], **kwargs) 117 else: 118 encoding = self(queries, docs, **kwargs) 119 return {"encoding": encoding}