Source code for lightning_ir.base.model

  1"""
  2Model module for Lightning IR.
  3
  4This module contains the main model class and output class for the Lightning IR library.
  5"""
  6
  7from collections import defaultdict
  8from collections.abc import Mapping, Sequence
  9from dataclasses import dataclass
 10from functools import wraps
 11from pathlib import Path
 12from typing import Any, Protocol, Self, TypeVar
 13
 14import torch
 15from transformers import BatchEncoding, BertModel, PreTrainedModel
 16from transformers.modeling_outputs import ModelOutput
 17
 18from .adapter import LightningIRAdapterMixin
 19from .class_factory import LightningIRModelClassFactory, _get_model_class
 20from .config import LightningIRConfig
 21from .external_model_hub import (
 22    BACKBONE_MAPPING,
 23    CHECKPOINT_MAPPING,
 24    POST_LOAD_CALLBACKS,
 25    STATE_DICT_KEY_MAPPING,
 26)
 27
 28
 29def _update_config_with_kwargs(config: LightningIRConfig, **kwargs):
 30    config.update(kwargs)
 31
 32    used_keys = set(config.to_dict().keys()) & set(kwargs.keys())
 33
 34    for key in used_keys:
 35        kwargs.pop(key)
 36
 37    return config, kwargs
 38
 39
[docs] 40@dataclass 41class LightningIROutput(ModelOutput): 42 """Base class for the output of the Lightning IR model. It is a subclass of transformers.ModelOutput_. 43 44 .. _transformers.ModelOutput: https://huggingface.co/transformers/main_classes/output.html#transformers.ModelOutput 45 46 Attributes: 47 scores (torch.Tensor | None): Output relevance scores for query--document pairs. Defaults to None. 48 """ 49 50 scores: torch.Tensor | None = None
51 52
[docs] 53class LightningIRModel(LightningIRAdapterMixin, PreTrainedModel): 54 """Base class for Lightning IR models. Derived classes implement the forward method for handling query 55 and document embeddings. It acts as mixin for a transformers.PreTrainedModel_ backbone model. 56 57 .. _transformers.PreTrainedModel: \ 58https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel 59 60 Attributes: 61 config_class (type[LightningIRConfig]): Configuration class for the model. 62 ALLOW_SUB_BATCHING (bool): Flag to allow mini batches of documents for a single query. 63 set to false for listwise models to ensure correctness. 64 """ 65 66 config_class: type[LightningIRConfig] = LightningIRConfig 67 """Configuration class for the model.""" 68 69 ALLOW_SUB_BATCHING = True 70 """Flag to allow mini batches of documents for a single query. set to false for listwise models to ensure 71 correctness.""" 72
[docs] 73 def __init__(self, config: LightningIRConfig, *args, **kwargs) -> None: 74 """Initializes the model. 75 76 Args: 77 config(LightningIRConfig): Configuration class for the model 78 """ 79 super().__init__(config, *args, **kwargs) 80 self.config = config 81 82 self._sub_batch_size: int | None = None
83 84 def _initialize_adapters(self) -> None: 85 """Initialize adapters based on configuration.""" 86 if not self.config.use_adapter: 87 return 88 89 # Enable adapters if configuration is provided 90 if self.config.adapter_config is not None: 91 self.init_adapters(self.config.adapter_config) 92 93 # Load adapter weights if path is provided 94 if self.config.pretrained_adapter_name_or_path is not None: 95 self.load_adapter(self.config.pretrained_adapter_name_or_path) 96 97 def _backbone_forward(self, *args, **kwargs): 98 """Runs the forward method of the backbone model. Is overridden in 99 :class:`~lightning_ir.base.class_factory.LightningIRModelClassFactory`. 100 101 Raises: 102 NotImplementedError: If not overridden in the derived class 103 """ 104 raise NotImplementedError 105
[docs] 106 def forward(self, *args, **kwargs) -> LightningIROutput: 107 """Forward method of the model. Must be implemented by the derived class.""" 108 raise NotImplementedError
109
[docs] 110 @classmethod 111 def from_pretrained( 112 cls, model_name_or_path: str | Path, *args, BackboneModel: type[PreTrainedModel] | None = None, **kwargs 113 ) -> Self: 114 """Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained_ method to return a 115 derived LightningIRModel. See :class:`LightningIRModelClassFactory` for more details. 116 117.. _transformers.PreTrainedModel.from_pretrained: \ 118 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained 119 120 .. ::doctest 121 .. highlight:: python 122 .. code-block:: python 123 124 >>> # Loading using model class and backbone checkpoint 125 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) 126 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 127 >>> # Loading using base class and backbone checkpoint 128 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) 129 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 130 131 Args: 132 model_name_or_path (str | Path): Name or path of the pretrained model. 133 BackboneModel (type[PreTrainedModel] | None): Huggingface PreTrainedModel class to use as backbone 134 instead of the default AutoModel. Defaults to None. 135 Raises: 136 ValueError: If called on the abstract class `LightningIRModel` and no config is passed. 137 Returns: 138 LightningIRModel: A derived `LightningIRModel` consisting of a backbone model 139 and a `LightningIRModel` mixin. 140 """ 141 # provides AutoModel.from_pretrained support 142 config = kwargs.get("config", None) 143 if cls is LightningIRModel or all(issubclass(base, LightningIRModel) for base in cls.__bases__): 144 # no backbone models found, create derived lightning-ir model based on backbone model 145 if config is not None: 146 ConfigClass = config.__class__ 147 elif model_name_or_path in CHECKPOINT_MAPPING: 148 _config = CHECKPOINT_MAPPING[model_name_or_path] 149 ConfigClass = _config.__class__ 150 if config is None: 151 config = _config 152 elif cls is not LightningIRModel: 153 ConfigClass = cls.config_class 154 else: 155 ConfigClass = type(LightningIRModelClassFactory.get_lightning_ir_config(model_name_or_path)) 156 if ConfigClass is None: 157 raise ValueError("Pass a config to `from_pretrained`.") 158 if BackboneModel is None: 159 if model_name_or_path in BACKBONE_MAPPING: 160 BackboneModel = BACKBONE_MAPPING[str(model_name_or_path)] 161 else: 162 backbone_config = LightningIRModelClassFactory.get_backbone_config( 163 model_name_or_path 164 ).from_pretrained(model_name_or_path) 165 BackboneModel = _get_model_class(backbone_config) 166 cls = LightningIRModelClassFactory(ConfigClass).from_backbone_class(BackboneModel) 167 if config is not None: 168 if all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__): 169 derived_config = cls.config_class.from_pretrained(model_name_or_path, config=config) 170 derived_config.update(config.to_diff_dict()) 171 config = derived_config 172 kwargs["config"] = config 173 # NOTE 'config' is contained in kwargs, so we can update it 174 config, kwargs = _update_config_with_kwargs(**kwargs) 175 kwargs["config"] = config 176 return cls.from_pretrained(model_name_or_path, *args, **kwargs) 177 if issubclass(cls, BertModel): 178 kwargs["add_pooling_layer"] = False 179 key_mapping = kwargs.pop("key_mapping", {}) 180 if model_name_or_path in STATE_DICT_KEY_MAPPING: 181 key_mapping.update(STATE_DICT_KEY_MAPPING[str(model_name_or_path)]) 182 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs) 183 if model_name_or_path in POST_LOAD_CALLBACKS: 184 model = POST_LOAD_CALLBACKS[str(model_name_or_path)](model) 185 186 # Initialize adapters after model is fully loaded 187 model._initialize_adapters() 188 189 return model
190 191 192T = TypeVar("T") 193 194 195def _cat_outputs( 196 outputs: Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None], OutputClass: type[T] | None 197) -> torch.Tensor | T | None: 198 """Helper method to concatenate outputs of the model. 199 200 Args: 201 outputs (Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None]): Outputs from the model. 202 OutputClass (type[T] | None): Class to return the concatenated output as. 203 Returns: 204 torch.Tensor | T | None: Concatenated output. 205 """ 206 if len(outputs) == 1: 207 return outputs[0] 208 if len(outputs) == 0 or outputs[0] is None or OutputClass is None: 209 return None 210 if isinstance(outputs[0], torch.Tensor): 211 return torch.cat(outputs, dim=0) 212 agg = defaultdict(list) 213 types = {} 214 for output in outputs: 215 for key, value in output.items(): 216 agg[key].append(value) 217 types[key] = type(value) 218 kwargs = {key: _cat_outputs(value, types[key]) for key, value in agg.items()} 219 if OutputClass is BatchEncoding: 220 return OutputClass(kwargs) 221 return OutputClass(**kwargs) 222 223
[docs] 224class BatchEncodingWrapper(Protocol): 225 def __call__(self, encoding: BatchEncoding, *args, **kwargs) -> Any: ...
226 227
[docs] 228def batch_encoding_wrapper(func: BatchEncodingWrapper) -> BatchEncodingWrapper: 229 """Decorator to enable sub-batching for models that support it. Lowers the batch size of the input batch encoding 230 if the model runs out of memory. 231 232 Args: 233 func (BatchEncodingWrapper): Function to wrap that takes a batch encoding. 234 Returns: 235 BatchEncodingWrapper: Wrapped function that handles sub-batching. 236 Raises: 237 RuntimeError: If CUDA runs out of memory and the batch size cannot be lowered further. 238 ValueError: If no output was generated. 239 """ 240 241 @wraps(func) 242 def wrapper(self, encoding: BatchEncoding, *args, **kwargs) -> Any: 243 if not self.ALLOW_SUB_BATCHING: 244 return func(self, encoding, *args, **kwargs) 245 sub_batch_size = self._sub_batch_size or encoding.input_ids.shape[0] 246 sub_encoding = encoding 247 remaining_encoding = encoding 248 OutputClass = None 249 outputs = [] 250 while True: 251 try: 252 # ceil division 253 num_batches = -(remaining_encoding.input_ids.shape[0] // -sub_batch_size) 254 for _ in range(num_batches): 255 sub_encoding = BatchEncoding( 256 {key: value[:sub_batch_size] for key, value in remaining_encoding.items()} 257 ) 258 output = func(self, sub_encoding, *args, **kwargs) 259 OutputClass = output.__class__ 260 outputs.append(output) 261 remaining_encoding = BatchEncoding( 262 {key: value[sub_batch_size:] for key, value in remaining_encoding.items()} 263 ) 264 break 265 except RuntimeError as e: 266 if "CUDA out of memory" in str(e) or "CUDACachingAllocator.cpp" in str(e): 267 self._sub_batch_size = sub_batch_size = sub_batch_size // 2 268 if sub_batch_size == 0: 269 raise e 270 else: 271 raise e 272 if OutputClass is None: 273 raise ValueError("No output was generated.") 274 return _cat_outputs(outputs, OutputClass) 275 276 return wrapper