Source code for lightning_ir.base.class_factory

  1"""
  2Class factory module for Lightning IR.
  3
  4This module provides factory classes for creating various components of the Lightning IR library
  5by extending Hugging Face Transformers classes.
  6"""
  7
  8from __future__ import annotations
  9
 10import importlib
 11from abc import ABC, abstractmethod
 12from pathlib import Path
 13from typing import TYPE_CHECKING, Any, Tuple, Type
 14
 15from transformers import (
 16    CONFIG_MAPPING,
 17    MODEL_MAPPING,
 18    TOKENIZER_MAPPING,
 19    PretrainedConfig,
 20    PreTrainedModel,
 21    PreTrainedTokenizerBase,
 22)
 23from transformers.models.auto.configuration_auto import model_type_to_module_name
 24from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES, get_tokenizer_config
 25
 26if TYPE_CHECKING:
 27    from . import LightningIRConfig, LightningIRModel, LightningIRTokenizer
 28
 29
 30def _get_model_class(config: PretrainedConfig | Type[PretrainedConfig]) -> Type[PreTrainedModel]:
 31    # https://github.com/huggingface/transformers/blob/356b3cd71d7bfb51c88fea3e8a0c054f3a457ab9/src/transformers/models/auto/auto_factory.py#L387
 32    if isinstance(config, type):
 33        supported_models = MODEL_MAPPING[config]
 34    else:
 35        supported_models = MODEL_MAPPING[type(config)]
 36    if not isinstance(supported_models, (list, tuple)):
 37        return supported_models
 38
 39    if isinstance(config, type):
 40        # we cannot parse architectures from a config class, we need an instance for this
 41        return supported_models[0]
 42
 43    name_to_model = {model.__name__: model for model in supported_models}
 44    architectures = getattr(config, "architectures", [])
 45    for arch in architectures:
 46        if arch in name_to_model:
 47            return name_to_model[arch]
 48        elif f"TF{arch}" in name_to_model:
 49            return name_to_model[f"TF{arch}"]
 50        elif f"Flax{arch}" in name_to_model:
 51            return name_to_model[f"Flax{arch}"]
 52
 53    # If not architecture is set in the config or match the supported models, the first element of the tuple is the
 54    # defaults.
 55    return supported_models[0]
 56
 57
[docs] 58class LightningIRClassFactory(ABC): 59 """Base class for creating derived Lightning IR classes from HuggingFace classes.""" 60
[docs] 61 def __init__(self, MixinConfig: Type[LightningIRConfig]) -> None: 62 """Creates a new LightningIRClassFactory. 63 64 Args: 65 MixinConfig (Type[LightningIRConfig]): LightningIRConfig mixin class. 66 """ 67 if getattr(MixinConfig, "backbone_model_type", None) is not None: 68 MixinConfig = MixinConfig.__bases__[0] 69 self.MixinConfig = MixinConfig
70
[docs] 71 @staticmethod 72 def get_backbone_config(model_name_or_path: str | Path) -> PretrainedConfig: 73 """Grabs the configuration from a checkpoint of a pretrained HuggingFace model. 74 75 Args: 76 model_name_or_path (str | Path): Path to the model or its name. 77 Returns: 78 PretrainedConfig: Configuration of the backbone model. 79 """ 80 backbone_model_type = LightningIRClassFactory.get_backbone_model_type(model_name_or_path) 81 return CONFIG_MAPPING[backbone_model_type].from_pretrained(model_name_or_path)
82
[docs] 83 @staticmethod 84 def get_lightning_ir_config(model_name_or_path: str | Path) -> LightningIRConfig | None: 85 """Grabs the Lightning IR configuration from a checkpoint of a pretrained Lightning IR model. 86 87 Args: 88 model_name_or_path (str | Path): Path to the model or its name. 89 Returns: 90 LightningIRConfig | None: Configuration class of the Lightning IR model. 91 """ 92 model_type = LightningIRClassFactory.get_lightning_ir_model_type(model_name_or_path) 93 if model_type is None: 94 return None 95 return CONFIG_MAPPING[model_type].from_pretrained(model_name_or_path)
96
[docs] 97 @staticmethod 98 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str: 99 """Grabs the model type from a checkpoint of a pretrained HuggingFace model. 100 101 Args: 102 model_name_or_path (str | Path): Path to the model or its name. 103 Returns: 104 str: Model type of the backbone model. 105 Raises: 106 ValueError: If the type of the model is None in the configuration. 107 """ 108 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path, *args, **kwargs) 109 backbone_model_type = config_dict.get("backbone_model_type", None) or config_dict.get("model_type") 110 if backbone_model_type is None: 111 raise ValueError(f"Unable to load PretrainedConfig from {model_name_or_path}") 112 return backbone_model_type
113
[docs] 114 @staticmethod 115 def get_lightning_ir_model_type(model_name_or_path: str | Path) -> str | None: 116 """Grabs the Lightning IR model type from a checkpoint of a pretrained HuggingFace model. 117 118 Args: 119 model_name_or_path (str | Path): Path to the model or its name. 120 Returns: 121 str | None: Model type of the Lightning IR model. 122 Raises: 123 ValueError: If the backbone model type is not found in the configuration. 124 """ 125 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path) 126 if "backbone_model_type" not in config_dict: 127 return None 128 return config_dict.get("model_type", None)
129
[docs] 130 def cc_lir_model_type(self, Config: Type[LightningIRConfig]) -> str: 131 """Camel case model type of the Lightning IR model.""" 132 return "".join(s.title() for s in Config.model_type.split("-"))
133
[docs] 134 @abstractmethod 135 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Any: 136 """Loads a derived Lightning IR class from a pretrained HuggingFace model. Must be implemented by subclasses. 137 138 Args: 139 model_name_or_path (str | Path): Path to the model or its name. 140 Returns: 141 Any: Derived Lightning IR class. 142 """ 143 ...
144
[docs] 145 @abstractmethod 146 def from_backbone_class(self, BackboneClass: Type) -> Type: 147 """Creates a derived Lightning IR class from a backbone HuggingFace class. Must be implemented by subclasses. 148 149 Args: 150 BackboneClass (Type): Backbone class. 151 Returns: 152 Type: Derived Lightning IR class. 153 """ 154 ...
155 156
[docs] 157class LightningIRConfigClassFactory(LightningIRClassFactory): 158 """Class factory for creating derived LightningIRConfig classes from HuggingFace configuration classes.""" 159
[docs] 160 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRConfig]: 161 """Loads a derived LightningIRConfig from a pretrained HuggingFace model. 162 163 Args: 164 model_name_or_path (str | Path): Path to the model or its name. 165 Returns: 166 Type[LightningIRConfig]: Derived LightningIRConfig. 167 """ 168 backbone_config = self.get_backbone_config(model_name_or_path) 169 DerivedLightningIRConfig = self.from_backbone_class(type(backbone_config)) 170 return DerivedLightningIRConfig
171
[docs] 172 def from_backbone_class(self, BackboneClass: Type[PretrainedConfig]) -> Type[LightningIRConfig]: 173 """Creates a derived LightningIRConfig from a transformers.PretrainedConfig_ backbone configuration class. If 174 the backbone configuration class is already a derived LightningIRConfig, it is returned as is. 175 176 .. _transformers.PretrainedConfig: \ 177https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig 178 179 Args: 180 BackboneClass (Type[PretrainedConfig]): Backbone configuration class. 181 Returns: 182 Type[LightningIRConfig]: Derived LightningIRConfig. 183 """ 184 if getattr(BackboneClass, "backbone_model_type", None) is not None: 185 return BackboneClass 186 LightningIRConfigMixin: Type[LightningIRConfig] = CONFIG_MAPPING[self.MixinConfig.model_type] 187 188 DerivedLightningIRConfig = type( 189 f"{self.cc_lir_model_type(LightningIRConfigMixin)}{BackboneClass.__name__}", 190 (LightningIRConfigMixin, BackboneClass), 191 { 192 "model_type": self.MixinConfig.model_type, 193 "backbone_model_type": BackboneClass.model_type, 194 "mixin_config": self.MixinConfig, 195 }, 196 ) 197 return DerivedLightningIRConfig
198 199
[docs] 200class LightningIRModelClassFactory(LightningIRClassFactory): 201 """Class factory for creating derived LightningIRModel classes from HuggingFace model classes.""" 202
[docs] 203 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRModel]: 204 """Loads a derived LightningIRModel from a pretrained HuggingFace model. 205 206 Args: 207 model_name_or_path (str | Path): Path to the model or its name. 208 Returns: 209 Type[LightningIRModel]: Derived LightningIRModel. 210 """ 211 backbone_config = self.get_backbone_config(model_name_or_path) 212 BackboneModel = _get_model_class(backbone_config) 213 DerivedLightningIRModel = self.from_backbone_class(BackboneModel) 214 return DerivedLightningIRModel
215
[docs] 216 def from_backbone_class(self, BackboneClass: Type[PreTrainedModel]) -> Type[LightningIRModel]: 217 """Creates a derived LightningIRModel from a transformers.PreTrainedModel_ backbone model. If the backbone model 218 is already a LightningIRModel, it is returned as is. 219 220 .. _transformers.PreTrainedModel: \ 221https://huggingface.co/transformers/main_classes/model#transformers.PreTrainedModel 222 223 Args: 224 BackboneClass (Type[PreTrainedModel]): Backbone model class. 225 Returns: 226 Type[LightningIRModel]: Derived LightningIRModel. 227 Raises: 228 ValueError: If the backbone model is not a valid backbone model. 229 ValueError: If the backbone model is not a LightningIRModel and no LightningIRConfig is passed. 230 ValueError: If the LightningIRModel mixin is not registered with the Hugging Face model mapping. 231 """ 232 if getattr(BackboneClass.config_class, "backbone_model_type", None) is not None: 233 return BackboneClass 234 BackboneConfig = BackboneClass.config_class 235 if BackboneConfig is None: 236 raise ValueError( 237 f"Model {BackboneClass} is not a valid backbone model because it is missing a `config_class`." 238 ) 239 240 LightningIRModelMixin: Type[LightningIRModel] = _get_model_class(self.MixinConfig) 241 242 DerivedLightningIRConfig = LightningIRConfigClassFactory(self.MixinConfig).from_backbone_class(BackboneConfig) 243 244 DerivedLightningIRModel = type( 245 f"{self.cc_lir_model_type(LightningIRModelMixin.config_class)}{BackboneClass.__name__}", 246 (LightningIRModelMixin, BackboneClass), 247 {"config_class": DerivedLightningIRConfig, "_backbone_forward": BackboneClass.forward}, 248 ) 249 return DerivedLightningIRModel
250 251
[docs] 252class LightningIRTokenizerClassFactory(LightningIRClassFactory): 253 """Class factory for creating derived LightningIRTokenizer classes from HuggingFace tokenizer classes.""" 254
[docs] 255 @staticmethod 256 def get_backbone_config(model_name_or_path: str | Path) -> PretrainedConfig: 257 """Grabs the tokenizer configuration class from a checkpoint of a pretrained HuggingFace tokenizer. 258 259 Args: 260 model_name_or_path (str | Path): Path to the tokenizer or its name. 261 Returns: 262 PretrainedConfig: Configuration class of the backbone tokenizer. 263 """ 264 backbone_model_type = LightningIRTokenizerClassFactory.get_backbone_model_type(model_name_or_path) 265 return CONFIG_MAPPING[backbone_model_type].from_pretrained(model_name_or_path)
266
[docs] 267 @staticmethod 268 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str: 269 """Grabs the model type from a checkpoint of a pretrained HuggingFace tokenizer. 270 271 Args: 272 model_name_or_path (str | Path): Path to the tokenizer or its name. 273 Returns: 274 str: Model type of the backbone tokenizer. 275 """ 276 try: 277 return LightningIRClassFactory.get_backbone_model_type(model_name_or_path, *args, **kwargs) 278 except (OSError, ValueError): 279 # best guess at model type 280 config_dict = get_tokenizer_config(model_name_or_path) 281 backbone_tokenizer_class = config_dict.get("backbone_tokenizer_class", None) 282 if backbone_tokenizer_class is not None: 283 config_class = backbone_tokenizer_class.removesuffix("Fast").replace("Tokenizer", "Config") 284 for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items(): 285 if backbone_tokenizer_class in tokenizers: 286 module_name = model_type_to_module_name(module_name) 287 module = importlib.import_module(f".{module_name}", "transformers.models") 288 if hasattr(module, backbone_tokenizer_class) and hasattr(module, config_class): 289 config = getattr(module, config_class) 290 return config.model_type 291 raise ValueError("No backbone model found in the configuration")
292
[docs] 293 def from_pretrained( 294 self, model_name_or_path: str | Path, *args, use_fast: bool = True, **kwargs 295 ) -> Type[LightningIRTokenizer]: 296 """Loads a derived LightningIRTokenizer from a pretrained HuggingFace tokenizer. 297 298 Args: 299 model_name_or_path (str | Path): Path to the tokenizer or its name. 300 use_fast (bool, optional): Whether to use the fast tokenizer. Defaults to True. 301 Returns: 302 Type[LightningIRTokenizer]: Derived LightningIRTokenizer. 303 Raises: 304 ValueError: If no fast tokenizer is found when `use_fast` is True. 305 ValueError: If no slow tokenizer is found when `use_fast` is False. 306 """ 307 backbone_config = self.get_backbone_config(model_name_or_path) 308 BackboneTokenizers = TOKENIZER_MAPPING[type(backbone_config)] 309 DerivedLightningIRTokenizers = self.from_backbone_classes(BackboneTokenizers, type(backbone_config)) 310 if use_fast: 311 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[1] 312 if DerivedLightningIRTokenizer is None: 313 raise ValueError("No fast tokenizer found.") 314 else: 315 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[0] 316 if DerivedLightningIRTokenizer is None: 317 raise ValueError("No slow tokenizer found.") 318 return DerivedLightningIRTokenizer
319
[docs] 320 def from_backbone_classes( 321 self, 322 BackboneClasses: Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None], 323 BackboneConfig: Type[PretrainedConfig] | None = None, 324 ) -> Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None]: 325 """Creates derived slow and fastLightningIRTokenizers from a tuple of backbone HuggingFace tokenizer classes. 326 327 Args: 328 BackboneClasses (Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None]): 329 Slow and fast backbone tokenizer classes. 330 BackboneConfig (Type[PretrainedConfig] | None, optional): Backbone configuration class. Defaults to None. 331 Returns: 332 Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None]: Slow and fast derived 333 LightningIRTokenizers. 334 """ 335 DerivedLightningIRTokenizers = tuple( 336 None if BackboneClass is None else self.from_backbone_class(BackboneClass) 337 for BackboneClass in BackboneClasses 338 ) 339 if DerivedLightningIRTokenizers[1] is not None: 340 DerivedLightningIRTokenizers[1].slow_tokenizer_class = DerivedLightningIRTokenizers[0] 341 return DerivedLightningIRTokenizers
342
[docs] 343 def from_backbone_class(self, BackboneClass: Type[PreTrainedTokenizerBase]) -> Type[LightningIRTokenizer]: 344 """Creates a derived LightningIRTokenizer from a transformers.PreTrainedTokenizerBase_ backbone tokenizer. If 345 the backbone tokenizer is already a LightningIRTokenizer, it is returned as is. 346 347 .. _transformers.PreTrainedTokenizerBase: \ 348https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.PreTrainedTokenizerBase 349 350 Args: 351 BackboneClass (Type[PreTrainedTokenizerBase]): Backbone tokenizer class. 352 Returns: 353 Type[LightningIRTokenizer]: Derived LightningIRTokenizer. 354 """ 355 if hasattr(BackboneClass, "config_class"): 356 return BackboneClass 357 LightningIRTokenizerMixin = TOKENIZER_MAPPING[self.MixinConfig][0] 358 359 DerivedLightningIRTokenizer = type( 360 f"{self.cc_lir_model_type(LightningIRTokenizerMixin.config_class)}{BackboneClass.__name__}", 361 (LightningIRTokenizerMixin, BackboneClass), 362 {}, 363 ) 364 365 return DerivedLightningIRTokenizer