Source code for lightning_ir.base.config

  1"""
  2Base configuration class for Lightning IR models.
  3
  4This module defines the configuration class `LightningIRConfig` which is used to instantiate
  5a Lightning IR model. The configuration class acts as a mixin for the `transformers.PretrainedConfig`
  6class from the Hugging Face Transformers library.
  7"""
  8
  9from __future__ import annotations
 10
 11import inspect
 12from pathlib import Path
 13from typing import TYPE_CHECKING, Any, Dict, Optional, Type
 14
 15from transformers import PretrainedConfig
 16
 17from .class_factory import LightningIRConfigClassFactory
 18from .external_model_hub import CHECKPOINT_MAPPING
 19
 20if TYPE_CHECKING:
 21    from .tokenizer import LightningIRTokenizer
 22
 23    try:
 24        from peft import LoraConfig
 25    except ImportError:
 26
 27        class LoraConfig:
 28            pass
 29
 30
[docs] 31class LightningIRConfig(PretrainedConfig): 32 """The configuration class to instantiate a Lightning IR model. Acts as a mixin for the 33 transformers.PretrainedConfig_ class. 34 35 .. _transformers.PretrainedConfig: \ 36https://huggingface.co/transformers/main_classes/configuration.html#transformers.PretrainedConfig 37 """ 38 39 model_type = "lightning-ir" 40 """Model type for the configuration.""" 41 backbone_model_type: str | None = None 42 """Backbone model type for the configuration. Set by :func:`LightningIRModelClassFactory`.""" 43
[docs] 44 def __init__( 45 self, 46 *args, 47 query_length: int | None = 32, 48 doc_length: int | None = 512, 49 use_adapter: bool = False, 50 adapter_config: Optional["LoraConfig"] = None, 51 pretrained_adapter_name_or_path: Optional[str] = None, 52 **kwargs, 53 ): 54 """Initializes the configuration. 55 56 Args: 57 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 58 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 59 use_adapter (bool, optional): Whether to use LoRA adapters. Defaults to False. 60 adapter_config (Optional[LoraConfig], optional): Configuration for LoRA adapters. 61 Only used if use_adapter is True. Defaults to None. 62 pretrained_adapter_name_or_path (Optional[str], optional): The path to a pretrained adapter to load. 63 Defaults to None. 64 """ 65 super().__init__(*args, **kwargs) 66 self.query_length = query_length 67 self.doc_length = doc_length 68 self.use_adapter = use_adapter 69 self.adapter_config = adapter_config 70 self.pretrained_adapter_name_or_path = pretrained_adapter_name_or_path
71
[docs] 72 def get_tokenizer_kwargs(self, Tokenizer: Type[LightningIRTokenizer]) -> Dict[str, Any]: 73 """Returns the keyword arguments for the tokenizer. This method is used to pass the configuration 74 parameters to the tokenizer. 75 76 Args: 77 Tokenizer (Type[LightningIRTokenizer]): Class of the tokenizer to be used. 78 Returns: 79 Dict[str, Any]: Keyword arguments for the tokenizer. 80 """ 81 return {k: getattr(self, k) for k in inspect.signature(Tokenizer.__init__).parameters if hasattr(self, k)}
82
[docs] 83 def to_dict(self) -> Dict[str, Any]: 84 """Overrides the transformers.PretrainedConfig.to_dict_ method to include the added arguments and the backbone 85 model type. 86 87 .. _transformers.PretrainedConfig.to_dict: \ 88https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.to_dict 89 90 Returns: 91 Dict[str, Any]: Configuration dictionary. 92 """ 93 output = super().to_dict() 94 if self.backbone_model_type is not None: 95 output["backbone_model_type"] = self.backbone_model_type 96 return output
97
[docs] 98 @classmethod 99 def from_pretrained(cls, pretrained_model_name_or_path: str | Path, *args, **kwargs) -> "LightningIRConfig": 100 """Loads the configuration from a pretrained model. Wraps the transformers.PretrainedConfig.from_pretrained_ 101 102 .. _transformers.PretrainedConfig.from_pretrained: \ 103https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.from_pretrained 104 105 Args: 106 pretrained_model_name_or_path (str | Path): Pretrained model name or path. 107 Returns: 108 LightningIRConfig: Derived LightningIRConfig class. 109 Raises: 110 ValueError: If `pretrained_model_name_or_path` is not a Lightning IR model and no 111 :py:class:`LightningIRConfig` is passed. 112 """ 113 # provides AutoConfig.from_pretrained support 114 if cls is LightningIRConfig or all(issubclass(base, LightningIRConfig) for base in cls.__bases__): 115 # no backbone config found, create derived lightning-ir config based on backbone config 116 config = None 117 if pretrained_model_name_or_path in CHECKPOINT_MAPPING: 118 config = CHECKPOINT_MAPPING[pretrained_model_name_or_path] 119 ConfigClass = config.__class__ 120 elif cls is not LightningIRConfig: 121 ConfigClass = cls 122 else: 123 ConfigClass = type(LightningIRConfigClassFactory.get_lightning_ir_config(pretrained_model_name_or_path)) 124 if ConfigClass is None: 125 raise ValueError("Pass a config to `from_pretrained`.") 126 backbone_config = LightningIRConfigClassFactory.get_backbone_config(pretrained_model_name_or_path) 127 cls = LightningIRConfigClassFactory(ConfigClass).from_backbone_class(type(backbone_config)) 128 if config is not None and all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__): 129 derived_config = cls.from_pretrained(pretrained_model_name_or_path, config=config) 130 derived_config.update(config.to_dict()) 131 return cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) 132 return super(LightningIRConfig, cls).from_pretrained(pretrained_model_name_or_path, *args, **kwargs)