Source code for lightning_ir.base.class_factory

  1"""
  2Class factory module for Lightning IR.
  3
  4This module provides factory classes for creating various components of the Lightning IR library
  5by extending Hugging Face Transformers classes.
  6"""
  7
  8from __future__ import annotations
  9
 10from abc import ABC, abstractmethod
 11from pathlib import Path
 12from typing import TYPE_CHECKING, Any, Tuple, Type
 13
 14from transformers import (
 15    CONFIG_MAPPING,
 16    MODEL_MAPPING,
 17    TOKENIZER_MAPPING,
 18    PretrainedConfig,
 19    PreTrainedModel,
 20    PreTrainedTokenizerBase,
 21)
 22from transformers.models.auto.tokenization_auto import get_tokenizer_config, tokenizer_class_from_name
 23
 24if TYPE_CHECKING:
 25    from . import LightningIRConfig, LightningIRModel, LightningIRTokenizer
 26
 27
 28def _get_model_class(config: PretrainedConfig | Type[PretrainedConfig]) -> Type[PreTrainedModel]:
 29    # https://github.com/huggingface/transformers/blob/356b3cd71d7bfb51c88fea3e8a0c054f3a457ab9/src/transformers/models/auto/auto_factory.py#L387
 30    if isinstance(config, type):
 31        supported_models = MODEL_MAPPING[config]
 32    else:
 33        supported_models = MODEL_MAPPING[type(config)]
 34    if not isinstance(supported_models, (list, tuple)):
 35        return supported_models
 36
 37    if isinstance(config, type):
 38        # we cannot parse architectures from a config class, we need an instance for this
 39        return supported_models[0]
 40
 41    name_to_model = {model.__name__: model for model in supported_models}
 42    architectures = getattr(config, "architectures", [])
 43    for arch in architectures:
 44        if arch in name_to_model:
 45            return name_to_model[arch]
 46        elif f"TF{arch}" in name_to_model:
 47            return name_to_model[f"TF{arch}"]
 48        elif f"Flax{arch}" in name_to_model:
 49            return name_to_model[f"Flax{arch}"]
 50
 51    # If not architecture is set in the config or match the supported models, the first element of the tuple is the
 52    # defaults.
 53    return supported_models[0]
 54
 55
[docs] 56class LightningIRClassFactory(ABC): 57 """Base class for creating derived Lightning IR classes from HuggingFace classes.""" 58
[docs] 59 def __init__(self, MixinConfig: Type[LightningIRConfig]) -> None: 60 """Creates a new LightningIRClassFactory. 61 62 Args: 63 MixinConfig (Type[LightningIRConfig]): LightningIRConfig mixin class. 64 """ 65 if getattr(MixinConfig, "backbone_model_type", None) is not None: 66 MixinConfig = MixinConfig.__bases__[0] 67 self.MixinConfig = MixinConfig
68
[docs] 69 @staticmethod 70 def get_backbone_config(model_name_or_path: str | Path) -> PretrainedConfig: 71 """Grabs the configuration from a checkpoint of a pretrained HuggingFace model. 72 73 Args: 74 model_name_or_path (str | Path): Path to the model or its name. 75 Returns: 76 PretrainedConfig: Configuration of the backbone model. 77 """ 78 backbone_model_type = LightningIRClassFactory.get_backbone_model_type(model_name_or_path) 79 return CONFIG_MAPPING[backbone_model_type].from_pretrained(model_name_or_path)
80
[docs] 81 @staticmethod 82 def get_lightning_ir_config(model_name_or_path: str | Path) -> LightningIRConfig | None: 83 """Grabs the Lightning IR configuration from a checkpoint of a pretrained Lightning IR model. 84 85 Args: 86 model_name_or_path (str | Path): Path to the model or its name. 87 Returns: 88 LightningIRConfig | None: Configuration class of the Lightning IR model. 89 """ 90 model_type = LightningIRClassFactory.get_lightning_ir_model_type(model_name_or_path) 91 if model_type is None: 92 return None 93 return CONFIG_MAPPING[model_type].from_pretrained(model_name_or_path)
94
[docs] 95 @staticmethod 96 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str: 97 """Grabs the model type from a checkpoint of a pretrained HuggingFace model. 98 99 Args: 100 model_name_or_path (str | Path): Path to the model or its name. 101 Returns: 102 str: Model type of the backbone model. 103 Raises: 104 ValueError: If the type of the model is None in the configuration. 105 """ 106 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path, *args, **kwargs) 107 backbone_model_type = config_dict.get("backbone_model_type", None) or config_dict.get("model_type") 108 if backbone_model_type is None: 109 raise ValueError(f"Unable to load PretrainedConfig from {model_name_or_path}") 110 return backbone_model_type
111
[docs] 112 @staticmethod 113 def get_lightning_ir_model_type(model_name_or_path: str | Path) -> str | None: 114 """Grabs the Lightning IR model type from a checkpoint of a pretrained HuggingFace model. 115 116 Args: 117 model_name_or_path (str | Path): Path to the model or its name. 118 Returns: 119 str | None: Model type of the Lightning IR model. 120 Raises: 121 ValueError: If the backbone model type is not found in the configuration. 122 """ 123 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path) 124 if "backbone_model_type" not in config_dict: 125 return None 126 return config_dict.get("model_type", None)
127 128 @property 129 def cc_lir_model_type(self) -> str: 130 """Camel case model type of the Lightning IR model.""" 131 return "".join(s.title() for s in self.MixinConfig.model_type.split("-")) 132
[docs] 133 @abstractmethod 134 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Any: 135 """Loads a derived Lightning IR class from a pretrained HuggingFace model. Must be implemented by subclasses. 136 137 Args: 138 model_name_or_path (str | Path): Path to the model or its name. 139 Returns: 140 Any: Derived Lightning IR class. 141 """ 142 ...
143
[docs] 144 @abstractmethod 145 def from_backbone_class(self, BackboneClass: Type) -> Type: 146 """Creates a derived Lightning IR class from a backbone HuggingFace class. Must be implemented by subclasses. 147 148 Args: 149 BackboneClass (Type): Backbone class. 150 Returns: 151 Type: Derived Lightning IR class. 152 """ 153 ...
154 155
[docs] 156class LightningIRConfigClassFactory(LightningIRClassFactory): 157 """Class factory for creating derived LightningIRConfig classes from HuggingFace configuration classes.""" 158
[docs] 159 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRConfig]: 160 """Loads a derived LightningIRConfig from a pretrained HuggingFace model. 161 162 Args: 163 model_name_or_path (str | Path): Path to the model or its name. 164 Returns: 165 Type[LightningIRConfig]: Derived LightningIRConfig. 166 """ 167 backbone_config = self.get_backbone_config(model_name_or_path) 168 DerivedLightningIRConfig = self.from_backbone_class(type(backbone_config)) 169 return DerivedLightningIRConfig
170
[docs] 171 def from_backbone_class(self, BackboneClass: Type[PretrainedConfig]) -> Type[LightningIRConfig]: 172 """Creates a derived LightningIRConfig from a transformers.PretrainedConfig_ backbone configuration class. If 173 the backbone configuration class is already a derived LightningIRConfig, it is returned as is. 174 175 .. _transformers.PretrainedConfig: \ 176https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig 177 178 Args: 179 BackboneClass (Type[PretrainedConfig]): Backbone configuration class. 180 Returns: 181 Type[LightningIRConfig]: Derived LightningIRConfig. 182 """ 183 if getattr(BackboneClass, "backbone_model_type", None) is not None: 184 return BackboneClass 185 LightningIRConfigMixin: Type[LightningIRConfig] = CONFIG_MAPPING[self.MixinConfig.model_type] 186 187 DerivedLightningIRConfig = type( 188 f"{self.cc_lir_model_type}{BackboneClass.__name__}", 189 (LightningIRConfigMixin, BackboneClass), 190 { 191 "model_type": self.MixinConfig.model_type, 192 "backbone_model_type": BackboneClass.model_type, 193 "mixin_config": self.MixinConfig, 194 }, 195 ) 196 return DerivedLightningIRConfig
197 198
[docs] 199class LightningIRModelClassFactory(LightningIRClassFactory): 200 """Class factory for creating derived LightningIRModel classes from HuggingFace model classes.""" 201
[docs] 202 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRModel]: 203 """Loads a derived LightningIRModel from a pretrained HuggingFace model. 204 205 Args: 206 model_name_or_path (str | Path): Path to the model or its name. 207 Returns: 208 Type[LightningIRModel]: Derived LightningIRModel. 209 """ 210 backbone_config = self.get_backbone_config(model_name_or_path) 211 BackboneModel = _get_model_class(backbone_config) 212 DerivedLightningIRModel = self.from_backbone_class(BackboneModel) 213 return DerivedLightningIRModel
214
[docs] 215 def from_backbone_class(self, BackboneClass: Type[PreTrainedModel]) -> Type[LightningIRModel]: 216 """Creates a derived LightningIRModel from a transformers.PreTrainedModel_ backbone model. If the backbone model 217 is already a LightningIRModel, it is returned as is. 218 219 .. _transformers.PreTrainedModel: \ 220https://huggingface.co/transformers/main_classes/model#transformers.PreTrainedModel 221 222 Args: 223 BackboneClass (Type[PreTrainedModel]): Backbone model class. 224 Returns: 225 Type[LightningIRModel]: Derived LightningIRModel. 226 Raises: 227 ValueError: If the backbone model is not a valid backbone model. 228 ValueError: If the backbone model is not a LightningIRModel and no LightningIRConfig is passed. 229 ValueError: If the LightningIRModel mixin is not registered with the Hugging Face model mapping. 230 """ 231 if getattr(BackboneClass.config_class, "backbone_model_type", None) is not None: 232 return BackboneClass 233 BackboneConfig = BackboneClass.config_class 234 if BackboneConfig is None: 235 raise ValueError( 236 f"Model {BackboneClass} is not a valid backbone model because it is missing a `config_class`." 237 ) 238 239 LightningIRModelMixin: Type[LightningIRModel] = _get_model_class(self.MixinConfig) 240 241 DerivedLightningIRConfig = LightningIRConfigClassFactory(self.MixinConfig).from_backbone_class(BackboneConfig) 242 243 DerivedLightningIRModel = type( 244 f"{self.cc_lir_model_type}{BackboneClass.__name__}", 245 (LightningIRModelMixin, BackboneClass), 246 {"config_class": DerivedLightningIRConfig, "_backbone_forward": BackboneClass.forward}, 247 ) 248 return DerivedLightningIRModel
249 250
[docs] 251class LightningIRTokenizerClassFactory(LightningIRClassFactory): 252 """Class factory for creating derived LightningIRTokenizer classes from HuggingFace tokenizer classes.""" 253
[docs] 254 @staticmethod 255 def get_backbone_config(model_name_or_path: str | Path) -> PretrainedConfig: 256 """Grabs the tokenizer configuration class from a checkpoint of a pretrained HuggingFace tokenizer. 257 258 Args: 259 model_name_or_path (str | Path): Path to the tokenizer or its name. 260 Returns: 261 PretrainedConfig: Configuration class of the backbone tokenizer. 262 """ 263 backbone_model_type = LightningIRTokenizerClassFactory.get_backbone_model_type(model_name_or_path) 264 return CONFIG_MAPPING[backbone_model_type].from_pretrained(model_name_or_path)
265
[docs] 266 @staticmethod 267 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str: 268 """Grabs the model type from a checkpoint of a pretrained HuggingFace tokenizer. 269 270 Args: 271 model_name_or_path (str | Path): Path to the tokenizer or its name. 272 Returns: 273 str: Model type of the backbone tokenizer. 274 """ 275 try: 276 return LightningIRClassFactory.get_backbone_model_type(model_name_or_path, *args, **kwargs) 277 except (OSError, ValueError): 278 # best guess at model type 279 config_dict = get_tokenizer_config(model_name_or_path) 280 backbone_tokenizer_class = config_dict.get("backbone_tokenizer_class", None) 281 if backbone_tokenizer_class is not None: 282 Tokenizer = tokenizer_class_from_name(backbone_tokenizer_class) 283 for config, tokenizers in TOKENIZER_MAPPING.items(): 284 if Tokenizer in tokenizers: 285 return getattr(config, "model_type") 286 raise ValueError("No backbone model found in the configuration")
287
[docs] 288 def from_pretrained( 289 self, model_name_or_path: str | Path, *args, use_fast: bool = True, **kwargs 290 ) -> Type[LightningIRTokenizer]: 291 """Loads a derived LightningIRTokenizer from a pretrained HuggingFace tokenizer. 292 293 Args: 294 model_name_or_path (str | Path): Path to the tokenizer or its name. 295 use_fast (bool, optional): Whether to use the fast tokenizer. Defaults to True. 296 Returns: 297 Type[LightningIRTokenizer]: Derived LightningIRTokenizer. 298 Raises: 299 ValueError: If no fast tokenizer is found when `use_fast` is True. 300 ValueError: If no slow tokenizer is found when `use_fast` is False. 301 """ 302 backbone_config = self.get_backbone_config(model_name_or_path) 303 BackboneTokenizers = TOKENIZER_MAPPING[type(backbone_config)] 304 DerivedLightningIRTokenizers = self.from_backbone_classes(BackboneTokenizers, type(backbone_config)) 305 if use_fast: 306 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[1] 307 if DerivedLightningIRTokenizer is None: 308 raise ValueError("No fast tokenizer found.") 309 else: 310 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[0] 311 if DerivedLightningIRTokenizer is None: 312 raise ValueError("No slow tokenizer found.") 313 return DerivedLightningIRTokenizer
314
[docs] 315 def from_backbone_classes( 316 self, 317 BackboneClasses: Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None], 318 BackboneConfig: Type[PretrainedConfig] | None = None, 319 ) -> Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None]: 320 """Creates derived slow and fastLightningIRTokenizers from a tuple of backbone HuggingFace tokenizer classes. 321 322 Args: 323 BackboneClasses (Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None]): 324 Slow and fast backbone tokenizer classes. 325 BackboneConfig (Type[PretrainedConfig] | None, optional): Backbone configuration class. Defaults to None. 326 Returns: 327 Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None]: Slow and fast derived 328 LightningIRTokenizers. 329 """ 330 DerivedLightningIRTokenizers = tuple( 331 None if BackboneClass is None else self.from_backbone_class(BackboneClass) 332 for BackboneClass in BackboneClasses 333 ) 334 if DerivedLightningIRTokenizers[1] is not None: 335 DerivedLightningIRTokenizers[1].slow_tokenizer_class = DerivedLightningIRTokenizers[0] 336 return DerivedLightningIRTokenizers
337
[docs] 338 def from_backbone_class(self, BackboneClass: Type[PreTrainedTokenizerBase]) -> Type[LightningIRTokenizer]: 339 """Creates a derived LightningIRTokenizer from a transformers.PreTrainedTokenizerBase_ backbone tokenizer. If 340 the backbone tokenizer is already a LightningIRTokenizer, it is returned as is. 341 342 .. _transformers.PreTrainedTokenizerBase: \ 343https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.PreTrainedTokenizerBase 344 345 Args: 346 BackboneClass (Type[PreTrainedTokenizerBase]): Backbone tokenizer class. 347 Returns: 348 Type[LightningIRTokenizer]: Derived LightningIRTokenizer. 349 """ 350 if hasattr(BackboneClass, "config_class"): 351 return BackboneClass 352 LightningIRTokenizerMixin = TOKENIZER_MAPPING[self.MixinConfig][0] 353 354 DerivedLightningIRTokenizer = type( 355 f"{self.cc_lir_model_type}{BackboneClass.__name__}", (LightningIRTokenizerMixin, BackboneClass), {} 356 ) 357 358 return DerivedLightningIRTokenizer