Source code for lightning_ir.base.tokenizer
1"""
2Tokenizer module for Lightning IR.
3
4This module contains the main tokenizer class for the Lightning IR library.
5"""
6
7import json
8from collections.abc import Sequence
9from os import PathLike
10from typing import Self
11
12from transformers import TOKENIZER_MAPPING, BatchEncoding, PreTrainedTokenizerBase
13
14from .class_factory import LightningIRTokenizerClassFactory
15from .config import LightningIRConfig
16from .external_model_hub import CHECKPOINT_MAPPING
17
18
[docs]
19class LightningIRTokenizer(PreTrainedTokenizerBase):
20 """Base class for Lightning IR tokenizers. Derived classes implement the tokenize method for handling query
21 and document tokenization. It acts as mixin for a transformers.PreTrainedTokenizer_ backbone tokenizer.
22
23 .. _transformers.PreTrainedTokenizer: \
24https://huggingface.co/transformers/main_classes/tokenizer.htmltransformers.PreTrainedTokenizer
25 """
26
27 config_class: type[LightningIRConfig] = LightningIRConfig
28 """Configuration class for the tokenizer."""
29
[docs]
30 def __init__(self, *args, query_length: int | None = 32, doc_length: int | None = 512, **kwargs):
31 """Initializes the tokenizer.
32
33 Args:
34 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
35 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
36 """
37 super().__init__(*args, query_length=query_length, doc_length=doc_length, **kwargs)
38 self.query_length = query_length
39 self.doc_length = doc_length
40
[docs]
41 def tokenize(
42 self, queries: str | Sequence[str] | None = None, docs: str | Sequence[str] | None = None, **kwargs
43 ) -> dict[str, BatchEncoding]:
44 """Tokenizes queries and documents.
45
46 Args:
47 queries (str | Sequence[str] | None): Queries to tokenize. Defaults to None.
48 docs (str | Sequence[str] | None): Documents to tokenize. Defaults to None.
49 Returns:
50 dict[str, BatchEncoding]: Dictionary containing tokenized queries and documents.
51 Raises:
52 NotImplementedError: Must be implemented by the derived class.
53 """
54 raise NotImplementedError
55
[docs]
56 @classmethod
57 def from_pretrained(cls, model_name_or_path: str, *args, **kwargs) -> Self:
58 """Loads a pretrained tokenizer. Wraps the transformers.PreTrainedTokenizer.from_pretrained_ method to return a
59 derived LightningIRTokenizer class. See :class:`.LightningIRTokenizerClassFactory` for more details.
60
61 .. _transformers.PreTrainedTokenizer.from_pretrained: \
62https://huggingface.co/docs/transformers/main_classes/tokenizer.html#transformers.PreTrainedTokenizer.from_pretrained
63
64 .. highlight:: python
65 .. code-block:: python
66
67 >>> Loading using model class and backbone checkpoint
68 >>> type(BiEncoderTokenizer.from_pretrained("bert-base-uncased"))
69 ...
70 <class 'lightning_ir.base.class_factory.BiEncoderBertTokenizerFast'>
71 >>> Loading using base class and backbone checkpoint
72 >>> type(LightningIRTokenizer.from_pretrained("bert-base-uncased", config=BiEncoderConfig()))
73 ...
74 <class 'lightning_ir.base.class_factory.BiEncoderBertTokenizerFast'>
75
76 Args:
77 model_name_or_path (str): Name or path of the pretrained tokenizer.
78 Returns:
79 Self: A derived LightningIRTokenizer consisting of a backbone tokenizer and a LightningIRTokenizer mixin.
80 Raises:
81 ValueError: If called on the abstract class `LightningIRTokenizer` and no config is passed.
82 """
83 # provides AutoTokenizer.from_pretrained support
84 config = kwargs.get("config", None)
85 if cls is LightningIRTokenizer or all(issubclass(base, LightningIRTokenizer) for base in cls.__bases__):
86 # no backbone models found, create derived lightning-ir tokenizer based on backbone model
87 if config is not None:
88 ConfigClass = config.__class__
89 elif model_name_or_path in CHECKPOINT_MAPPING:
90 _config = CHECKPOINT_MAPPING[model_name_or_path]
91 ConfigClass = _config.__class__
92 if config is None:
93 kwargs["config"] = _config
94 elif cls is not LightningIRTokenizer and hasattr(cls, "config_class"):
95 ConfigClass = cls.config_class
96 else:
97 ConfigClass = LightningIRTokenizerClassFactory.get_lightning_ir_config(model_name_or_path)
98 if ConfigClass is None:
99 raise ValueError("Pass a config to `from_pretrained`.")
100 ConfigClass = getattr(ConfigClass, "mixin_config", ConfigClass)
101 backbone_config = LightningIRTokenizerClassFactory.get_backbone_config(model_name_or_path)
102 BackboneTokenizers = TOKENIZER_MAPPING[type(backbone_config)]
103 if kwargs.get("use_fast", True):
104 BackboneTokenizer = BackboneTokenizers[1]
105 else:
106 BackboneTokenizer = BackboneTokenizers[0]
107 cls = LightningIRTokenizerClassFactory(ConfigClass).from_backbone_class(BackboneTokenizer)
108 return cls.from_pretrained(model_name_or_path, *args, **kwargs)
109 config = kwargs.pop("config", None)
110 if config is not None:
111 kwargs.update(config.get_tokenizer_kwargs(cls))
112 return super().from_pretrained(model_name_or_path, *args, **kwargs)
113
114 def _save_pretrained(
115 self,
116 save_directory: str | PathLike,
117 file_names: tuple[str],
118 legacy_format: bool | None = None,
119 filename_prefix: str | None = None,
120 ) -> tuple[str]:
121 # bit of a hack to change the tokenizer class in the stored tokenizer config to only contain the
122 # lightning_ir tokenizer class (removing the backbone tokenizer class)
123 save_files = super()._save_pretrained(save_directory, file_names, legacy_format, filename_prefix)
124 config_file = save_files[0]
125 with open(config_file) as file:
126 tokenizer_config = json.load(file)
127
128 tokenizer_class = None
129 backbone_tokenizer_class = None
130 for base in self.__class__.__bases__:
131 if issubclass(base, LightningIRTokenizer):
132 if tokenizer_class is not None:
133 raise ValueError("Multiple Lightning IR tokenizer classes found.")
134 tokenizer_class = base.__name__
135 continue
136 if issubclass(base, PreTrainedTokenizerBase):
137 backbone_tokenizer_class = base.__name__
138
139 tokenizer_config["tokenizer_class"] = tokenizer_class
140 tokenizer_config["backbone_tokenizer_class"] = backbone_tokenizer_class
141
142 with open(config_file, "w") as file:
143 out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
144 file.write(out_str)
145 return save_files