Source code for lightning_ir.base.config
1"""
2Base configuration class for Lightning IR models.
3
4This module defines the configuration class `LightningIRConfig` which is used to instantiate
5a Lightning IR model. The configuration class acts as a mixin for the `transformers.PretrainedConfig`
6class from the Hugging Face Transformers library.
7"""
8
9from __future__ import annotations
10
11import inspect
12from pathlib import Path
13from typing import TYPE_CHECKING, Any, Dict, Optional, Type
14
15from transformers import PretrainedConfig
16
17from .class_factory import LightningIRConfigClassFactory
18from .external_model_hub import CHECKPOINT_MAPPING
19
20if TYPE_CHECKING:
21 from .tokenizer import LightningIRTokenizer
22
23 try:
24 from peft import LoraConfig
25 except ImportError:
26
27 class LoraConfig:
28 pass
29
30
[docs]
31class LightningIRConfig(PretrainedConfig):
32 """The configuration class to instantiate a Lightning IR model. Acts as a mixin for the
33 transformers.PretrainedConfig_ class.
34
35 .. _transformers.PretrainedConfig: \
36https://huggingface.co/transformers/main_classes/configuration.html#transformers.PretrainedConfig
37 """
38
39 model_type = "lightning-ir"
40 """Model type for the configuration."""
41 backbone_model_type: str | None = None
42 """Backbone model type for the configuration. Set by :func:`LightningIRModelClassFactory`."""
43
[docs]
44 def __init__(
45 self,
46 *args,
47 query_length: int | None = 32,
48 doc_length: int | None = 512,
49 use_adapter: bool = False,
50 adapter_config: Optional["LoraConfig"] = None,
51 pretrained_adapter_name_or_path: Optional[str] = None,
52 **kwargs,
53 ):
54 """Initializes the configuration.
55
56 Args:
57 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
58 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
59 use_adapter (bool, optional): Whether to use LoRA adapters. Defaults to False.
60 adapter_config (Optional[LoraConfig], optional): Configuration for LoRA adapters.
61 Only used if use_adapter is True. Defaults to None.
62 pretrained_adapter_name_or_path (Optional[str], optional): The path to a pretrained adapter to load.
63 Defaults to None.
64 """
65 super().__init__(*args, **kwargs)
66 self.query_length = query_length
67 self.doc_length = doc_length
68 self.use_adapter = use_adapter
69 self.adapter_config = adapter_config
70 self.pretrained_adapter_name_or_path = pretrained_adapter_name_or_path
71
[docs]
72 def get_tokenizer_kwargs(self, Tokenizer: Type[LightningIRTokenizer]) -> Dict[str, Any]:
73 """Returns the keyword arguments for the tokenizer. This method is used to pass the configuration
74 parameters to the tokenizer.
75
76 Args:
77 Tokenizer (Type[LightningIRTokenizer]): Class of the tokenizer to be used.
78 Returns:
79 Dict[str, Any]: Keyword arguments for the tokenizer.
80 """
81 return {k: getattr(self, k) for k in inspect.signature(Tokenizer.__init__).parameters if hasattr(self, k)}
82
[docs]
83 def to_dict(self) -> Dict[str, Any]:
84 """Overrides the transformers.PretrainedConfig.to_dict_ method to include the added arguments and the backbone
85 model type.
86
87 .. _transformers.PretrainedConfig.to_dict: \
88https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.to_dict
89
90 Returns:
91 Dict[str, Any]: Configuration dictionary.
92 """
93 output = super().to_dict()
94 if self.backbone_model_type is not None:
95 output["backbone_model_type"] = self.backbone_model_type
96 return output
97
[docs]
98 @classmethod
99 def from_pretrained(cls, pretrained_model_name_or_path: str | Path, *args, **kwargs) -> "LightningIRConfig":
100 """Loads the configuration from a pretrained model. Wraps the transformers.PretrainedConfig.from_pretrained_
101
102 .. _transformers.PretrainedConfig.from_pretrained: \
103https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.from_pretrained
104
105 Args:
106 pretrained_model_name_or_path (str | Path): Pretrained model name or path.
107 Returns:
108 LightningIRConfig: Derived LightningIRConfig class.
109 Raises:
110 ValueError: If `pretrained_model_name_or_path` is not a Lightning IR model and no
111 :py:class:`LightningIRConfig` is passed.
112 """
113 # provides AutoConfig.from_pretrained support
114 if cls is LightningIRConfig or all(issubclass(base, LightningIRConfig) for base in cls.__bases__):
115 # no backbone config found, create derived lightning-ir config based on backbone config
116 config = None
117 if pretrained_model_name_or_path in CHECKPOINT_MAPPING:
118 config = CHECKPOINT_MAPPING[pretrained_model_name_or_path]
119 ConfigClass = config.__class__
120 elif cls is not LightningIRConfig:
121 ConfigClass = cls
122 else:
123 ConfigClass = type(LightningIRConfigClassFactory.get_lightning_ir_config(pretrained_model_name_or_path))
124 if ConfigClass is None:
125 raise ValueError("Pass a config to `from_pretrained`.")
126 backbone_config = LightningIRConfigClassFactory.get_backbone_config(pretrained_model_name_or_path)
127 cls = LightningIRConfigClassFactory(ConfigClass).from_backbone_class(type(backbone_config))
128 if config is not None and all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__):
129 derived_config = cls.from_pretrained(pretrained_model_name_or_path, config=config)
130 derived_config.update(config.to_dict())
131 return cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
132 return super(LightningIRConfig, cls).from_pretrained(pretrained_model_name_or_path, *args, **kwargs)