Source code for lightning_ir.base.tokenizer

  1"""
  2Tokenizer module for Lightning IR.
  3
  4This module contains the main tokenizer class for the Lightning IR library.
  5"""
  6
  7import json
  8from os import PathLike
  9from typing import Dict, Self, Sequence, Tuple, Type
 10
 11from transformers import TOKENIZER_MAPPING, BatchEncoding, PreTrainedTokenizerBase
 12
 13from .class_factory import LightningIRTokenizerClassFactory
 14from .config import LightningIRConfig
 15from .external_model_hub import CHECKPOINT_MAPPING
 16
 17
[docs] 18class LightningIRTokenizer(PreTrainedTokenizerBase): 19 """Base class for Lightning IR tokenizers. Derived classes implement the tokenize method for handling query 20 and document tokenization. It acts as mixin for a transformers.PreTrainedTokenizer_ backbone tokenizer. 21 22 .. _transformers.PreTrainedTokenizer: \ 23https://huggingface.co/transformers/main_classes/tokenizer.htmltransformers.PreTrainedTokenizer 24 """ 25 26 config_class: Type[LightningIRConfig] = LightningIRConfig 27 """Configuration class for the tokenizer.""" 28
[docs] 29 def __init__(self, *args, query_length: int = 32, doc_length: int = 512, **kwargs): 30 """Initializes the tokenizer. 31 32 :param query_length: Maximum number of tokens per query, defaults to 32 33 :type query_length: int, optional 34 :param doc_length: Maximum number of tokens per document, defaults to 512 35 :type doc_length: int, optional 36 """ 37 super().__init__(*args, query_length=query_length, doc_length=doc_length, **kwargs) 38 self.query_length = query_length 39 self.doc_length = doc_length
40
[docs] 41 def tokenize( 42 self, queries: str | Sequence[str] | None = None, docs: str | Sequence[str] | None = None, **kwargs 43 ) -> Dict[str, BatchEncoding]: 44 """Tokenizes queries and documents. 45 46 :param queries: Queries to tokenize, defaults to None 47 :type queries: str | Sequence[str] | None, optional 48 :param docs: Documents to tokenize, defaults to None 49 :type docs: str | Sequence[str] | None, optional 50 :raises NotImplementedError: Must be implemented by the derived class 51 :return: Dictionary of tokenized queries and documents 52 :rtype: Dict[str, BatchEncoding] 53 """ 54 raise NotImplementedError
55
[docs] 56 @classmethod 57 def from_pretrained(cls, model_name_or_path: str, *args, **kwargs) -> Self: 58 """Loads a pretrained tokenizer. Wraps the transformers.PreTrainedTokenizer.from_pretrained_ method to return a 59 derived LightningIRTokenizer class. See :class:`.LightningIRTokenizerClassFactory` for more details. 60 61 .. _transformers.PreTrainedTokenizer.from_pretrained: \ 62https://huggingface.co/docs/transformers/main_classes/tokenizer.html#transformers.PreTrainedTokenizer.from_pretrained 63 64 .. highlight:: python 65 .. code-block:: python 66 67 >>> Loading using model class and backbone checkpoint 68 >>> type(BiEncoderTokenizer.from_pretrained("bert-base-uncased")) 69 ... 70 <class 'lightning_ir.base.class_factory.BiEncoderBertTokenizerFast'> 71 >>> Loading using base class and backbone checkpoint 72 >>> type(LightningIRTokenizer.from_pretrained("bert-base-uncased", config=BiEncoderConfig())) 73 ... 74 <class 'lightning_ir.base.class_factory.BiEncoderBertTokenizerFast'> 75 76 :param model_name_or_path: Name or path of the pretrained tokenizer 77 :type model_name_or_path: str 78 :raises ValueError: If called on the abstract class :class:`LightningIRTokenizer` and no config is passed 79 :return: A derived LightningIRTokenizer consisting of a backbone tokenizer and a LightningIRTokenizer mixin 80 :rtype: LightningIRTokenizer 81 """ 82 # provides AutoTokenizer.from_pretrained support 83 config = kwargs.get("config", None) 84 if cls is LightningIRTokenizer or all(issubclass(base, LightningIRTokenizer) for base in cls.__bases__): 85 # no backbone models found, create derived lightning-ir tokenizer based on backbone model 86 if config is not None: 87 ConfigClass = config.__class__ 88 elif model_name_or_path in CHECKPOINT_MAPPING: 89 _config = CHECKPOINT_MAPPING[model_name_or_path] 90 ConfigClass = _config.__class__ 91 if config is None: 92 kwargs["config"] = _config 93 elif cls is not LightningIRTokenizer and hasattr(cls, "config_class"): 94 ConfigClass = cls.config_class 95 else: 96 ConfigClass = LightningIRTokenizerClassFactory.get_lightning_ir_config(model_name_or_path) 97 if ConfigClass is None: 98 raise ValueError("Pass a config to `from_pretrained`.") 99 ConfigClass = getattr(ConfigClass, "mixin_config", ConfigClass) 100 backbone_config = LightningIRTokenizerClassFactory.get_backbone_config(model_name_or_path) 101 BackboneTokenizers = TOKENIZER_MAPPING[type(backbone_config)] 102 if kwargs.get("use_fast", True): 103 BackboneTokenizer = BackboneTokenizers[1] 104 else: 105 BackboneTokenizer = BackboneTokenizers[0] 106 cls = LightningIRTokenizerClassFactory(ConfigClass).from_backbone_class(BackboneTokenizer) 107 return cls.from_pretrained(model_name_or_path, *args, **kwargs) 108 config = kwargs.pop("config", None) 109 if config is not None: 110 kwargs.update(config.get_tokenizer_kwargs(cls)) 111 return super(LightningIRTokenizer, cls).from_pretrained(model_name_or_path, *args, **kwargs)
112 113 def _save_pretrained( 114 self, 115 save_directory: str | PathLike, 116 file_names: Tuple[str], 117 legacy_format: bool | None = None, 118 filename_prefix: str | None = None, 119 ) -> Tuple[str]: 120 # bit of a hack to change the tokenizer class in the stored tokenizer config to only contain the 121 # lightning_ir tokenizer class (removing the backbone tokenizer class) 122 save_files = super()._save_pretrained(save_directory, file_names, legacy_format, filename_prefix) 123 config_file = save_files[0] 124 with open(config_file, "r") as file: 125 tokenizer_config = json.load(file) 126 127 tokenizer_class = None 128 backbone_tokenizer_class = None 129 for base in self.__class__.__bases__: 130 if issubclass(base, LightningIRTokenizer): 131 if tokenizer_class is not None: 132 raise ValueError("Multiple Lightning IR tokenizer classes found.") 133 tokenizer_class = base.__name__ 134 continue 135 if issubclass(base, PreTrainedTokenizerBase): 136 backbone_tokenizer_class = base.__name__ 137 138 tokenizer_config["tokenizer_class"] = tokenizer_class 139 tokenizer_config["backbone_tokenizer_class"] = backbone_tokenizer_class 140 141 with open(config_file, "w") as file: 142 out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n" 143 file.write(out_str) 144 return save_files