1"""
2Tokenizer module for Lightning IR.
3
4This module contains the main tokenizer class for the Lightning IR library.
5"""
6
7import json
8from os import PathLike
9from typing import Dict, Self, Sequence, Tuple, Type
10
11from transformers import TOKENIZER_MAPPING, BatchEncoding, PreTrainedTokenizerBase
12
13from .class_factory import LightningIRTokenizerClassFactory
14from .config import LightningIRConfig
15from .external_model_hub import CHECKPOINT_MAPPING
16
17
[docs]
18class LightningIRTokenizer(PreTrainedTokenizerBase):
19 """Base class for Lightning IR tokenizers. Derived classes implement the tokenize method for handling query
20 and document tokenization. It acts as mixin for a transformers.PreTrainedTokenizer_ backbone tokenizer.
21
22 .. _transformers.PreTrainedTokenizer: \
23https://huggingface.co/transformers/main_classes/tokenizer.htmltransformers.PreTrainedTokenizer
24 """
25
26 config_class: Type[LightningIRConfig] = LightningIRConfig
27 """Configuration class for the tokenizer."""
28
[docs]
29 def __init__(self, *args, query_length: int | None = 32, doc_length: int | None = 512, **kwargs):
30 """Initializes the tokenizer.
31
32 Args:
33 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
34 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
35 """
36 super().__init__(*args, query_length=query_length, doc_length=doc_length, **kwargs)
37 self.query_length = query_length
38 self.doc_length = doc_length
39
[docs]
40 def tokenize(
41 self, queries: str | Sequence[str] | None = None, docs: str | Sequence[str] | None = None, **kwargs
42 ) -> Dict[str, BatchEncoding]:
43 """Tokenizes queries and documents.
44
45 Args:
46 queries (str | Sequence[str] | None): Queries to tokenize. Defaults to None.
47 docs (str | Sequence[str] | None): Documents to tokenize. Defaults to None.
48 Returns:
49 Dict[str, BatchEncoding]: Dictionary containing tokenized queries and documents.
50 Raises:
51 NotImplementedError: Must be implemented by the derived class.
52 """
53 raise NotImplementedError
54
[docs]
55 @classmethod
56 def from_pretrained(cls, model_name_or_path: str, *args, **kwargs) -> Self:
57 """Loads a pretrained tokenizer. Wraps the transformers.PreTrainedTokenizer.from_pretrained_ method to return a
58 derived LightningIRTokenizer class. See :class:`.LightningIRTokenizerClassFactory` for more details.
59
60 .. _transformers.PreTrainedTokenizer.from_pretrained: \
61https://huggingface.co/docs/transformers/main_classes/tokenizer.html#transformers.PreTrainedTokenizer.from_pretrained
62
63 .. highlight:: python
64 .. code-block:: python
65
66 >>> Loading using model class and backbone checkpoint
67 >>> type(BiEncoderTokenizer.from_pretrained("bert-base-uncased"))
68 ...
69 <class 'lightning_ir.base.class_factory.BiEncoderBertTokenizerFast'>
70 >>> Loading using base class and backbone checkpoint
71 >>> type(LightningIRTokenizer.from_pretrained("bert-base-uncased", config=BiEncoderConfig()))
72 ...
73 <class 'lightning_ir.base.class_factory.BiEncoderBertTokenizerFast'>
74
75 Args:
76 model_name_or_path (str): Name or path of the pretrained tokenizer.
77 Returns:
78 Self: A derived LightningIRTokenizer consisting of a backbone tokenizer and a LightningIRTokenizer mixin.
79 Raises:
80 ValueError: If called on the abstract class `LightningIRTokenizer` and no config is passed.
81 """
82 # provides AutoTokenizer.from_pretrained support
83 config = kwargs.get("config", None)
84 if cls is LightningIRTokenizer or all(issubclass(base, LightningIRTokenizer) for base in cls.__bases__):
85 # no backbone models found, create derived lightning-ir tokenizer based on backbone model
86 if config is not None:
87 ConfigClass = config.__class__
88 elif model_name_or_path in CHECKPOINT_MAPPING:
89 _config = CHECKPOINT_MAPPING[model_name_or_path]
90 ConfigClass = _config.__class__
91 if config is None:
92 kwargs["config"] = _config
93 elif cls is not LightningIRTokenizer and hasattr(cls, "config_class"):
94 ConfigClass = cls.config_class
95 else:
96 ConfigClass = LightningIRTokenizerClassFactory.get_lightning_ir_config(model_name_or_path)
97 if ConfigClass is None:
98 raise ValueError("Pass a config to `from_pretrained`.")
99 ConfigClass = getattr(ConfigClass, "mixin_config", ConfigClass)
100 backbone_config = LightningIRTokenizerClassFactory.get_backbone_config(model_name_or_path)
101 BackboneTokenizers = TOKENIZER_MAPPING[type(backbone_config)]
102 if kwargs.get("use_fast", True):
103 BackboneTokenizer = BackboneTokenizers[1]
104 else:
105 BackboneTokenizer = BackboneTokenizers[0]
106 cls = LightningIRTokenizerClassFactory(ConfigClass).from_backbone_class(BackboneTokenizer)
107 return cls.from_pretrained(model_name_or_path, *args, **kwargs)
108 config = kwargs.pop("config", None)
109 if config is not None:
110 kwargs.update(config.get_tokenizer_kwargs(cls))
111 return super(LightningIRTokenizer, cls).from_pretrained(model_name_or_path, *args, **kwargs)
112
113 def _save_pretrained(
114 self,
115 save_directory: str | PathLike,
116 file_names: Tuple[str],
117 legacy_format: bool | None = None,
118 filename_prefix: str | None = None,
119 ) -> Tuple[str]:
120 # bit of a hack to change the tokenizer class in the stored tokenizer config to only contain the
121 # lightning_ir tokenizer class (removing the backbone tokenizer class)
122 save_files = super()._save_pretrained(save_directory, file_names, legacy_format, filename_prefix)
123 config_file = save_files[0]
124 with open(config_file, "r") as file:
125 tokenizer_config = json.load(file)
126
127 tokenizer_class = None
128 backbone_tokenizer_class = None
129 for base in self.__class__.__bases__:
130 if issubclass(base, LightningIRTokenizer):
131 if tokenizer_class is not None:
132 raise ValueError("Multiple Lightning IR tokenizer classes found.")
133 tokenizer_class = base.__name__
134 continue
135 if issubclass(base, PreTrainedTokenizerBase):
136 backbone_tokenizer_class = base.__name__
137
138 tokenizer_config["tokenizer_class"] = tokenizer_class
139 tokenizer_config["backbone_tokenizer_class"] = backbone_tokenizer_class
140
141 with open(config_file, "w") as file:
142 out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
143 file.write(out_str)
144 return save_files