Source code for lightning_ir.base.model

  1"""
  2Model module for Lightning IR.
  3
  4This module contains the main model class and output class for the Lightning IR library.
  5"""
  6
  7from collections import defaultdict
  8from collections.abc import Mapping, Sequence
  9from dataclasses import dataclass
 10from functools import wraps
 11from pathlib import Path
 12from typing import Any, Protocol, Self, TypeVar
 13
 14import torch
 15from transformers import BatchEncoding, BertModel, PreTrainedModel
 16from transformers.modeling_outputs import ModelOutput
 17
 18from .adapter import LightningIRAdapterMixin
 19from .class_factory import LightningIRModelClassFactory, _get_model_class
 20from .config import LightningIRConfig
 21from .external_model_hub import (
 22    BACKBONE_MAPPING,
 23    CHECKPOINT_MAPPING,
 24    POST_LOAD_CALLBACKS,
 25    STATE_DICT_KEY_MAPPING,
 26)
 27
 28
 29def _update_config_with_kwargs(config: LightningIRConfig, **kwargs):
 30    config.update(kwargs)
 31
 32    used_keys = set(config.to_dict().keys()) & set(kwargs.keys())
 33
 34    for key in used_keys:
 35        kwargs.pop(key)
 36
 37    return config, kwargs
 38
 39
[docs] 40@dataclass 41class LightningIROutput(ModelOutput): 42 """Base class for the output of the Lightning IR model. It is a subclass of transformers.ModelOutput_. 43 44 .. _transformers.ModelOutput: https://huggingface.co/transformers/main_classes/output.html#transformers.ModelOutput 45 46 Attributes: 47 scores (torch.Tensor | None): Output relevance scores for query--document pairs. Defaults to None. 48 """ 49 50 scores: torch.Tensor | None = None
51 52
[docs] 53class LightningIRModel(LightningIRAdapterMixin, PreTrainedModel): 54 """Base class for Lightning IR models. Derived classes implement the forward method for handling query 55 and document embeddings. It acts as mixin for a transformers.PreTrainedModel_ backbone model. 56 57 .. _transformers.PreTrainedModel: \ 58https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel 59 60 Attributes: 61 config_class (type[LightningIRConfig]): Configuration class for the model. 62 ALLOW_SUB_BATCHING (bool): Flag to allow mini batches of documents for a single query. 63 set to false for listwise models to ensure correctness. 64 """ 65 66 config_class: type[LightningIRConfig] = LightningIRConfig 67 """Configuration class for the model.""" 68 69 ALLOW_SUB_BATCHING = True 70 """Flag to allow mini batches of documents for a single query. set to false for listwise models to ensure 71 correctness.""" 72
[docs] 73 def __init__(self, config: LightningIRConfig, *args, **kwargs) -> None: 74 """Initializes the model. 75 76 Args: 77 config(LightningIRConfig): Configuration class for the model 78 """ 79 super().__init__(config, *args, **kwargs) 80 self.config = config 81 82 self._sub_batch_size: int | None = None
83 84 def _initialize_adapters(self) -> None: 85 """Initialize adapters based on configuration.""" 86 if not self.config.use_adapter: 87 return 88 89 # Enable adapters if configuration is provided 90 if self.config.adapter_config is not None: 91 self.init_adapters(self.config.adapter_config) 92 93 # Load adapter weights if path is provided 94 if self.config.pretrained_adapter_name_or_path is not None: 95 self.load_adapter(self.config.pretrained_adapter_name_or_path) 96 97 def _init_weights(self, module: torch.nn.Module) -> None: 98 super()._init_weights(module) 99 if self.config.backbone_model_type == "modernbert": 100 # NOTE modernbert does not initialize the weights of non-modernbert layers 101 # So we need to initialize them separately using the default initialization of the PreTrainedModel 102 PreTrainedModel._init_weights(self, module) 103 104 def _backbone_forward(self, *args, **kwargs): 105 """Runs the forward method of the backbone model. Is overridden in 106 :class:`~lightning_ir.base.class_factory.LightningIRModelClassFactory`. 107 108 Raises: 109 NotImplementedError: If not overridden in the derived class 110 """ 111 raise NotImplementedError 112
[docs] 113 def forward(self, *args, **kwargs) -> LightningIROutput: 114 """Forward method of the model. Must be implemented by the derived class.""" 115 raise NotImplementedError
116
[docs] 117 @classmethod 118 def from_pretrained( 119 cls, model_name_or_path: str | Path, *args, BackboneModel: type[PreTrainedModel] | None = None, **kwargs 120 ) -> Self: 121 """Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained_ method to return a 122 derived LightningIRModel. See :class:`LightningIRModelClassFactory` for more details. 123 124.. _transformers.PreTrainedModel.from_pretrained: \ 125 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained 126 127 .. ::doctest 128 .. highlight:: python 129 .. code-block:: python 130 131 >>> # Loading using model class and backbone checkpoint 132 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) 133 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 134 >>> # Loading using base class and backbone checkpoint 135 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) 136 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 137 138 Args: 139 model_name_or_path (str | Path): Name or path of the pretrained model. 140 BackboneModel (type[PreTrainedModel] | None): Huggingface PreTrainedModel class to use as backbone 141 instead of the default AutoModel. Defaults to None. 142 Raises: 143 ValueError: If called on the abstract class `LightningIRModel` and no config is passed. 144 Returns: 145 LightningIRModel: A derived `LightningIRModel` consisting of a backbone model 146 and a `LightningIRModel` mixin. 147 """ 148 # provides AutoModel.from_pretrained support 149 config = kwargs.get("config", None) 150 if cls is LightningIRModel or all(issubclass(base, LightningIRModel) for base in cls.__bases__): 151 # no backbone models found, create derived lightning-ir model based on backbone model 152 if config is not None: 153 ConfigClass = config.__class__ 154 elif model_name_or_path in CHECKPOINT_MAPPING: 155 _config = CHECKPOINT_MAPPING[model_name_or_path] 156 ConfigClass = _config.__class__ 157 if config is None: 158 config = _config 159 elif cls is not LightningIRModel: 160 ConfigClass = cls.config_class 161 else: 162 ConfigClass = type(LightningIRModelClassFactory.get_lightning_ir_config(model_name_or_path)) 163 if ConfigClass is None: 164 raise ValueError("Pass a config to `from_pretrained`.") 165 if BackboneModel is None: 166 if model_name_or_path in BACKBONE_MAPPING: 167 BackboneModel = BACKBONE_MAPPING[str(model_name_or_path)] 168 else: 169 backbone_config = LightningIRModelClassFactory.get_backbone_config( 170 model_name_or_path 171 ).from_pretrained(model_name_or_path) 172 BackboneModel = _get_model_class(backbone_config) 173 cls = LightningIRModelClassFactory(ConfigClass).from_backbone_class(BackboneModel) 174 if config is not None: 175 if all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__): 176 derived_config = cls.config_class.from_pretrained(model_name_or_path, config=config) 177 derived_config.update(config.to_diff_dict()) 178 config = derived_config 179 kwargs["config"] = config 180 # NOTE 'config' is contained in kwargs, so we can update it 181 config, kwargs = _update_config_with_kwargs(**kwargs) 182 kwargs["config"] = config 183 return cls.from_pretrained(model_name_or_path, *args, **kwargs) 184 if issubclass(cls, BertModel): 185 kwargs["add_pooling_layer"] = False 186 key_mapping = kwargs.pop("key_mapping", {}) 187 if model_name_or_path in STATE_DICT_KEY_MAPPING: 188 key_mapping.update(STATE_DICT_KEY_MAPPING[str(model_name_or_path)]) 189 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs) 190 if model_name_or_path in POST_LOAD_CALLBACKS: 191 model = POST_LOAD_CALLBACKS[str(model_name_or_path)](model) 192 193 # Initialize adapters after model is fully loaded 194 model._initialize_adapters() 195 196 return model
197 198 199T = TypeVar("T") 200 201 202def _cat_outputs( 203 outputs: Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None], OutputClass: type[T] | None 204) -> torch.Tensor | T | None: 205 """Helper method to concatenate outputs of the model. 206 207 Args: 208 outputs (Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None]): Outputs from the model. 209 OutputClass (type[T] | None): Class to return the concatenated output as. 210 Returns: 211 torch.Tensor | T | None: Concatenated output. 212 """ 213 if len(outputs) == 1: 214 return outputs[0] 215 if len(outputs) == 0 or outputs[0] is None or OutputClass is None: 216 return None 217 if isinstance(outputs[0], torch.Tensor): 218 return torch.cat(outputs, dim=0) 219 agg = defaultdict(list) 220 types = {} 221 for output in outputs: 222 for key, value in output.items(): 223 agg[key].append(value) 224 types[key] = type(value) 225 kwargs = {key: _cat_outputs(value, types[key]) for key, value in agg.items()} 226 if OutputClass is BatchEncoding: 227 return OutputClass(kwargs) 228 return OutputClass(**kwargs) 229 230
[docs] 231class BatchEncodingWrapper(Protocol): 232 def __call__(self, encoding: BatchEncoding, *args, **kwargs) -> Any: ...
233 234
[docs] 235def batch_encoding_wrapper(func: BatchEncodingWrapper) -> BatchEncodingWrapper: 236 """Decorator to enable sub-batching for models that support it. Lowers the batch size of the input batch encoding 237 if the model runs out of memory. 238 239 Args: 240 func (BatchEncodingWrapper): Function to wrap that takes a batch encoding. 241 Returns: 242 BatchEncodingWrapper: Wrapped function that handles sub-batching. 243 Raises: 244 RuntimeError: If CUDA runs out of memory and the batch size cannot be lowered further. 245 ValueError: If no output was generated. 246 """ 247 248 @wraps(func) 249 def wrapper(self, encoding: BatchEncoding, *args, **kwargs) -> Any: 250 if not self.ALLOW_SUB_BATCHING: 251 return func(self, encoding, *args, **kwargs) 252 sub_batch_size = self._sub_batch_size or encoding.input_ids.shape[0] 253 sub_encoding = encoding 254 remaining_encoding = encoding 255 OutputClass = None 256 outputs = [] 257 while True: 258 try: 259 # ceil division 260 num_batches = -(remaining_encoding.input_ids.shape[0] // -sub_batch_size) 261 for _ in range(num_batches): 262 sub_encoding = BatchEncoding( 263 {key: value[:sub_batch_size] for key, value in remaining_encoding.items()} 264 ) 265 output = func(self, sub_encoding, *args, **kwargs) 266 OutputClass = output.__class__ 267 outputs.append(output) 268 remaining_encoding = BatchEncoding( 269 {key: value[sub_batch_size:] for key, value in remaining_encoding.items()} 270 ) 271 break 272 except RuntimeError as e: 273 if "CUDA out of memory" in str(e) or "CUDACachingAllocator.cpp" in str(e): 274 self._sub_batch_size = sub_batch_size = sub_batch_size // 2 275 if sub_batch_size == 0: 276 raise e 277 else: 278 raise e 279 if OutputClass is None: 280 raise ValueError("No output was generated.") 281 return _cat_outputs(outputs, OutputClass) 282 283 return wrapper