1"""
2Class factory module for Lightning IR.
3
4This module provides factory classes for creating various components of the Lightning IR library
5by extending Hugging Face Transformers classes.
6"""
7
8from __future__ import annotations
9
10from abc import ABC, abstractmethod
11from pathlib import Path
12from typing import TYPE_CHECKING, Any, Tuple, Type
13
14from transformers import (
15 CONFIG_MAPPING,
16 MODEL_MAPPING,
17 TOKENIZER_MAPPING,
18 PretrainedConfig,
19 PreTrainedModel,
20 PreTrainedTokenizerBase,
21)
22from transformers.models.auto.tokenization_auto import get_tokenizer_config, tokenizer_class_from_name
23
24if TYPE_CHECKING:
25 from . import LightningIRConfig, LightningIRModel, LightningIRTokenizer
26
27
28def _get_model_class(config: PretrainedConfig | Type[PretrainedConfig]) -> Type[PreTrainedModel]:
29 # https://github.com/huggingface/transformers/blob/356b3cd71d7bfb51c88fea3e8a0c054f3a457ab9/src/transformers/models/auto/auto_factory.py#L387
30 if isinstance(config, type):
31 supported_models = MODEL_MAPPING[config]
32 else:
33 supported_models = MODEL_MAPPING[type(config)]
34 if not isinstance(supported_models, (list, tuple)):
35 return supported_models
36
37 if isinstance(config, type):
38 # we cannot parse architectures from a config class, we need an instance for this
39 return supported_models[0]
40
41 name_to_model = {model.__name__: model for model in supported_models}
42 architectures = getattr(config, "architectures", [])
43 for arch in architectures:
44 if arch in name_to_model:
45 return name_to_model[arch]
46 elif f"TF{arch}" in name_to_model:
47 return name_to_model[f"TF{arch}"]
48 elif f"Flax{arch}" in name_to_model:
49 return name_to_model[f"Flax{arch}"]
50
51 # If not architecture is set in the config or match the supported models, the first element of the tuple is the
52 # defaults.
53 return supported_models[0]
54
55
[docs]
56class LightningIRClassFactory(ABC):
57 """Base class for creating derived Lightning IR classes from HuggingFace classes."""
58
[docs]
59 def __init__(self, MixinConfig: Type[LightningIRConfig]) -> None:
60 """Creates a new LightningIRClassFactory.
61
62 Args:
63 MixinConfig (Type[LightningIRConfig]): LightningIRConfig mixin class.
64 """
65 if getattr(MixinConfig, "backbone_model_type", None) is not None:
66 MixinConfig = MixinConfig.__bases__[0]
67 self.MixinConfig = MixinConfig
68
[docs]
69 @staticmethod
70 def get_backbone_config(model_name_or_path: str | Path) -> PretrainedConfig:
71 """Grabs the configuration from a checkpoint of a pretrained HuggingFace model.
72
73 Args:
74 model_name_or_path (str | Path): Path to the model or its name.
75 Returns:
76 PretrainedConfig: Configuration of the backbone model.
77 """
78 backbone_model_type = LightningIRClassFactory.get_backbone_model_type(model_name_or_path)
79 return CONFIG_MAPPING[backbone_model_type].from_pretrained(model_name_or_path)
80
[docs]
81 @staticmethod
82 def get_lightning_ir_config(model_name_or_path: str | Path) -> LightningIRConfig | None:
83 """Grabs the Lightning IR configuration from a checkpoint of a pretrained Lightning IR model.
84
85 Args:
86 model_name_or_path (str | Path): Path to the model or its name.
87 Returns:
88 LightningIRConfig | None: Configuration class of the Lightning IR model.
89 """
90 model_type = LightningIRClassFactory.get_lightning_ir_model_type(model_name_or_path)
91 if model_type is None:
92 return None
93 return CONFIG_MAPPING[model_type].from_pretrained(model_name_or_path)
94
[docs]
95 @staticmethod
96 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str:
97 """Grabs the model type from a checkpoint of a pretrained HuggingFace model.
98
99 Args:
100 model_name_or_path (str | Path): Path to the model or its name.
101 Returns:
102 str: Model type of the backbone model.
103 Raises:
104 ValueError: If the type of the model is None in the configuration.
105 """
106 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path, *args, **kwargs)
107 backbone_model_type = config_dict.get("backbone_model_type", None) or config_dict.get("model_type")
108 if backbone_model_type is None:
109 raise ValueError(f"Unable to load PretrainedConfig from {model_name_or_path}")
110 return backbone_model_type
111
[docs]
112 @staticmethod
113 def get_lightning_ir_model_type(model_name_or_path: str | Path) -> str | None:
114 """Grabs the Lightning IR model type from a checkpoint of a pretrained HuggingFace model.
115
116 Args:
117 model_name_or_path (str | Path): Path to the model or its name.
118 Returns:
119 str | None: Model type of the Lightning IR model.
120 Raises:
121 ValueError: If the backbone model type is not found in the configuration.
122 """
123 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path)
124 if "backbone_model_type" not in config_dict:
125 return None
126 return config_dict.get("model_type", None)
127
128 @property
129 def cc_lir_model_type(self) -> str:
130 """Camel case model type of the Lightning IR model."""
131 return "".join(s.title() for s in self.MixinConfig.model_type.split("-"))
132
[docs]
133 @abstractmethod
134 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Any:
135 """Loads a derived Lightning IR class from a pretrained HuggingFace model. Must be implemented by subclasses.
136
137 Args:
138 model_name_or_path (str | Path): Path to the model or its name.
139 Returns:
140 Any: Derived Lightning IR class.
141 """
142 ...
143
[docs]
144 @abstractmethod
145 def from_backbone_class(self, BackboneClass: Type) -> Type:
146 """Creates a derived Lightning IR class from a backbone HuggingFace class. Must be implemented by subclasses.
147
148 Args:
149 BackboneClass (Type): Backbone class.
150 Returns:
151 Type: Derived Lightning IR class.
152 """
153 ...
154
155
[docs]
156class LightningIRConfigClassFactory(LightningIRClassFactory):
157 """Class factory for creating derived LightningIRConfig classes from HuggingFace configuration classes."""
158
[docs]
159 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRConfig]:
160 """Loads a derived LightningIRConfig from a pretrained HuggingFace model.
161
162 Args:
163 model_name_or_path (str | Path): Path to the model or its name.
164 Returns:
165 Type[LightningIRConfig]: Derived LightningIRConfig.
166 """
167 backbone_config = self.get_backbone_config(model_name_or_path)
168 DerivedLightningIRConfig = self.from_backbone_class(type(backbone_config))
169 return DerivedLightningIRConfig
170
[docs]
171 def from_backbone_class(self, BackboneClass: Type[PretrainedConfig]) -> Type[LightningIRConfig]:
172 """Creates a derived LightningIRConfig from a transformers.PretrainedConfig_ backbone configuration class. If
173 the backbone configuration class is already a derived LightningIRConfig, it is returned as is.
174
175 .. _transformers.PretrainedConfig: \
176https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig
177
178 Args:
179 BackboneClass (Type[PretrainedConfig]): Backbone configuration class.
180 Returns:
181 Type[LightningIRConfig]: Derived LightningIRConfig.
182 """
183 if getattr(BackboneClass, "backbone_model_type", None) is not None:
184 return BackboneClass
185 LightningIRConfigMixin: Type[LightningIRConfig] = CONFIG_MAPPING[self.MixinConfig.model_type]
186
187 DerivedLightningIRConfig = type(
188 f"{self.cc_lir_model_type}{BackboneClass.__name__}",
189 (LightningIRConfigMixin, BackboneClass),
190 {
191 "model_type": self.MixinConfig.model_type,
192 "backbone_model_type": BackboneClass.model_type,
193 "mixin_config": self.MixinConfig,
194 },
195 )
196 return DerivedLightningIRConfig
197
198
[docs]
199class LightningIRModelClassFactory(LightningIRClassFactory):
200 """Class factory for creating derived LightningIRModel classes from HuggingFace model classes."""
201
[docs]
202 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRModel]:
203 """Loads a derived LightningIRModel from a pretrained HuggingFace model.
204
205 Args:
206 model_name_or_path (str | Path): Path to the model or its name.
207 Returns:
208 Type[LightningIRModel]: Derived LightningIRModel.
209 """
210 backbone_config = self.get_backbone_config(model_name_or_path)
211 BackboneModel = _get_model_class(backbone_config)
212 DerivedLightningIRModel = self.from_backbone_class(BackboneModel)
213 return DerivedLightningIRModel
214
[docs]
215 def from_backbone_class(self, BackboneClass: Type[PreTrainedModel]) -> Type[LightningIRModel]:
216 """Creates a derived LightningIRModel from a transformers.PreTrainedModel_ backbone model. If the backbone model
217 is already a LightningIRModel, it is returned as is.
218
219 .. _transformers.PreTrainedModel: \
220https://huggingface.co/transformers/main_classes/model#transformers.PreTrainedModel
221
222 Args:
223 BackboneClass (Type[PreTrainedModel]): Backbone model class.
224 Returns:
225 Type[LightningIRModel]: Derived LightningIRModel.
226 Raises:
227 ValueError: If the backbone model is not a valid backbone model.
228 ValueError: If the backbone model is not a LightningIRModel and no LightningIRConfig is passed.
229 ValueError: If the LightningIRModel mixin is not registered with the Hugging Face model mapping.
230 """
231 if getattr(BackboneClass.config_class, "backbone_model_type", None) is not None:
232 return BackboneClass
233 BackboneConfig = BackboneClass.config_class
234 if BackboneConfig is None:
235 raise ValueError(
236 f"Model {BackboneClass} is not a valid backbone model because it is missing a `config_class`."
237 )
238
239 LightningIRModelMixin: Type[LightningIRModel] = _get_model_class(self.MixinConfig)
240
241 DerivedLightningIRConfig = LightningIRConfigClassFactory(self.MixinConfig).from_backbone_class(BackboneConfig)
242
243 DerivedLightningIRModel = type(
244 f"{self.cc_lir_model_type}{BackboneClass.__name__}",
245 (LightningIRModelMixin, BackboneClass),
246 {"config_class": DerivedLightningIRConfig, "_backbone_forward": BackboneClass.forward},
247 )
248 return DerivedLightningIRModel
249
250
[docs]
251class LightningIRTokenizerClassFactory(LightningIRClassFactory):
252 """Class factory for creating derived LightningIRTokenizer classes from HuggingFace tokenizer classes."""
253
[docs]
254 @staticmethod
255 def get_backbone_config(model_name_or_path: str | Path) -> PretrainedConfig:
256 """Grabs the tokenizer configuration class from a checkpoint of a pretrained HuggingFace tokenizer.
257
258 Args:
259 model_name_or_path (str | Path): Path to the tokenizer or its name.
260 Returns:
261 PretrainedConfig: Configuration class of the backbone tokenizer.
262 """
263 backbone_model_type = LightningIRTokenizerClassFactory.get_backbone_model_type(model_name_or_path)
264 return CONFIG_MAPPING[backbone_model_type].from_pretrained(model_name_or_path)
265
[docs]
266 @staticmethod
267 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str:
268 """Grabs the model type from a checkpoint of a pretrained HuggingFace tokenizer.
269
270 Args:
271 model_name_or_path (str | Path): Path to the tokenizer or its name.
272 Returns:
273 str: Model type of the backbone tokenizer.
274 """
275 try:
276 return LightningIRClassFactory.get_backbone_model_type(model_name_or_path, *args, **kwargs)
277 except (OSError, ValueError):
278 # best guess at model type
279 config_dict = get_tokenizer_config(model_name_or_path)
280 backbone_tokenizer_class = config_dict.get("backbone_tokenizer_class", None)
281 if backbone_tokenizer_class is not None:
282 Tokenizer = tokenizer_class_from_name(backbone_tokenizer_class)
283 for config, tokenizers in TOKENIZER_MAPPING.items():
284 if Tokenizer in tokenizers:
285 return getattr(config, "model_type")
286 raise ValueError("No backbone model found in the configuration")
287
[docs]
288 def from_pretrained(
289 self, model_name_or_path: str | Path, *args, use_fast: bool = True, **kwargs
290 ) -> Type[LightningIRTokenizer]:
291 """Loads a derived LightningIRTokenizer from a pretrained HuggingFace tokenizer.
292
293 Args:
294 model_name_or_path (str | Path): Path to the tokenizer or its name.
295 use_fast (bool, optional): Whether to use the fast tokenizer. Defaults to True.
296 Returns:
297 Type[LightningIRTokenizer]: Derived LightningIRTokenizer.
298 Raises:
299 ValueError: If no fast tokenizer is found when `use_fast` is True.
300 ValueError: If no slow tokenizer is found when `use_fast` is False.
301 """
302 backbone_config = self.get_backbone_config(model_name_or_path)
303 BackboneTokenizers = TOKENIZER_MAPPING[type(backbone_config)]
304 DerivedLightningIRTokenizers = self.from_backbone_classes(BackboneTokenizers, type(backbone_config))
305 if use_fast:
306 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[1]
307 if DerivedLightningIRTokenizer is None:
308 raise ValueError("No fast tokenizer found.")
309 else:
310 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[0]
311 if DerivedLightningIRTokenizer is None:
312 raise ValueError("No slow tokenizer found.")
313 return DerivedLightningIRTokenizer
314
[docs]
315 def from_backbone_classes(
316 self,
317 BackboneClasses: Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None],
318 BackboneConfig: Type[PretrainedConfig] | None = None,
319 ) -> Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None]:
320 """Creates derived slow and fastLightningIRTokenizers from a tuple of backbone HuggingFace tokenizer classes.
321
322 Args:
323 BackboneClasses (Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None]):
324 Slow and fast backbone tokenizer classes.
325 BackboneConfig (Type[PretrainedConfig] | None, optional): Backbone configuration class. Defaults to None.
326 Returns:
327 Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None]: Slow and fast derived
328 LightningIRTokenizers.
329 """
330 DerivedLightningIRTokenizers = tuple(
331 None if BackboneClass is None else self.from_backbone_class(BackboneClass)
332 for BackboneClass in BackboneClasses
333 )
334 if DerivedLightningIRTokenizers[1] is not None:
335 DerivedLightningIRTokenizers[1].slow_tokenizer_class = DerivedLightningIRTokenizers[0]
336 return DerivedLightningIRTokenizers
337
[docs]
338 def from_backbone_class(self, BackboneClass: Type[PreTrainedTokenizerBase]) -> Type[LightningIRTokenizer]:
339 """Creates a derived LightningIRTokenizer from a transformers.PreTrainedTokenizerBase_ backbone tokenizer. If
340 the backbone tokenizer is already a LightningIRTokenizer, it is returned as is.
341
342 .. _transformers.PreTrainedTokenizerBase: \
343https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.PreTrainedTokenizerBase
344
345 Args:
346 BackboneClass (Type[PreTrainedTokenizerBase]): Backbone tokenizer class.
347 Returns:
348 Type[LightningIRTokenizer]: Derived LightningIRTokenizer.
349 """
350 if hasattr(BackboneClass, "config_class"):
351 return BackboneClass
352 LightningIRTokenizerMixin = TOKENIZER_MAPPING[self.MixinConfig][0]
353
354 DerivedLightningIRTokenizer = type(
355 f"{self.cc_lir_model_type}{BackboneClass.__name__}", (LightningIRTokenizerMixin, BackboneClass), {}
356 )
357
358 return DerivedLightningIRTokenizer