1"""
2Class factory module for Lightning IR.
3
4This module provides factory classes for creating various components of the Lightning IR library
5by extending Hugging Face Transformers classes.
6"""
7
8from __future__ import annotations
9
10from abc import ABC, abstractmethod
11from pathlib import Path
12from typing import TYPE_CHECKING, Any, Tuple, Type
13
14from transformers import (
15 CONFIG_MAPPING,
16 MODEL_MAPPING,
17 TOKENIZER_MAPPING,
18 PretrainedConfig,
19 PreTrainedModel,
20 PreTrainedTokenizerBase,
21)
22from transformers.models.auto.tokenization_auto import get_tokenizer_config, tokenizer_class_from_name
23
24if TYPE_CHECKING:
25 from . import LightningIRConfig, LightningIRModel, LightningIRTokenizer
26
27
28def _get_model_class(config: PretrainedConfig | Type[PretrainedConfig]) -> Type[PreTrainedModel]:
29 # https://github.com/huggingface/transformers/blob/356b3cd71d7bfb51c88fea3e8a0c054f3a457ab9/src/transformers/models/auto/auto_factory.py#L387
30 if isinstance(config, type):
31 supported_models = MODEL_MAPPING[config]
32 else:
33 supported_models = MODEL_MAPPING[type(config)]
34 if not isinstance(supported_models, (list, tuple)):
35 return supported_models
36
37 if isinstance(config, type):
38 # we cannot parse architectures from a config class, we need an instance for this
39 return supported_models[0]
40
41 name_to_model = {model.__name__: model for model in supported_models}
42 architectures = getattr(config, "architectures", [])
43 for arch in architectures:
44 if arch in name_to_model:
45 return name_to_model[arch]
46 elif f"TF{arch}" in name_to_model:
47 return name_to_model[f"TF{arch}"]
48 elif f"Flax{arch}" in name_to_model:
49 return name_to_model[f"Flax{arch}"]
50
51 # If not architecture is set in the config or match the supported models, the first element of the tuple is the
52 # defaults.
53 return supported_models[0]
54
55
[docs]
56class LightningIRClassFactory(ABC):
57 """Base class for creating derived Lightning IR classes from HuggingFace classes."""
58
[docs]
59 def __init__(self, MixinConfig: Type[LightningIRConfig]) -> None:
60 """Creates a new LightningIRClassFactory.
61
62 :param MixinConfig: LightningIRConfig mixin class
63 :type MixinConfig: Type[LightningIRConfig]
64 """
65 if getattr(MixinConfig, "backbone_model_type", None) is not None:
66 MixinConfig = MixinConfig.__bases__[0]
67 self.MixinConfig = MixinConfig
68
[docs]
69 @staticmethod
70 def get_backbone_config(model_name_or_path: str | Path) -> PretrainedConfig:
71 """Grabs the configuration from a checkpoint of a pretrained HuggingFace model.
72
73 :param model_name_or_path: Path to the model or its name
74 :type model_name_or_path: str | Path
75 :return: Configuration of the backbone model
76 :rtype: PretrainedConfig
77 """
78 backbone_model_type = LightningIRClassFactory.get_backbone_model_type(model_name_or_path)
79 return CONFIG_MAPPING[backbone_model_type].from_pretrained(model_name_or_path)
80
[docs]
81 @staticmethod
82 def get_lightning_ir_config(model_name_or_path: str | Path) -> LightningIRConfig | None:
83 """Grabs the Lightning IR configuration from a checkpoint of a pretrained Lightning IR model.
84
85 :param model_name_or_path: Path to the model or its name
86 :type model_name_or_path: str | Path
87 :return: Configuration class of the Lightning IR model
88 :rtype: LightningIRConfig | None
89 """
90 model_type = LightningIRClassFactory.get_lightning_ir_model_type(model_name_or_path)
91 if model_type is None:
92 return None
93 return CONFIG_MAPPING[model_type].from_pretrained(model_name_or_path)
94
[docs]
95 @staticmethod
96 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str:
97 """Grabs the model type from a checkpoint of a pretrained HuggingFace model.
98
99 :param model_name_or_path: Path to the model or its name
100 :type model_name_or_path: str | Path
101 :return: Model type of the backbone model
102 :rtype: str
103 """
104 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path, *args, **kwargs)
105 backbone_model_type = config_dict.get("backbone_model_type", None) or config_dict.get("model_type")
106 if backbone_model_type is None:
107 raise ValueError(f"Unable to load PretrainedConfig from {model_name_or_path}")
108 return backbone_model_type
109
[docs]
110 @staticmethod
111 def get_lightning_ir_model_type(model_name_or_path: str | Path) -> str | None:
112 """Grabs the Lightning IR model type from a checkpoint of a pretrained HuggingFace model.
113
114 :param model_name_or_path: Path to the model or its name
115 :type model_name_or_path: str | Path
116 :return: Model type of the Lightning IR model
117 :rtype: str | None
118 """
119 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path)
120 if "backbone_model_type" not in config_dict:
121 return None
122 return config_dict.get("model_type", None)
123
124 @property
125 def cc_lir_model_type(self) -> str:
126 """Camel case model type of the Lightning IR model."""
127 return "".join(s.title() for s in self.MixinConfig.model_type.split("-"))
128
[docs]
129 @abstractmethod
130 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Any:
131 """Loads a derived Lightning IR class from a pretrained HuggingFace model. Must be implemented by subclasses.
132
133 :param model_name_or_path: Path to the model or its name
134 :type model_name_or_path: str | Path
135 :return: Derived Lightning IR class
136 :rtype: Any
137 """
138 ...
139
[docs]
140 @abstractmethod
141 def from_backbone_class(self, BackboneClass: Type) -> Type:
142 """Creates a derived Lightning IR class from a backbone HuggingFace class. Must be implemented by subclasses.
143
144 :param BackboneClass: Backbone class
145 :type BackboneClass: Type
146 :return: Derived Lightning IR class
147 :rtype: Type
148 """
149 ...
150
151
[docs]
152class LightningIRConfigClassFactory(LightningIRClassFactory):
153 """Class factory for creating derived LightningIRConfig classes from HuggingFace configuration classes."""
154
[docs]
155 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRConfig]:
156 """Loads a derived LightningIRConfig from a pretrained HuggingFace model.
157
158 :param model_name_or_path: Path to the model or its name
159 :type model_name_or_path: str | Path
160 :return: Derived LightningIRConfig
161 :rtype: Type[LightningIRConfig]
162 """
163 backbone_config = self.get_backbone_config(model_name_or_path)
164 DerivedLightningIRConfig = self.from_backbone_class(type(backbone_config))
165 return DerivedLightningIRConfig
166
[docs]
167 def from_backbone_class(self, BackboneClass: Type[PretrainedConfig]) -> Type[LightningIRConfig]:
168 """Creates a derived LightningIRConfig from a transformers.PretrainedConfig_ backbone configuration class. If
169 the backbone configuration class is already a dervied LightningIRConfig, it is returned as is.
170
171 .. _transformers.PretrainedConfig: \
172https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig
173
174 :param BackboneClass: Backbone configuration class
175 :type BackboneClass: Type[PretrainedConfig]
176 :return: Derived LightningIRConfig
177 :rtype: Type[LightningIRConfig]
178 """
179 if getattr(BackboneClass, "backbone_model_type", None) is not None:
180 return BackboneClass
181 LightningIRConfigMixin: Type[LightningIRConfig] = CONFIG_MAPPING[self.MixinConfig.model_type]
182
183 DerivedLightningIRConfig = type(
184 f"{self.cc_lir_model_type}{BackboneClass.__name__}",
185 (LightningIRConfigMixin, BackboneClass),
186 {
187 "model_type": self.MixinConfig.model_type,
188 "backbone_model_type": BackboneClass.model_type,
189 "mixin_config": self.MixinConfig,
190 },
191 )
192 return DerivedLightningIRConfig
193
194
[docs]
195class LightningIRModelClassFactory(LightningIRClassFactory):
196 """Class factory for creating derived LightningIRModel classes from HuggingFace model classes."""
197
[docs]
198 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRModel]:
199 """Loads a derived LightningIRModel from a pretrained HuggingFace model.
200
201 :param model_name_or_path: Path to the model or its name
202 :type model_name_or_path: str | Path
203 :return: Derived LightningIRModel
204 :rtype: Type[LightningIRModel]
205 """
206 backbone_config = self.get_backbone_config(model_name_or_path)
207 BackboneModel = _get_model_class(backbone_config)
208 DerivedLightningIRModel = self.from_backbone_class(BackboneModel)
209 return DerivedLightningIRModel
210
[docs]
211 def from_backbone_class(self, BackboneClass: Type[PreTrainedModel]) -> Type[LightningIRModel]:
212 """Creates a derived LightningIRModel from a transformers.PreTrainedModel_ backbone model. If the backbone model
213 is already a LightningIRModel, it is returned as is.
214
215 .. _transformers.PreTrainedModel: \
216https://huggingface.co/transformers/main_classes/model#transformers.PreTrainedModel
217
218 :param BackboneClass: Backbone model
219 :type BackboneClass: Type[PreTrainedModel]
220 :raises ValueError: If the backbone model is not a valid backbone model.
221 :raises ValueError: If the backbone model is not a LightningIRModel and no LightningIRConfig is passed.
222 :raises ValueError: If the LightningIRModel mixin is not registered with the Hugging Face model mapping.
223 :return: The derived LightningIRModel
224 :rtype: Type[LightningIRModel]
225 """
226 if getattr(BackboneClass.config_class, "backbone_model_type", None) is not None:
227 return BackboneClass
228 BackboneConfig = BackboneClass.config_class
229 if BackboneConfig is None:
230 raise ValueError(
231 f"Model {BackboneClass} is not a valid backbone model because it is missing a `config_class`."
232 )
233
234 LightningIRModelMixin: Type[LightningIRModel] = _get_model_class(self.MixinConfig)
235
236 DerivedLightningIRConfig = LightningIRConfigClassFactory(self.MixinConfig).from_backbone_class(BackboneConfig)
237
238 DerivedLightningIRModel = type(
239 f"{self.cc_lir_model_type}{BackboneClass.__name__}",
240 (LightningIRModelMixin, BackboneClass),
241 {"config_class": DerivedLightningIRConfig, "_backbone_forward": BackboneClass.forward},
242 )
243 return DerivedLightningIRModel
244
245
[docs]
246class LightningIRTokenizerClassFactory(LightningIRClassFactory):
247 """Class factory for creating derived LightningIRTokenizer classes from HuggingFace tokenizer classes."""
248
[docs]
249 @staticmethod
250 def get_backbone_config(model_name_or_path: str | Path) -> PretrainedConfig:
251 """Grabs the tokenizer configuration class from a checkpoint of a pretrained HuggingFace tokenizer.
252
253 :param model_name_or_path: Path to the tokenizer or its name
254 :type model_name_or_path: str | Path
255 :return: Configuration class of the backbone tokenizer
256 :rtype: PretrainedConfig
257 """
258 backbone_model_type = LightningIRTokenizerClassFactory.get_backbone_model_type(model_name_or_path)
259 return CONFIG_MAPPING[backbone_model_type].from_pretrained(model_name_or_path)
260
[docs]
261 @staticmethod
262 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str:
263 """Grabs the model type from a checkpoint of a pretrained HuggingFace tokenizer.
264
265 :param model_name_or_path: Path to the tokenizer or its name
266 :type model_name_or_path: str | Path
267 :return: Model type of the backbone tokenizer
268 :rtype: str
269 """
270 try:
271 return LightningIRClassFactory.get_backbone_model_type(model_name_or_path, *args, **kwargs)
272 except (OSError, ValueError):
273 # best guess at model type
274 config_dict = get_tokenizer_config(model_name_or_path)
275 backbone_tokenizer_class = config_dict.get("backbone_tokenizer_class", None)
276 if backbone_tokenizer_class is not None:
277 Tokenizer = tokenizer_class_from_name(backbone_tokenizer_class)
278 for config, tokenizers in TOKENIZER_MAPPING.items():
279 if Tokenizer in tokenizers:
280 return getattr(config, "model_type")
281 raise ValueError("No backbone model found in the configuration")
282
[docs]
283 def from_pretrained(
284 self, model_name_or_path: str | Path, *args, use_fast: bool = True, **kwargs
285 ) -> Type[LightningIRTokenizer]:
286 """Loads a derived LightningIRTokenizer from a pretrained HuggingFace tokenizer.
287
288 :param model_name_or_path: Path to the tokenizer or its name
289 :type model_name_or_path: str | Path
290 :param use_fast: Whether to use the fast or slow tokenizer, defaults to True
291 :type use_fast: bool, optional
292 :raises ValueError: If use_fast is True and no fast tokenizer is found
293 :raises ValueError: If use_fast is False and no slow tokenizer is found
294 :return: Derived LightningIRTokenizer
295 :rtype: Type[LightningIRTokenizer]
296 """
297 backbone_config = self.get_backbone_config(model_name_or_path)
298 BackboneTokenizers = TOKENIZER_MAPPING[type(backbone_config)]
299 DerivedLightningIRTokenizers = self.from_backbone_classes(BackboneTokenizers, type(backbone_config))
300 if use_fast:
301 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[1]
302 if DerivedLightningIRTokenizer is None:
303 raise ValueError("No fast tokenizer found.")
304 else:
305 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[0]
306 if DerivedLightningIRTokenizer is None:
307 raise ValueError("No slow tokenizer found.")
308 return DerivedLightningIRTokenizer
309
[docs]
310 def from_backbone_classes(
311 self,
312 BackboneClasses: Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None],
313 BackboneConfig: Type[PretrainedConfig] | None = None,
314 ) -> Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None]:
315 """Creates derived slow and fastLightningIRTokenizers from a tuple of backbone HuggingFace tokenizer classes.
316
317 :param BackboneClasses: Slow and fast backbone tokenizer classes
318 :type BackboneClasses: Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None]
319 :param BackboneConfig: Backbone configuration class, defaults to None
320 :type BackboneConfig: Type[PretrainedConfig], optional
321 :return: Slow and fast derived LightningIRTokenizers
322 :rtype: Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None]
323 """
324 DerivedLightningIRTokenizers = tuple(
325 None if BackboneClass is None else self.from_backbone_class(BackboneClass)
326 for BackboneClass in BackboneClasses
327 )
328 if DerivedLightningIRTokenizers[1] is not None:
329 DerivedLightningIRTokenizers[1].slow_tokenizer_class = DerivedLightningIRTokenizers[0]
330 return DerivedLightningIRTokenizers
331
[docs]
332 def from_backbone_class(self, BackboneClass: Type[PreTrainedTokenizerBase]) -> Type[LightningIRTokenizer]:
333 """Creates a derived LightningIRTokenizer from a transformers.PreTrainedTokenizerBase_ backbone tokenizer. If
334 the backbone tokenizer is already a LightningIRTokenizer, it is returned as is.
335
336 .. _transformers.PreTrainedTokenizerBase: \
337https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.PreTrainedTokenizerBase
338
339 :param BackboneClass: Backbone tokenizer class
340 :type BackboneClass: Type[PreTrainedTokenizerBase]
341 :return: Derived LightningIRTokenizer
342 :rtype: Type[LightningIRTokenizer]
343 """
344 if hasattr(BackboneClass, "config_class"):
345 return BackboneClass
346 LightningIRTokenizerMixin = TOKENIZER_MAPPING[self.MixinConfig][0]
347
348 DerivedLightningIRTokenizer = type(
349 f"{self.cc_lir_model_type}{BackboneClass.__name__}", (LightningIRTokenizerMixin, BackboneClass), {}
350 )
351
352 return DerivedLightningIRTokenizer