Source code for lightning_ir.base.class_factory

  1"""
  2Class factory module for Lightning IR.
  3
  4This module provides factory classes for creating various components of the Lightning IR library
  5by extending Hugging Face Transformers classes.
  6"""
  7
  8from __future__ import annotations
  9
 10from abc import ABC, abstractmethod
 11from pathlib import Path
 12from typing import TYPE_CHECKING, Any, Tuple, Type
 13
 14from transformers import (
 15    CONFIG_MAPPING,
 16    MODEL_MAPPING,
 17    TOKENIZER_MAPPING,
 18    PretrainedConfig,
 19    PreTrainedModel,
 20    PreTrainedTokenizerBase,
 21)
 22from transformers.models.auto.tokenization_auto import get_tokenizer_config, tokenizer_class_from_name
 23
 24if TYPE_CHECKING:
 25    from . import LightningIRConfig, LightningIRModel, LightningIRTokenizer
 26
 27
 28def _get_model_class(config: PretrainedConfig | Type[PretrainedConfig]) -> Type[PreTrainedModel]:
 29    # https://github.com/huggingface/transformers/blob/356b3cd71d7bfb51c88fea3e8a0c054f3a457ab9/src/transformers/models/auto/auto_factory.py#L387
 30    if isinstance(config, type):
 31        supported_models = MODEL_MAPPING[config]
 32    else:
 33        supported_models = MODEL_MAPPING[type(config)]
 34    if not isinstance(supported_models, (list, tuple)):
 35        return supported_models
 36
 37    if isinstance(config, type):
 38        # we cannot parse architectures from a config class, we need an instance for this
 39        return supported_models[0]
 40
 41    name_to_model = {model.__name__: model for model in supported_models}
 42    architectures = getattr(config, "architectures", [])
 43    for arch in architectures:
 44        if arch in name_to_model:
 45            return name_to_model[arch]
 46        elif f"TF{arch}" in name_to_model:
 47            return name_to_model[f"TF{arch}"]
 48        elif f"Flax{arch}" in name_to_model:
 49            return name_to_model[f"Flax{arch}"]
 50
 51    # If not architecture is set in the config or match the supported models, the first element of the tuple is the
 52    # defaults.
 53    return supported_models[0]
 54
 55
[docs] 56class LightningIRClassFactory(ABC): 57 """Base class for creating derived Lightning IR classes from HuggingFace classes.""" 58
[docs] 59 def __init__(self, MixinConfig: Type[LightningIRConfig]) -> None: 60 """Creates a new LightningIRClassFactory. 61 62 :param MixinConfig: LightningIRConfig mixin class 63 :type MixinConfig: Type[LightningIRConfig] 64 """ 65 if getattr(MixinConfig, "backbone_model_type", None) is not None: 66 MixinConfig = MixinConfig.__bases__[0] 67 self.MixinConfig = MixinConfig
68
[docs] 69 @staticmethod 70 def get_backbone_config(model_name_or_path: str | Path) -> PretrainedConfig: 71 """Grabs the configuration from a checkpoint of a pretrained HuggingFace model. 72 73 :param model_name_or_path: Path to the model or its name 74 :type model_name_or_path: str | Path 75 :return: Configuration of the backbone model 76 :rtype: PretrainedConfig 77 """ 78 backbone_model_type = LightningIRClassFactory.get_backbone_model_type(model_name_or_path) 79 return CONFIG_MAPPING[backbone_model_type].from_pretrained(model_name_or_path)
80
[docs] 81 @staticmethod 82 def get_lightning_ir_config(model_name_or_path: str | Path) -> LightningIRConfig | None: 83 """Grabs the Lightning IR configuration from a checkpoint of a pretrained Lightning IR model. 84 85 :param model_name_or_path: Path to the model or its name 86 :type model_name_or_path: str | Path 87 :return: Configuration class of the Lightning IR model 88 :rtype: LightningIRConfig | None 89 """ 90 model_type = LightningIRClassFactory.get_lightning_ir_model_type(model_name_or_path) 91 if model_type is None: 92 return None 93 return CONFIG_MAPPING[model_type].from_pretrained(model_name_or_path)
94
[docs] 95 @staticmethod 96 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str: 97 """Grabs the model type from a checkpoint of a pretrained HuggingFace model. 98 99 :param model_name_or_path: Path to the model or its name 100 :type model_name_or_path: str | Path 101 :return: Model type of the backbone model 102 :rtype: str 103 """ 104 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path, *args, **kwargs) 105 backbone_model_type = config_dict.get("backbone_model_type", None) or config_dict.get("model_type") 106 if backbone_model_type is None: 107 raise ValueError(f"Unable to load PretrainedConfig from {model_name_or_path}") 108 return backbone_model_type
109
[docs] 110 @staticmethod 111 def get_lightning_ir_model_type(model_name_or_path: str | Path) -> str | None: 112 """Grabs the Lightning IR model type from a checkpoint of a pretrained HuggingFace model. 113 114 :param model_name_or_path: Path to the model or its name 115 :type model_name_or_path: str | Path 116 :return: Model type of the Lightning IR model 117 :rtype: str | None 118 """ 119 config_dict, _ = PretrainedConfig.get_config_dict(model_name_or_path) 120 if "backbone_model_type" not in config_dict: 121 return None 122 return config_dict.get("model_type", None)
123 124 @property 125 def cc_lir_model_type(self) -> str: 126 """Camel case model type of the Lightning IR model.""" 127 return "".join(s.title() for s in self.MixinConfig.model_type.split("-")) 128
[docs] 129 @abstractmethod 130 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Any: 131 """Loads a derived Lightning IR class from a pretrained HuggingFace model. Must be implemented by subclasses. 132 133 :param model_name_or_path: Path to the model or its name 134 :type model_name_or_path: str | Path 135 :return: Derived Lightning IR class 136 :rtype: Any 137 """ 138 ...
139
[docs] 140 @abstractmethod 141 def from_backbone_class(self, BackboneClass: Type) -> Type: 142 """Creates a derived Lightning IR class from a backbone HuggingFace class. Must be implemented by subclasses. 143 144 :param BackboneClass: Backbone class 145 :type BackboneClass: Type 146 :return: Derived Lightning IR class 147 :rtype: Type 148 """ 149 ...
150 151
[docs] 152class LightningIRConfigClassFactory(LightningIRClassFactory): 153 """Class factory for creating derived LightningIRConfig classes from HuggingFace configuration classes.""" 154
[docs] 155 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRConfig]: 156 """Loads a derived LightningIRConfig from a pretrained HuggingFace model. 157 158 :param model_name_or_path: Path to the model or its name 159 :type model_name_or_path: str | Path 160 :return: Derived LightningIRConfig 161 :rtype: Type[LightningIRConfig] 162 """ 163 backbone_config = self.get_backbone_config(model_name_or_path) 164 DerivedLightningIRConfig = self.from_backbone_class(type(backbone_config)) 165 return DerivedLightningIRConfig
166
[docs] 167 def from_backbone_class(self, BackboneClass: Type[PretrainedConfig]) -> Type[LightningIRConfig]: 168 """Creates a derived LightningIRConfig from a transformers.PretrainedConfig_ backbone configuration class. If 169 the backbone configuration class is already a dervied LightningIRConfig, it is returned as is. 170 171 .. _transformers.PretrainedConfig: \ 172https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig 173 174 :param BackboneClass: Backbone configuration class 175 :type BackboneClass: Type[PretrainedConfig] 176 :return: Derived LightningIRConfig 177 :rtype: Type[LightningIRConfig] 178 """ 179 if getattr(BackboneClass, "backbone_model_type", None) is not None: 180 return BackboneClass 181 LightningIRConfigMixin: Type[LightningIRConfig] = CONFIG_MAPPING[self.MixinConfig.model_type] 182 183 DerivedLightningIRConfig = type( 184 f"{self.cc_lir_model_type}{BackboneClass.__name__}", 185 (LightningIRConfigMixin, BackboneClass), 186 { 187 "model_type": self.MixinConfig.model_type, 188 "backbone_model_type": BackboneClass.model_type, 189 "mixin_config": self.MixinConfig, 190 }, 191 ) 192 return DerivedLightningIRConfig
193 194
[docs] 195class LightningIRModelClassFactory(LightningIRClassFactory): 196 """Class factory for creating derived LightningIRModel classes from HuggingFace model classes.""" 197
[docs] 198 def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Type[LightningIRModel]: 199 """Loads a derived LightningIRModel from a pretrained HuggingFace model. 200 201 :param model_name_or_path: Path to the model or its name 202 :type model_name_or_path: str | Path 203 :return: Derived LightningIRModel 204 :rtype: Type[LightningIRModel] 205 """ 206 backbone_config = self.get_backbone_config(model_name_or_path) 207 BackboneModel = _get_model_class(backbone_config) 208 DerivedLightningIRModel = self.from_backbone_class(BackboneModel) 209 return DerivedLightningIRModel
210
[docs] 211 def from_backbone_class(self, BackboneClass: Type[PreTrainedModel]) -> Type[LightningIRModel]: 212 """Creates a derived LightningIRModel from a transformers.PreTrainedModel_ backbone model. If the backbone model 213 is already a LightningIRModel, it is returned as is. 214 215 .. _transformers.PreTrainedModel: \ 216https://huggingface.co/transformers/main_classes/model#transformers.PreTrainedModel 217 218 :param BackboneClass: Backbone model 219 :type BackboneClass: Type[PreTrainedModel] 220 :raises ValueError: If the backbone model is not a valid backbone model. 221 :raises ValueError: If the backbone model is not a LightningIRModel and no LightningIRConfig is passed. 222 :raises ValueError: If the LightningIRModel mixin is not registered with the Hugging Face model mapping. 223 :return: The derived LightningIRModel 224 :rtype: Type[LightningIRModel] 225 """ 226 if getattr(BackboneClass.config_class, "backbone_model_type", None) is not None: 227 return BackboneClass 228 BackboneConfig = BackboneClass.config_class 229 if BackboneConfig is None: 230 raise ValueError( 231 f"Model {BackboneClass} is not a valid backbone model because it is missing a `config_class`." 232 ) 233 234 LightningIRModelMixin: Type[LightningIRModel] = _get_model_class(self.MixinConfig) 235 236 DerivedLightningIRConfig = LightningIRConfigClassFactory(self.MixinConfig).from_backbone_class(BackboneConfig) 237 238 DerivedLightningIRModel = type( 239 f"{self.cc_lir_model_type}{BackboneClass.__name__}", 240 (LightningIRModelMixin, BackboneClass), 241 {"config_class": DerivedLightningIRConfig, "_backbone_forward": BackboneClass.forward}, 242 ) 243 return DerivedLightningIRModel
244 245
[docs] 246class LightningIRTokenizerClassFactory(LightningIRClassFactory): 247 """Class factory for creating derived LightningIRTokenizer classes from HuggingFace tokenizer classes.""" 248
[docs] 249 @staticmethod 250 def get_backbone_config(model_name_or_path: str | Path) -> PretrainedConfig: 251 """Grabs the tokenizer configuration class from a checkpoint of a pretrained HuggingFace tokenizer. 252 253 :param model_name_or_path: Path to the tokenizer or its name 254 :type model_name_or_path: str | Path 255 :return: Configuration class of the backbone tokenizer 256 :rtype: PretrainedConfig 257 """ 258 backbone_model_type = LightningIRTokenizerClassFactory.get_backbone_model_type(model_name_or_path) 259 return CONFIG_MAPPING[backbone_model_type].from_pretrained(model_name_or_path)
260
[docs] 261 @staticmethod 262 def get_backbone_model_type(model_name_or_path: str | Path, *args, **kwargs) -> str: 263 """Grabs the model type from a checkpoint of a pretrained HuggingFace tokenizer. 264 265 :param model_name_or_path: Path to the tokenizer or its name 266 :type model_name_or_path: str | Path 267 :return: Model type of the backbone tokenizer 268 :rtype: str 269 """ 270 try: 271 return LightningIRClassFactory.get_backbone_model_type(model_name_or_path, *args, **kwargs) 272 except (OSError, ValueError): 273 # best guess at model type 274 config_dict = get_tokenizer_config(model_name_or_path) 275 backbone_tokenizer_class = config_dict.get("backbone_tokenizer_class", None) 276 if backbone_tokenizer_class is not None: 277 Tokenizer = tokenizer_class_from_name(backbone_tokenizer_class) 278 for config, tokenizers in TOKENIZER_MAPPING.items(): 279 if Tokenizer in tokenizers: 280 return getattr(config, "model_type") 281 raise ValueError("No backbone model found in the configuration")
282
[docs] 283 def from_pretrained( 284 self, model_name_or_path: str | Path, *args, use_fast: bool = True, **kwargs 285 ) -> Type[LightningIRTokenizer]: 286 """Loads a derived LightningIRTokenizer from a pretrained HuggingFace tokenizer. 287 288 :param model_name_or_path: Path to the tokenizer or its name 289 :type model_name_or_path: str | Path 290 :param use_fast: Whether to use the fast or slow tokenizer, defaults to True 291 :type use_fast: bool, optional 292 :raises ValueError: If use_fast is True and no fast tokenizer is found 293 :raises ValueError: If use_fast is False and no slow tokenizer is found 294 :return: Derived LightningIRTokenizer 295 :rtype: Type[LightningIRTokenizer] 296 """ 297 backbone_config = self.get_backbone_config(model_name_or_path) 298 BackboneTokenizers = TOKENIZER_MAPPING[type(backbone_config)] 299 DerivedLightningIRTokenizers = self.from_backbone_classes(BackboneTokenizers, type(backbone_config)) 300 if use_fast: 301 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[1] 302 if DerivedLightningIRTokenizer is None: 303 raise ValueError("No fast tokenizer found.") 304 else: 305 DerivedLightningIRTokenizer = DerivedLightningIRTokenizers[0] 306 if DerivedLightningIRTokenizer is None: 307 raise ValueError("No slow tokenizer found.") 308 return DerivedLightningIRTokenizer
309
[docs] 310 def from_backbone_classes( 311 self, 312 BackboneClasses: Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None], 313 BackboneConfig: Type[PretrainedConfig] | None = None, 314 ) -> Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None]: 315 """Creates derived slow and fastLightningIRTokenizers from a tuple of backbone HuggingFace tokenizer classes. 316 317 :param BackboneClasses: Slow and fast backbone tokenizer classes 318 :type BackboneClasses: Tuple[Type[PreTrainedTokenizerBase] | None, Type[PreTrainedTokenizerBase] | None] 319 :param BackboneConfig: Backbone configuration class, defaults to None 320 :type BackboneConfig: Type[PretrainedConfig], optional 321 :return: Slow and fast derived LightningIRTokenizers 322 :rtype: Tuple[Type[LightningIRTokenizer] | None, Type[LightningIRTokenizer] | None] 323 """ 324 DerivedLightningIRTokenizers = tuple( 325 None if BackboneClass is None else self.from_backbone_class(BackboneClass) 326 for BackboneClass in BackboneClasses 327 ) 328 if DerivedLightningIRTokenizers[1] is not None: 329 DerivedLightningIRTokenizers[1].slow_tokenizer_class = DerivedLightningIRTokenizers[0] 330 return DerivedLightningIRTokenizers
331
[docs] 332 def from_backbone_class(self, BackboneClass: Type[PreTrainedTokenizerBase]) -> Type[LightningIRTokenizer]: 333 """Creates a derived LightningIRTokenizer from a transformers.PreTrainedTokenizerBase_ backbone tokenizer. If 334 the backbone tokenizer is already a LightningIRTokenizer, it is returned as is. 335 336 .. _transformers.PreTrainedTokenizerBase: \ 337https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.PreTrainedTokenizerBase 338 339 :param BackboneClass: Backbone tokenizer class 340 :type BackboneClass: Type[PreTrainedTokenizerBase] 341 :return: Derived LightningIRTokenizer 342 :rtype: Type[LightningIRTokenizer] 343 """ 344 if hasattr(BackboneClass, "config_class"): 345 return BackboneClass 346 LightningIRTokenizerMixin = TOKENIZER_MAPPING[self.MixinConfig][0] 347 348 DerivedLightningIRTokenizer = type( 349 f"{self.cc_lir_model_type}{BackboneClass.__name__}", (LightningIRTokenizerMixin, BackboneClass), {} 350 ) 351 352 return DerivedLightningIRTokenizer