Source code for lightning_ir.base.adapter

 1"""
 2Adapter module for Lightning IR models.
 3
 4This module provides LoRA adapter support for Lightning IR models using the PEFT library.
 5The adapter functionality is optional and only enabled when explicitly configured.
 6"""
 7
 8from __future__ import annotations
 9
10try:
11    from peft import LoraConfig, get_peft_model
12
13    PEFT_AVAILABLE = True
14except ImportError:
15    PEFT_AVAILABLE = False
16
17
[docs] 18class LightningIRAdapterMixin: 19 """Mixin class that adds LoRA adapter functionality to Lightning IR models.""" 20
[docs] 21 def __init__(self, *args, **kwargs): 22 super().__init__(*args, **kwargs) 23 self._adapter_enabled = False 24 self._hf_peft_config_loaded = False
25
[docs] 26 def init_adapters(self, adapter_config: LoraConfig) -> None: 27 """Enable LoRA adapters on the model. 28 29 Args: 30 adapter_config: Configuration for the LoRA adapter. 31 32 Raises: 33 ImportError: If PEFT is not available. 34 ValueError: If adapters are already enabled. 35 """ 36 if not PEFT_AVAILABLE: 37 raise ImportError( 38 "PEFT is required for adapter functionality. " "Install it with: pip install lightning-ir[adapters]" 39 ) 40 41 if self._hf_peft_config_loaded: 42 raise ValueError("Adapters are already enabled on this model") 43 44 peft_model = get_peft_model(self, adapter_config) 45 46 for name, module in peft_model.named_children(): 47 if hasattr(self, name) and name != "base_model": 48 original_module = getattr(self, name) 49 if original_module is not module: # Only set if it's actually different 50 setattr(self, name, module) 51 52 self._adapter_enabled = True 53 self._hf_peft_config_loaded = True
54
[docs] 55 def disable_adapters(self) -> None: 56 """Disable LoRA adapters.""" 57 if not self._adapter_enabled: 58 return 59 if hasattr(self, "disable_adapter_layers"): 60 self.disable_adapter_layers() 61 elif hasattr(self, "disable_adapter"): 62 self.disable_adapter()
63
[docs] 64 def enable_adapters(self) -> None: 65 """(Re-)Enable LoRA adapters.""" 66 if self._adapter_enabled: 67 return 68 if hasattr(self, "enable_adapter_layers"): 69 self.enable_adapter_layers() 70 elif hasattr(self, "enable_adapter"): 71 self.enable_adapter()