Source code for lightning_ir.base.model

  1"""
  2Model module for Lightning IR.
  3
  4This module contains the main model class and output class for the Lightning IR library.
  5"""
  6
  7from collections import defaultdict
  8from dataclasses import dataclass
  9from functools import wraps
 10from pathlib import Path
 11from typing import Any, Literal, Mapping, Protocol, Self, Sequence, Type, TypeVar
 12
 13import torch
 14from transformers import BatchEncoding, BertModel, PreTrainedModel
 15from transformers.modeling_outputs import ModelOutput
 16
 17from .adapter import LightningIRAdapterMixin
 18from .class_factory import LightningIRModelClassFactory, _get_model_class
 19from .config import LightningIRConfig
 20from .external_model_hub import BACKBONE_MAPPING, CHECKPOINT_MAPPING, POST_LOAD_CALLBACKS, STATE_DICT_KEY_MAPPING
 21
 22
 23def _update_config_with_kwargs(config: LightningIRConfig, **kwargs):
 24    config.update(kwargs)
 25
 26    used_keys = set(config.to_dict().keys()) & set(kwargs.keys())
 27
 28    for key in used_keys:
 29        kwargs.pop(key)
 30
 31    return config, kwargs
 32
 33
[docs] 34@dataclass 35class LightningIROutput(ModelOutput): 36 """Base class for the output of the Lightning IR model. It is a subclass of transformers.ModelOutput_. 37 38 .. _transformers.ModelOutput: https://huggingface.co/transformers/main_classes/output.html#transformers.ModelOutput 39 40 Attributes: 41 scores (torch.Tensor | None): Output relevance scores for query--document pairs. Defaults to None. 42 """ 43 44 scores: torch.Tensor | None = None
45 46
[docs] 47class LightningIRModel(LightningIRAdapterMixin, PreTrainedModel): 48 """Base class for Lightning IR models. Derived classes implement the forward method for handling query 49 and document embeddings. It acts as mixin for a transformers.PreTrainedModel_ backbone model. 50 51 .. _transformers.PreTrainedModel: \ 52https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel 53 54 Attributes: 55 config_class (Type[LightningIRConfig]): Configuration class for the model. 56 ALLOW_SUB_BATCHING (bool): Flag to allow mini batches of documents for a single query. 57 Set to false for listwise models to ensure correctness. 58 """ 59 60 config_class: Type[LightningIRConfig] = LightningIRConfig 61 """Configuration class for the model.""" 62 63 ALLOW_SUB_BATCHING = True 64 """Flag to allow mini batches of documents for a single query. Set to false for listwise models to ensure 65 correctness.""" 66
[docs] 67 def __init__(self, config: LightningIRConfig, *args, **kwargs) -> None: 68 """Initializes the model. 69 70 Args: 71 config(LightningIRConfig): Configuration class for the model 72 """ 73 super().__init__(config, *args, **kwargs) 74 self.config = config 75 76 self._sub_batch_size: int | None = None
77 78 def _initialize_adapters(self) -> None: 79 """Initialize adapters based on configuration.""" 80 if not self.config.use_adapter: 81 return 82 83 # Enable adapters if configuration is provided 84 if self.config.adapter_config is not None: 85 self.init_adapters(self.config.adapter_config) 86 87 # Load adapter weights if path is provided 88 if self.config.pretrained_adapter_name_or_path is not None: 89 self.load_adapter(self.config.pretrained_adapter_name_or_path) 90 91 def _backbone_forward(self, *args, **kwargs): 92 """Runs the forward method of the backbone model. Is overridden in 93 :class:`~lightning_ir.base.class_factory.LightningIRModelClassFactory`. 94 95 Raises: 96 NotImplementedError: If not overridden in the derived class 97 """ 98 raise NotImplementedError 99
[docs] 100 def forward(self, *args, **kwargs) -> LightningIROutput: 101 """Forward method of the model. Must be implemented by the derived class.""" 102 raise NotImplementedError
103
[docs] 104 def sparsification( 105 self, embeddings: torch.Tensor, sparsification_strategy: Literal["relu", "relu_log", "relu_2xlog"] | None = None 106 ) -> torch.Tensor: 107 """Helper method to apply sparsification to the embeddings. 108 109 Args: 110 embeddings(torch.Tensor): Query or document embeddings 111 sparsification_strategy(Literal['relu', 'relu_log', 'relu_2xlog'] | None): The sparsification strategy. No 112 sparsification is applied if None. Defaults to None. 113 Returns: 114 torch.Tensor: (Optionally) sparsified embeddings. 115 Raises: 116 ValueError: If an unknown sparsification strategy is passed. 117 """ 118 if sparsification_strategy is None: 119 return embeddings 120 if sparsification_strategy == "relu": 121 return torch.relu(embeddings) 122 if sparsification_strategy == "relu_log": 123 return torch.log1p(torch.relu(embeddings)) 124 if sparsification_strategy == "relu_2xlog": 125 return torch.log1p(torch.log1p(torch.relu(embeddings))) 126 raise ValueError(f"Unknown sparsification strategy: {sparsification_strategy}")
127
[docs] 128 def pooling( 129 self, 130 embeddings: torch.Tensor, 131 attention_mask: torch.Tensor | None, 132 pooling_strategy: Literal["first", "mean", "max", "sum"] | None, 133 ) -> torch.Tensor: 134 """Helper method to apply pooling to the embeddings. 135 136 Args: 137 embeddings (torch.Tensor): Query or document embeddings 138 attention_mask (torch.Tensor | None): Query or document attention mask 139 pooling_strategy (Literal['first', 'mean', 'max', 'sum'] | None): 140 The pooling strategy. No pooling is applied if None. 141 Returns: 142 torch.Tensor: (Optionally) pooled embeddings. 143 Raises: 144 ValueError: If an unknown pooling strategy is passed. 145 """ 146 if pooling_strategy is None: 147 return embeddings 148 if pooling_strategy == "first": 149 return embeddings[:, [0]] 150 if pooling_strategy in ("sum", "mean"): 151 if attention_mask is not None: 152 embeddings = embeddings * attention_mask.unsqueeze(-1) 153 embeddings = embeddings.sum(dim=1, keepdim=True) 154 if pooling_strategy == "mean": 155 if attention_mask is not None: 156 embeddings = embeddings / attention_mask.sum(dim=1, keepdim=True).unsqueeze(-1) 157 return embeddings 158 if pooling_strategy == "max": 159 if attention_mask is not None: 160 embeddings = embeddings.masked_fill(~attention_mask.bool().unsqueeze(-1), float("-inf")) 161 return embeddings.amax(dim=1, keepdim=True) 162 raise ValueError(f"Unknown pooling strategy: {pooling_strategy}")
163
[docs] 164 @classmethod 165 def from_pretrained( 166 cls, model_name_or_path: str | Path, *args, BackboneModel: Type[PreTrainedModel] | None = None, **kwargs 167 ) -> Self: 168 """Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained_ method to return a 169 derived LightningIRModel. See :class:`LightningIRModelClassFactory` for more details. 170 171.. _transformers.PreTrainedModel.from_pretrained: \ 172 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained 173 174 .. ::doctest 175 .. highlight:: python 176 .. code-block:: python 177 178 >>> # Loading using model class and backbone checkpoint 179 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) 180 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 181 >>> # Loading using base class and backbone checkpoint 182 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) 183 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 184 185 Args: 186 model_name_or_path (str | Path): Name or path of the pretrained model. 187 BackboneModel (Type[PreTrainedModel] | None): Huggingface PreTrainedModel class to use as backbone 188 instead of the default AutoModel. Defaults to None. 189 Raises: 190 ValueError: If called on the abstract class `LightningIRModel` and no config is passed. 191 Returns: 192 LightningIRModel: A derived `LightningIRModel` consisting of a backbone model 193 and a `LightningIRModel` mixin. 194 """ 195 # provides AutoModel.from_pretrained support 196 config = kwargs.get("config", None) 197 if cls is LightningIRModel or all(issubclass(base, LightningIRModel) for base in cls.__bases__): 198 # no backbone models found, create derived lightning-ir model based on backbone model 199 if config is not None: 200 ConfigClass = config.__class__ 201 elif model_name_or_path in CHECKPOINT_MAPPING: 202 _config = CHECKPOINT_MAPPING[model_name_or_path] 203 ConfigClass = _config.__class__ 204 if config is None: 205 config = _config 206 elif cls is not LightningIRModel: 207 ConfigClass = cls.config_class 208 else: 209 ConfigClass = type(LightningIRModelClassFactory.get_lightning_ir_config(model_name_or_path)) 210 if ConfigClass is None: 211 raise ValueError("Pass a config to `from_pretrained`.") 212 if BackboneModel is None: 213 if model_name_or_path in BACKBONE_MAPPING: 214 BackboneModel = BACKBONE_MAPPING[str(model_name_or_path)] 215 else: 216 backbone_config = LightningIRModelClassFactory.get_backbone_config( 217 model_name_or_path 218 ).from_pretrained(model_name_or_path) 219 BackboneModel = _get_model_class(backbone_config) 220 cls = LightningIRModelClassFactory(ConfigClass).from_backbone_class(BackboneModel) 221 if config is not None: 222 if all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__): 223 derived_config = cls.config_class.from_pretrained(model_name_or_path, config=config) 224 derived_config.update(config.to_diff_dict()) 225 config = derived_config 226 kwargs["config"] = config 227 # NOTE 'config' is contained in kwargs, so we can update it 228 config, kwargs = _update_config_with_kwargs(**kwargs) 229 kwargs["config"] = config 230 return cls.from_pretrained(model_name_or_path, *args, **kwargs) 231 if issubclass(cls, BertModel): 232 kwargs["add_pooling_layer"] = False 233 key_mapping = kwargs.pop("key_mapping", {}) 234 if model_name_or_path in STATE_DICT_KEY_MAPPING: 235 key_mapping.update(STATE_DICT_KEY_MAPPING[str(model_name_or_path)]) 236 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs) 237 if model_name_or_path in POST_LOAD_CALLBACKS: 238 model = POST_LOAD_CALLBACKS[str(model_name_or_path)](model) 239 240 # Initialize adapters after model is fully loaded 241 model._initialize_adapters() 242 243 return model
244 245 246T = TypeVar("T") 247 248 249def _cat_outputs( 250 outputs: Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None], OutputClass: Type[T] | None 251) -> torch.Tensor | T | None: 252 """Helper method to concatenate outputs of the model. 253 254 Args: 255 outputs (Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None]): Outputs from the model. 256 OutputClass (Type[T] | None): Class to return the concatenated output as. 257 Returns: 258 torch.Tensor | T | None: Concatenated output. 259 """ 260 if len(outputs) == 1: 261 return outputs[0] 262 if len(outputs) == 0 or outputs[0] is None or OutputClass is None: 263 return None 264 if isinstance(outputs[0], torch.Tensor): 265 return torch.cat(outputs, dim=0) 266 agg = defaultdict(list) 267 types = {} 268 for output in outputs: 269 for key, value in output.items(): 270 agg[key].append(value) 271 types[key] = type(value) 272 kwargs = {key: _cat_outputs(value, types[key]) for key, value in agg.items()} 273 if OutputClass is BatchEncoding: 274 return OutputClass(kwargs) 275 return OutputClass(**kwargs) 276 277
[docs] 278class BatchEncodingWrapper(Protocol): 279 def __call__(self, encoding: BatchEncoding, *args, **kwargs) -> Any: ...
280 281
[docs] 282def batch_encoding_wrapper(func: BatchEncodingWrapper) -> BatchEncodingWrapper: 283 """Decorator to enable sub-batching for models that support it. Lowers the batch size of the input batch encoding 284 if the model runs out of memory. 285 286 Args: 287 func (BatchEncodingWrapper): Function to wrap that takes a batch encoding. 288 Returns: 289 BatchEncodingWrapper: Wrapped function that handles sub-batching. 290 Raises: 291 RuntimeError: If CUDA runs out of memory and the batch size cannot be lowered further. 292 ValueError: If no output was generated. 293 """ 294 295 @wraps(func) 296 def wrapper(self, encoding: BatchEncoding, *args, **kwargs) -> Any: 297 if not self.ALLOW_SUB_BATCHING: 298 return func(self, encoding, *args, **kwargs) 299 sub_batch_size = self._sub_batch_size or encoding.input_ids.shape[0] 300 sub_encoding = encoding 301 remaining_encoding = encoding 302 OutputClass = None 303 outputs = [] 304 while True: 305 try: 306 # ceil division 307 num_batches = -(remaining_encoding.input_ids.shape[0] // -sub_batch_size) 308 for _ in range(num_batches): 309 sub_encoding = BatchEncoding( 310 {key: value[:sub_batch_size] for key, value in remaining_encoding.items()} 311 ) 312 output = func(self, sub_encoding, *args, **kwargs) 313 OutputClass = output.__class__ 314 outputs.append(output) 315 remaining_encoding = BatchEncoding( 316 {key: value[sub_batch_size:] for key, value in remaining_encoding.items()} 317 ) 318 break 319 except RuntimeError as e: 320 if "CUDA out of memory" in str(e) or "CUDACachingAllocator.cpp" in str(e): 321 self._sub_batch_size = sub_batch_size = sub_batch_size // 2 322 if sub_batch_size == 0: 323 raise e 324 else: 325 raise e 326 if OutputClass is None: 327 raise ValueError("No output was generated.") 328 return _cat_outputs(outputs, OutputClass) 329 330 return wrapper