Source code for lightning_ir.base.adapter
1"""
2Adapter module for Lightning IR models.
3
4This module provides LoRA adapter support for Lightning IR models using the PEFT library.
5The adapter functionality is optional and only enabled when explicitly configured.
6"""
7
8from __future__ import annotations
9
10try:
11 from peft import LoraConfig, get_peft_model
12
13 PEFT_AVAILABLE = True
14except ImportError:
15 PEFT_AVAILABLE = False
16
17
[docs]
18class LightningIRAdapterMixin:
19 """Mixin class that adds LoRA adapter functionality to Lightning IR models."""
20
[docs]
21 def __init__(self, *args, **kwargs):
22 super().__init__(*args, **kwargs)
23 self._adapter_enabled = False
24 self._hf_peft_config_loaded = False
25
[docs]
26 def init_adapters(self, adapter_config: LoraConfig) -> None:
27 """Enable LoRA adapters on the model.
28
29 Args:
30 adapter_config: Configuration for the LoRA adapter.
31
32 Raises:
33 ImportError: If PEFT is not available.
34 ValueError: If adapters are already enabled.
35 """
36 if not PEFT_AVAILABLE:
37 raise ImportError(
38 "PEFT is required for adapter functionality. " "Install it with: pip install lightning-ir[adapters]"
39 )
40
41 if self._hf_peft_config_loaded:
42 raise ValueError("Adapters are already enabled on this model")
43
44 peft_model = get_peft_model(self, adapter_config)
45
46 for name, module in peft_model.named_children():
47 if hasattr(self, name) and name != "base_model":
48 original_module = getattr(self, name)
49 if original_module is not module: # Only set if it's actually different
50 setattr(self, name, module)
51
52 self._adapter_enabled = True
53 self._hf_peft_config_loaded = True
54
[docs]
55 def disable_adapters(self) -> None:
56 """Disable LoRA adapters."""
57 if not self._adapter_enabled:
58 return
59 if hasattr(self, "disable_adapter_layers"):
60 self.disable_adapter_layers()
61 elif hasattr(self, "disable_adapter"):
62 self.disable_adapter()
63
[docs]
64 def enable_adapters(self) -> None:
65 """(Re-)Enable LoRA adapters."""
66 if self._adapter_enabled:
67 return
68 if hasattr(self, "enable_adapter_layers"):
69 self.enable_adapter_layers()
70 elif hasattr(self, "enable_adapter"):
71 self.enable_adapter()