1"""
2Class factory module for Lightning IR.
3
4This module provides factory classes for creating various components of the Lightning IR library
5by extending Hugging Face Transformers classes.
6"""
7
8from __future__ import annotations
9
10import importlib
11from abc import ABC, abstractmethod
12from pathlib import Path
13from typing import TYPE_CHECKING, Any, Tuple, Type
14
15from transformers import (
16 CONFIG_MAPPING,
17 MODEL_MAPPING,
18 TOKENIZER_MAPPING,
19 PretrainedConfig,
20 PreTrainedModel,
21 PreTrainedTokenizerBase,
22)
23from transformers.models.auto.configuration_auto import model_type_to_module_name
24from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES, get_tokenizer_config
25
26if TYPE_CHECKING:
27 from . import LightningIRConfig, LightningIRModel, LightningIRTokenizer
28
29
30def _get_model_class(config: PretrainedConfig | Type[PretrainedConfig]) -> Type[PreTrainedModel]:
31 # https://github.com/huggingface/transformers/blob/356b3cd71d7bfb51c88fea3e8a0c054f3a457ab9/src/transformers/models/auto/auto_factory.py#L387
32 if isinstance(config, type):
33 supported_models = MODEL_MAPPING[config]
34 else:
35 supported_models = MODEL_MAPPING[type(config)]
36 if not isinstance(supported_models, (list, tuple)):
37 return supported_models
38
39 if isinstance(config, type):
40 # we cannot parse architectures from a config class, we need an instance for this
41 return supported_models[0]
42
43 name_to_model = {model.__name__: model for model in supported_models}
44 architectures = getattr(config, "architectures", [])
45 for arch in architectures:
46 if arch in name_to_model:
47 return name_to_model[arch]
48 elif f"TF{arch}" in name_to_model:
49 return name_to_model[f"TF{arch}"]
50 elif f"Flax{arch}" in name_to_model:
51 return name_to_model[f"Flax{arch}"]
52
53 # If not architecture is set in the config or match the supported models, the first element of the tuple is the
54 # defaults.
55 return supported_models[0]
56
57
[docs]
58class LightningIRClassFactory(ABC):
59 """Base class for creating derived Lightning IR classes from HuggingFace classes."""
60
[docs]
61 def __init__(self, MixinConfig: Type[LightningIRConfig]) -> None:
62 """Creates a new LightningIRClassFactory.
63
64 Args:
65 MixinConfig (Type[LightningIRConfig]): LightningIRConfig mixin class.
66 """
67 if getattr(MixinConfig, "backbone_model_type", None) is not None:
68 MixinConfig = MixinConfig.__bases__[0]
69 self.MixinConfig = MixinConfig
70
[docs]
71 @staticmethod
72 def get_backbone_config(model_name_or_path: str | Path) -> PretrainedConfig:
73 """Grabs the configuration from a checkpoint of a pretrained HuggingFace model.
74
75 Args:
76 model_name_or_path (str | Path): Path to the model or its name.
77 Returns:
78 PretrainedConfig: Configuration of the backbone model.
79 """
80 backbone_model_type = LightningIRClassFactory.get_backbone_model_type(model_name_or_path)
81 return CONFIG_MAPPING[backbone_model_type].from_pretrained(model_name_or_path)
82
[docs]
83 @staticmethod
84 def get_lightning_ir_config(model_name_or_path: str | Path) -> LightningIRConfig | None:
85 """Grabs the Lightning IR configuration from a checkpoint of a pretrained Lightning IR model.
86
87 Args:
88 model_name_or_path (str | Path): Path to the model or its name.
89 Returns:
90 LightningIRConfig | None: Configuration class of the Lightning IR model.
91 """
92 model_type = LightningIRClassFactory.get_lightning_ir_model_type(model_name_or_path)
93 if model_type is None:
94 return None
95 return CONFIG_MAPPING[model_type].from_pretrained(model_name_or_path)
96
[docs]
97 @staticmethod
98 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str:
99 """Grabs the model type from a checkpoint of a pretrained HuggingFace model.
100
101 Args:
102 model_name_or_path (str | Path): Path to the model or its name.
103 Returns:
104 str: Model type of the backbone model.
105 Raises:
106 ValueError: If the type of the model is None in the configuration.
107 """
108 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path, *args, **kwargs)
109 backbone_model_type = config_dict.get("backbone_model_type", None) or config_dict.get("model_type")
110 if backbone_model_type is None:
111 raise ValueError(f"Unable to load PretrainedConfig from {model_name_or_path}")
112 return backbone_model_type
113
[docs]
114 @staticmethod
115 def get_lightning_ir_model_type(model_name_or_path: str | Path) -> str | None:
116 """Grabs the Lightning IR model type from a checkpoint of a pretrained HuggingFace model.
117
118 Args:
119 model_name_or_path (str | Path): Path to the model or its name.
120 Returns:
121 str | None: Model type of the Lightning IR model.
122 Raises:
123 ValueError: If the backbone model type is not found in the configuration.
124 """
125 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path)
126 if "backbone_model_type" not in config_dict:
127 return None
128 return config_dict.get("model_type", None)
129
[docs]
130 def cc_lir_model_type(self, Config: Type[LightningIRConfig]) -> str:
131 """Camel case model type of the Lightning IR model."""
132 return "".join(s.title() for s in Config.model_type.split("-"))
133
[docs]
134 @abstractmethod
135 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Any:
136 """Loads a derived Lightning IR class from a pretrained HuggingFace model. Must be implemented by subclasses.
137
138 Args:
139 model_name_or_path (str | Path): Path to the model or its name.
140 Returns:
141 Any: Derived Lightning IR class.
142 """
143 ...
144
[docs]
145 @abstractmethod
146 def from_backbone_class(self, BackboneClass: Type) -> Type:
147 """Creates a derived Lightning IR class from a backbone HuggingFace class. Must be implemented by subclasses.
148
149 Args:
150 BackboneClass (Type): Backbone class.
151 Returns:
152 Type: Derived Lightning IR class.
153 """
154 ...
155
156
[docs]
157class LightningIRConfigClassFactory(LightningIRClassFactory):
158 """Class factory for creating derived LightningIRConfig classes from HuggingFace configuration classes."""
159
[docs]
160 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRConfig]:
161 """Loads a derived LightningIRConfig from a pretrained HuggingFace model.
162
163 Args:
164 model_name_or_path (str | Path): Path to the model or its name.
165 Returns:
166 Type[LightningIRConfig]: Derived LightningIRConfig.
167 """
168 backbone_config = self.get_backbone_config(model_name_or_path)
169 DerivedLightningIRConfig = self.from_backbone_class(type(backbone_config))
170 return DerivedLightningIRConfig
171
[docs]
172 def from_backbone_class(self, BackboneClass: Type[PretrainedConfig]) -> Type[LightningIRConfig]:
173 """Creates a derived LightningIRConfig from a transformers.PretrainedConfig_ backbone configuration class. If
174 the backbone configuration class is already a derived LightningIRConfig, it is returned as is.
175
176 .. _transformers.PretrainedConfig: \
177https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig
178
179 Args:
180 BackboneClass (Type[PretrainedConfig]): Backbone configuration class.
181 Returns:
182 Type[LightningIRConfig]: Derived LightningIRConfig.
183 """
184 if getattr(BackboneClass, "backbone_model_type", None) is not None:
185 return BackboneClass
186 LightningIRConfigMixin: Type[LightningIRConfig] = CONFIG_MAPPING[self.MixinConfig.model_type]
187
188 DerivedLightningIRConfig = type(
189 f"{self.cc_lir_model_type(LightningIRConfigMixin)}{BackboneClass.__name__}",
190 (LightningIRConfigMixin, BackboneClass),
191 {
192 "model_type": self.MixinConfig.model_type,
193 "backbone_model_type": BackboneClass.model_type,
194 "mixin_config": self.MixinConfig,
195 },
196 )
197 return DerivedLightningIRConfig
198
199
[docs]
200class LightningIRModelClassFactory(LightningIRClassFactory):
201 """Class factory for creating derived LightningIRModel classes from HuggingFace model classes."""
202
[docs]
203 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRModel]:
204 """Loads a derived LightningIRModel from a pretrained HuggingFace model.
205
206 Args:
207 model_name_or_path (str | Path): Path to the model or its name.
208 Returns:
209 Type[LightningIRModel]: Derived LightningIRModel.
210 """
211 backbone_config = self.get_backbone_config(model_name_or_path)
212 BackboneModel = _get_model_class(backbone_config)
213 DerivedLightningIRModel = self.from_backbone_class(BackboneModel)
214 return DerivedLightningIRModel
215
[docs]
216 def from_backbone_class(self, BackboneClass: Type[PreTrainedModel]) -> Type[LightningIRModel]:
217 """Creates a derived LightningIRModel from a transformers.PreTrainedModel_ backbone model. If the backbone model
218 is already a LightningIRModel, it is returned as is.
219
220 .. _transformers.PreTrainedModel: \
221https://huggingface.co/transformers/main_classes/model#transformers.PreTrainedModel
222
223 Args:
224 BackboneClass (Type[PreTrainedModel]): Backbone model class.
225 Returns:
226 Type[LightningIRModel]: Derived LightningIRModel.
227 Raises:
228 ValueError: If the backbone model is not a valid backbone model.
229 ValueError: If the backbone model is not a LightningIRModel and no LightningIRConfig is passed.
230 ValueError: If the LightningIRModel mixin is not registered with the Hugging Face model mapping.
231 """
232 if getattr(BackboneClass.config_class, "backbone_model_type", None) is not None:
233 return BackboneClass
234 BackboneConfig = BackboneClass.config_class
235 if BackboneConfig is None:
236 raise ValueError(
237 f"Model {BackboneClass} is not a valid backbone model because it is missing a `config_class`."
238 )
239
240 LightningIRModelMixin: Type[LightningIRModel] = _get_model_class(self.MixinConfig)
241
242 DerivedLightningIRConfig = LightningIRConfigClassFactory(self.MixinConfig).from_backbone_class(BackboneConfig)
243
244 DerivedLightningIRModel = type(
245 f"{self.cc_lir_model_type(LightningIRModelMixin.config_class)}{BackboneClass.__name__}",
246 (LightningIRModelMixin, BackboneClass),
247 {"config_class": DerivedLightningIRConfig, "_backbone_forward": BackboneClass.forward},
248 )
249 return DerivedLightningIRModel
250
251
[docs]
252class LightningIRTokenizerClassFactory(LightningIRClassFactory):
253 """Class factory for creating derived LightningIRTokenizer classes from HuggingFace tokenizer classes."""
254
[docs]
255 @staticmethod
256 def get_backbone_config(model_name_or_path: str | Path) -> PretrainedConfig:
257 """Grabs the tokenizer configuration class from a checkpoint of a pretrained HuggingFace tokenizer.
258
259 Args:
260 model_name_or_path (str | Path): Path to the tokenizer or its name.
261 Returns:
262 PretrainedConfig: Configuration class of the backbone tokenizer.
263 """
264 backbone_model_type = LightningIRTokenizerClassFactory.get_backbone_model_type(model_name_or_path)
265 return CONFIG_MAPPING[backbone_model_type].from_pretrained(model_name_or_path)
266
[docs]
267 @staticmethod
268 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str:
269 """Grabs the model type from a checkpoint of a pretrained HuggingFace tokenizer.
270
271 Args:
272 model_name_or_path (str | Path): Path to the tokenizer or its name.
273 Returns:
274 str: Model type of the backbone tokenizer.
275 """
276 try:
277 return LightningIRClassFactory.get_backbone_model_type(model_name_or_path, *args, **kwargs)
278 except (OSError, ValueError):
279 # best guess at model type
280 config_dict = get_tokenizer_config(model_name_or_path)
281 backbone_tokenizer_class = config_dict.get("backbone_tokenizer_class", None)
282 if backbone_tokenizer_class is not None:
283 config_class = backbone_tokenizer_class.removesuffix("Fast").replace("Tokenizer", "Config")
284 for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
285 if backbone_tokenizer_class in tokenizers:
286 module_name = model_type_to_module_name(module_name)
287 module = importlib.import_module(f".{module_name}", "transformers.models")
288 if hasattr(module, backbone_tokenizer_class) and hasattr(module, config_class):
289 config = getattr(module, config_class)
290 return config.model_type
291 raise ValueError("No backbone model found in the configuration")
292
[docs]
293 def from_pretrained(
294 self, model_name_or_path: str | Path, *args, use_fast: bool = True, **kwargs
295 ) -> Type[LightningIRTokenizer]:
296 """Loads a derived LightningIRTokenizer from a pretrained HuggingFace tokenizer.
297
298 Args:
299 model_name_or_path (str | Path): Path to the tokenizer or its name.
300 use_fast (bool, optional): Whether to use the fast tokenizer. Defaults to True.
301 Returns:
302 Type[LightningIRTokenizer]: Derived LightningIRTokenizer.
303 Raises:
304 ValueError: If no fast tokenizer is found when `use_fast` is True.
305 ValueError: If no slow tokenizer is found when `use_fast` is False.
306 """
307 backbone_config = self.get_backbone_config(model_name_or_path)
308 BackboneTokenizers = TOKENIZER_MAPPING[type(backbone_config)]
309 DerivedLightningIRTokenizers = self.from_backbone_classes(BackboneTokenizers, type(backbone_config))
310 if use_fast:
311 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[1]
312 if DerivedLightningIRTokenizer is None:
313 raise ValueError("No fast tokenizer found.")
314 else:
315 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[0]
316 if DerivedLightningIRTokenizer is None:
317 raise ValueError("No slow tokenizer found.")
318 return DerivedLightningIRTokenizer
319
[docs]
320 def from_backbone_classes(
321 self,
322 BackboneClasses: Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None],
323 BackboneConfig: Type[PretrainedConfig] | None = None,
324 ) -> Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None]:
325 """Creates derived slow and fastLightningIRTokenizers from a tuple of backbone HuggingFace tokenizer classes.
326
327 Args:
328 BackboneClasses (Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None]):
329 Slow and fast backbone tokenizer classes.
330 BackboneConfig (Type[PretrainedConfig] | None, optional): Backbone configuration class. Defaults to None.
331 Returns:
332 Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None]: Slow and fast derived
333 LightningIRTokenizers.
334 """
335 DerivedLightningIRTokenizers = tuple(
336 None if BackboneClass is None else self.from_backbone_class(BackboneClass)
337 for BackboneClass in BackboneClasses
338 )
339 if DerivedLightningIRTokenizers[1] is not None:
340 DerivedLightningIRTokenizers[1].slow_tokenizer_class = DerivedLightningIRTokenizers[0]
341 return DerivedLightningIRTokenizers
342
[docs]
343 def from_backbone_class(self, BackboneClass: Type[PreTrainedTokenizerBase]) -> Type[LightningIRTokenizer]:
344 """Creates a derived LightningIRTokenizer from a transformers.PreTrainedTokenizerBase_ backbone tokenizer. If
345 the backbone tokenizer is already a LightningIRTokenizer, it is returned as is.
346
347 .. _transformers.PreTrainedTokenizerBase: \
348https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.PreTrainedTokenizerBase
349
350 Args:
351 BackboneClass (Type[PreTrainedTokenizerBase]): Backbone tokenizer class.
352 Returns:
353 Type[LightningIRTokenizer]: Derived LightningIRTokenizer.
354 """
355 if hasattr(BackboneClass, "config_class"):
356 return BackboneClass
357 LightningIRTokenizerMixin = TOKENIZER_MAPPING[self.MixinConfig][0]
358
359 DerivedLightningIRTokenizer = type(
360 f"{self.cc_lir_model_type(LightningIRTokenizerMixin.config_class)}{BackboneClass.__name__}",
361 (LightningIRTokenizerMixin, BackboneClass),
362 {},
363 )
364
365 return DerivedLightningIRTokenizer