Source code for lightning_ir.base.config

  1"""
  2Base configuration class for Lightning IR models.
  3
  4This module defines the configuration class `LightningIRConfig` which is used to instantiate
  5a Lightning IR model. The configuration class acts as a mixin for the `transformers.PretrainedConfig`
  6class from the Hugging Face Transformers library.
  7"""
  8
  9from __future__ import annotations
 10
 11import inspect
 12from pathlib import Path
 13from typing import TYPE_CHECKING, Any, Dict, Type
 14
 15from transformers import PretrainedConfig
 16
 17from .class_factory import LightningIRConfigClassFactory
 18from .external_model_hub import CHECKPOINT_MAPPING
 19
 20if TYPE_CHECKING:
 21    from .tokenizer import LightningIRTokenizer
 22
 23
[docs] 24class LightningIRConfig(PretrainedConfig): 25 """The configuration class to instantiate a Lightning IR model. Acts as a mixin for the 26 transformers.PretrainedConfig_ class. 27 28 .. _transformers.PretrainedConfig: \ 29https://huggingface.co/transformers/main_classes/configuration.html#transformers.PretrainedConfig 30 """ 31 32 model_type = "lightning-ir" 33 """Model type for the configuration.""" 34 backbone_model_type: str | None = None 35 """Backbone model type for the configuration. Set by :func:`LightningIRModelClassFactory`.""" 36
[docs] 37 def __init__(self, *args, query_length: int = 32, doc_length: int = 512, **kwargs): 38 """Initializes the configuration. 39 40 :param query_length: Maximum query length, defaults to 32 41 :type query_length: int, optional 42 :param doc_length: Maximum document length, defaults to 512 43 :type doc_length: int, optional 44 """ 45 super().__init__(*args, **kwargs) 46 self.query_length = query_length 47 self.doc_length = doc_length
48
[docs] 49 def get_tokenizer_kwargs(self, Tokenizer: Type[LightningIRTokenizer]) -> Dict[str, Any]: 50 """Returns the keyword arguments for the tokenizer. This method is used to pass the configuration 51 parameters to the tokenizer. 52 53 :param Tokenizer: Class of the tokenizer to be used 54 :type Tokenizer: Type[LightningIRTokenizer] 55 :return: Keyword arguments for the tokenizer 56 :rtype: Dict[str, Any] 57 """ 58 return {k: getattr(self, k) for k in inspect.signature(Tokenizer.__init__).parameters if hasattr(self, k)}
59
[docs] 60 def to_dict(self) -> Dict[str, Any]: 61 """Overrides the transformers.PretrainedConfig.to_dict_ method to include the added arguments and the backbone 62 model type. 63 64 .. _transformers.PretrainedConfig.to_dict: \ 65https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.to_dict 66 67 :return: Configuration dictionary 68 :rtype: Dict[str, Any] 69 """ 70 output = super().to_dict() 71 if self.backbone_model_type is not None: 72 output["backbone_model_type"] = self.backbone_model_type 73 return output
74
[docs] 75 @classmethod 76 def from_pretrained(cls, pretrained_model_name_or_path: str | Path, *args, **kwargs) -> "LightningIRConfig": 77 """Loads the configuration from a pretrained model. Wraps the transformers.PretrainedConfig.from_pretrained_ 78 79 .. _transformers.PretrainedConfig.from_pretrained: \ 80https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.from_pretrained 81 82 :param pretrained_model_name_or_path: Pretrained model name or path 83 :type pretrained_model_name_or_path: str | Path 84 :raises ValueError: If `pre_trained_model_name_or_path` is not a Lightning IR model and no 85 :py:class:`LightningIRConfig` is passed 86 :return: Derived LightningIRConfig class 87 :rtype: LightningIRConfig 88 """ 89 # provides AutoConfig.from_pretrained support 90 if cls is LightningIRConfig or all(issubclass(base, LightningIRConfig) for base in cls.__bases__): 91 # no backbone config found, create derived lightning-ir config based on backbone config 92 config = None 93 if pretrained_model_name_or_path in CHECKPOINT_MAPPING: 94 config = CHECKPOINT_MAPPING[pretrained_model_name_or_path] 95 ConfigClass = config.__class__ 96 elif cls is not LightningIRConfig: 97 ConfigClass = cls 98 else: 99 ConfigClass = type(LightningIRConfigClassFactory.get_lightning_ir_config(pretrained_model_name_or_path)) 100 if ConfigClass is None: 101 raise ValueError("Pass a config to `from_pretrained`.") 102 backbone_config = LightningIRConfigClassFactory.get_backbone_config(pretrained_model_name_or_path) 103 cls = LightningIRConfigClassFactory(ConfigClass).from_backbone_class(type(backbone_config)) 104 if config is not None and all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__): 105 derived_config = cls.from_pretrained(pretrained_model_name_or_path, config=config) 106 derived_config.update(config.to_dict()) 107 return cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) 108 return super(LightningIRConfig, cls).from_pretrained(pretrained_model_name_or_path, *args, **kwargs)