Source code for lightning_ir.base.config
1"""
2Base configuration class for Lightning IR models.
3
4This module defines the configuration class `LightningIRConfig` which is used to instantiate
5a Lightning IR model. The configuration class acts as a mixin for the `transformers.PretrainedConfig`
6class from the Hugging Face Transformers library.
7"""
8
9from __future__ import annotations
10
11import inspect
12from pathlib import Path
13from typing import TYPE_CHECKING, Any, Dict, Type
14
15from transformers import PretrainedConfig
16
17from .class_factory import LightningIRConfigClassFactory
18from .external_model_hub import CHECKPOINT_MAPPING
19
20if TYPE_CHECKING:
21 from .tokenizer import LightningIRTokenizer
22
23
[docs]
24class LightningIRConfig(PretrainedConfig):
25 """The configuration class to instantiate a Lightning IR model. Acts as a mixin for the
26 transformers.PretrainedConfig_ class.
27
28 .. _transformers.PretrainedConfig: \
29https://huggingface.co/transformers/main_classes/configuration.html#transformers.PretrainedConfig
30 """
31
32 model_type = "lightning-ir"
33 """Model type for the configuration."""
34 backbone_model_type: str | None = None
35 """Backbone model type for the configuration. Set by :func:`LightningIRModelClassFactory`."""
36
[docs]
37 def __init__(self, *args, query_length: int = 32, doc_length: int = 512, **kwargs):
38 """Initializes the configuration.
39
40 :param query_length: Maximum query length, defaults to 32
41 :type query_length: int, optional
42 :param doc_length: Maximum document length, defaults to 512
43 :type doc_length: int, optional
44 """
45 super().__init__(*args, **kwargs)
46 self.query_length = query_length
47 self.doc_length = doc_length
48
[docs]
49 def get_tokenizer_kwargs(self, Tokenizer: Type[LightningIRTokenizer]) -> Dict[str, Any]:
50 """Returns the keyword arguments for the tokenizer. This method is used to pass the configuration
51 parameters to the tokenizer.
52
53 :param Tokenizer: Class of the tokenizer to be used
54 :type Tokenizer: Type[LightningIRTokenizer]
55 :return: Keyword arguments for the tokenizer
56 :rtype: Dict[str, Any]
57 """
58 return {k: getattr(self, k) for k in inspect.signature(Tokenizer.__init__).parameters if hasattr(self, k)}
59
[docs]
60 def to_dict(self) -> Dict[str, Any]:
61 """Overrides the transformers.PretrainedConfig.to_dict_ method to include the added arguments and the backbone
62 model type.
63
64 .. _transformers.PretrainedConfig.to_dict: \
65https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.to_dict
66
67 :return: Configuration dictionary
68 :rtype: Dict[str, Any]
69 """
70 output = super().to_dict()
71 if self.backbone_model_type is not None:
72 output["backbone_model_type"] = self.backbone_model_type
73 return output
74
[docs]
75 @classmethod
76 def from_pretrained(cls, pretrained_model_name_or_path: str | Path, *args, **kwargs) -> "LightningIRConfig":
77 """Loads the configuration from a pretrained model. Wraps the transformers.PretrainedConfig.from_pretrained_
78
79 .. _transformers.PretrainedConfig.from_pretrained: \
80https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.from_pretrained
81
82 :param pretrained_model_name_or_path: Pretrained model name or path
83 :type pretrained_model_name_or_path: str | Path
84 :raises ValueError: If `pre_trained_model_name_or_path` is not a Lightning IR model and no
85 :py:class:`LightningIRConfig` is passed
86 :return: Derived LightningIRConfig class
87 :rtype: LightningIRConfig
88 """
89 # provides AutoConfig.from_pretrained support
90 if cls is LightningIRConfig or all(issubclass(base, LightningIRConfig) for base in cls.__bases__):
91 # no backbone config found, create derived lightning-ir config based on backbone config
92 config = None
93 if pretrained_model_name_or_path in CHECKPOINT_MAPPING:
94 config = CHECKPOINT_MAPPING[pretrained_model_name_or_path]
95 ConfigClass = config.__class__
96 elif cls is not LightningIRConfig:
97 ConfigClass = cls
98 else:
99 ConfigClass = type(LightningIRConfigClassFactory.get_lightning_ir_config(pretrained_model_name_or_path))
100 if ConfigClass is None:
101 raise ValueError("Pass a config to `from_pretrained`.")
102 backbone_config = LightningIRConfigClassFactory.get_backbone_config(pretrained_model_name_or_path)
103 cls = LightningIRConfigClassFactory(ConfigClass).from_backbone_class(type(backbone_config))
104 if config is not None and all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__):
105 derived_config = cls.from_pretrained(pretrained_model_name_or_path, config=config)
106 derived_config.update(config.to_dict())
107 return cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
108 return super(LightningIRConfig, cls).from_pretrained(pretrained_model_name_or_path, *args, **kwargs)