Source code for lightning_ir.base.model

  1"""
  2Model module for Lightning IR.
  3
  4This module contains the main model class and output class for the Lightning IR library.
  5"""
  6
  7from collections import defaultdict
  8from dataclasses import dataclass
  9from functools import wraps
 10from pathlib import Path
 11from typing import Any, Literal, Mapping, Protocol, Self, Sequence, Type, TypeVar
 12
 13import torch
 14from transformers import BatchEncoding, BertModel, PreTrainedModel
 15from transformers.modeling_outputs import ModelOutput
 16
 17from .class_factory import LightningIRModelClassFactory, _get_model_class
 18from .config import LightningIRConfig
 19from .external_model_hub import CHECKPOINT_MAPPING, POST_LOAD_CALLBACKS, STATE_DICT_KEY_MAPPING
 20
 21
 22def _update_config_with_kwargs(config: LightningIRConfig, **kwargs):
 23    config.update(kwargs)
 24
 25    used_keys = set(config.to_dict().keys()) & set(kwargs.keys())
 26
 27    for key in used_keys:
 28        kwargs.pop(key)
 29
 30    return config, kwargs
 31
 32
[docs] 33@dataclass 34class LightningIROutput(ModelOutput): 35 """Base class for the output of the Lightning IR model. It is a subclass of transformers.ModelOutput_. 36 37 .. _transformers.ModelOutput: https://huggingface.co/transformers/main_classes/output.html#transformers.ModelOutput 38 39 Attributes: 40 scores (torch.Tensor | None): Output relevance scores for query--document pairs. Defaults to None. 41 """ 42 43 scores: torch.Tensor | None = None
44 45
[docs] 46class LightningIRModel(PreTrainedModel): 47 """Base class for Lightning IR models. Derived classes implement the forward method for handling query 48 and document embeddings. It acts as mixin for a transformers.PreTrainedModel_ backbone model. 49 50 .. _transformers.PreTrainedModel: \ 51https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel 52 53 Attributes: 54 config_class (Type[LightningIRConfig]): Configuration class for the model. 55 ALLOW_SUB_BATCHING (bool): Flag to allow mini batches of documents for a single query. 56 Set to false for listwise models to ensure correctness. 57 """ 58 59 config_class: Type[LightningIRConfig] = LightningIRConfig 60 """Configuration class for the model.""" 61 62 ALLOW_SUB_BATCHING = True 63 """Flag to allow mini batches of documents for a single query. Set to false for listwise models to ensure 64 correctness.""" 65
[docs] 66 def __init__(self, config: LightningIRConfig, *args, **kwargs) -> None: 67 """Initializes the model. 68 69 Args: 70 config(LightningIRConfig): Configuration class for the model 71 """ 72 super().__init__(config, *args, **kwargs) 73 self.config = config 74 75 self._sub_batch_size: int | None = None
76 77 def _backbone_forward(self, *args, **kwargs): 78 """Runs the forward method of the backbone model. Is overridden in 79 :class:`~lightning_ir.base.class_factory.LightningIRModelClassFactory`. 80 81 Raises: 82 NotImplementedError: If not overridden in the derived class 83 """ 84 raise NotImplementedError 85
[docs] 86 def forward(self, *args, **kwargs) -> LightningIROutput: 87 """Forward method of the model. Must be implemented by the derived class.""" 88 raise NotImplementedError
89
[docs] 90 def sparsification( 91 self, embeddings: torch.Tensor, sparsification_strategy: Literal["relu", "relu_log"] | None = None 92 ) -> torch.Tensor: 93 """Helper method to apply sparsification to the embeddings. 94 95 Args: 96 embeddings(torch.Tensor): Query or document embeddings 97 sparsification_strategy(Literal['relu', 'relu_log'] | None): The sparsification strategy. No 98 sparsification is applied if None. Defaults to None. 99 Returns: 100 torch.Tensor: (Optionally) sparsified embeddings. 101 Raises: 102 ValueError: If an unknown sparsification strategy is passed. 103 """ 104 if sparsification_strategy is None: 105 return embeddings 106 if sparsification_strategy == "relu": 107 return torch.relu(embeddings) 108 if sparsification_strategy == "relu_log": 109 return torch.log1p(torch.relu(embeddings)) 110 raise ValueError(f"Unknown sparsification strategy: {sparsification_strategy}")
111
[docs] 112 def pooling( 113 self, 114 embeddings: torch.Tensor, 115 attention_mask: torch.Tensor | None, 116 pooling_strategy: Literal["first", "mean", "max", "sum"] | None, 117 ) -> torch.Tensor: 118 """Helper method to apply pooling to the embeddings. 119 120 Args: 121 embeddings (torch.Tensor): Query or document embeddings 122 attention_mask (torch.Tensor | None): Query or document attention mask 123 pooling_strategy (Literal['first', 'mean', 'max', 'sum'] | None): 124 The pooling strategy. No pooling is applied if None. 125 Returns: 126 torch.Tensor: (Optionally) pooled embeddings. 127 Raises: 128 ValueError: If an unknown pooling strategy is passed. 129 """ 130 if pooling_strategy is None: 131 return embeddings 132 if pooling_strategy == "first": 133 return embeddings.index_select(1, torch.tensor(0, device=embeddings.device)) 134 if pooling_strategy in ("sum", "mean"): 135 if attention_mask is not None: 136 embeddings = embeddings * attention_mask.unsqueeze(-1) 137 embeddings = embeddings.sum(dim=1, keepdim=True) 138 if pooling_strategy == "mean": 139 if attention_mask is not None: 140 embeddings = embeddings / attention_mask.sum(dim=1, keepdim=True).unsqueeze(-1) 141 return embeddings 142 if pooling_strategy == "max": 143 if attention_mask is not None: 144 embeddings = embeddings.masked_fill(~attention_mask.bool().unsqueeze(-1), float("-inf")) 145 return embeddings.amax(dim=1, keepdim=True) 146 raise ValueError(f"Unknown pooling strategy: {pooling_strategy}")
147
[docs] 148 @classmethod 149 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> Self: 150 """Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained_ method to return a 151 derived LightningIRModel. See :class:`LightningIRModelClassFactory` for more details. 152 153.. _transformers.PreTrainedModel.from_pretrained: \ 154 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained 155 156 .. ::doctest 157 .. highlight:: python 158 .. code-block:: python 159 160 >>> # Loading using model class and backbone checkpoint 161 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) 162 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 163 >>> # Loading using base class and backbone checkpoint 164 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) 165 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 166 167 Args: 168 model_name_or_path (str | Path): Name or path of the pretrained model. 169 Raises: 170 ValueError: If called on the abstract class `LightningIRModel` and no config is passed. 171 Returns: 172 LightningIRModel: A derived `LightningIRModel` consisting of a backbone model 173 and a `LightningIRModel` mixin. 174 """ 175 # provides AutoModel.from_pretrained support 176 config = kwargs.get("config", None) 177 if cls is LightningIRModel or all(issubclass(base, LightningIRModel) for base in cls.__bases__): 178 # no backbone models found, create derived lightning-ir model based on backbone model 179 if config is not None: 180 ConfigClass = config.__class__ 181 elif model_name_or_path in CHECKPOINT_MAPPING: 182 _config = CHECKPOINT_MAPPING[model_name_or_path] 183 ConfigClass = _config.__class__ 184 if config is None: 185 config = _config 186 elif cls is not LightningIRModel: 187 ConfigClass = cls.config_class 188 else: 189 ConfigClass = type(LightningIRModelClassFactory.get_lightning_ir_config(model_name_or_path)) 190 if ConfigClass is None: 191 raise ValueError("Pass a config to `from_pretrained`.") 192 backbone_config = LightningIRModelClassFactory.get_backbone_config(model_name_or_path).from_pretrained( 193 model_name_or_path 194 ) 195 BackboneModel = _get_model_class(backbone_config) 196 cls = LightningIRModelClassFactory(ConfigClass).from_backbone_class(BackboneModel) 197 if config is not None: 198 if all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__): 199 derived_config = cls.config_class.from_pretrained(model_name_or_path, config=config) 200 derived_config.update(config.to_diff_dict()) 201 config = derived_config 202 kwargs["config"] = config 203 # NOTE 'config' is contained in kwargs, so we can update it 204 config, kwargs = _update_config_with_kwargs(**kwargs) 205 kwargs["config"] = config 206 return cls.from_pretrained(model_name_or_path, *args, **kwargs) 207 if issubclass(cls, BertModel): 208 kwargs["add_pooling_layer"] = False 209 key_mapping = kwargs.pop("key_mapping", {}) 210 if model_name_or_path in STATE_DICT_KEY_MAPPING: 211 key_mapping.update(STATE_DICT_KEY_MAPPING[str(model_name_or_path)]) 212 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs) 213 if model_name_or_path in POST_LOAD_CALLBACKS: 214 model = POST_LOAD_CALLBACKS[str(model_name_or_path)](model) 215 return model
216 217 218T = TypeVar("T") 219 220 221def _cat_outputs( 222 outputs: Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None], OutputClass: Type[T] | None 223) -> torch.Tensor | T | None: 224 """Helper method to concatenate outputs of the model. 225 226 Args: 227 outputs (Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None]): Outputs from the model. 228 OutputClass (Type[T] | None): Class to return the concatenated output as. 229 Returns: 230 torch.Tensor | T | None: Concatenated output. 231 """ 232 if len(outputs) == 1: 233 return outputs[0] 234 if len(outputs) == 0 or outputs[0] is None or OutputClass is None: 235 return None 236 if isinstance(outputs[0], torch.Tensor): 237 return torch.cat(outputs, dim=0) 238 agg = defaultdict(list) 239 types = {} 240 for output in outputs: 241 for key, value in output.items(): 242 agg[key].append(value) 243 types[key] = type(value) 244 kwargs = {key: _cat_outputs(value, types[key]) for key, value in agg.items()} 245 if OutputClass is BatchEncoding: 246 return OutputClass(kwargs) 247 return OutputClass(**kwargs) 248 249
[docs] 250class BatchEncodingWrapper(Protocol): 251 def __call__(self, encoding: BatchEncoding, *args, **kwargs) -> Any: ...
252 253
[docs] 254def batch_encoding_wrapper(func: BatchEncodingWrapper) -> BatchEncodingWrapper: 255 """Decorator to enable sub-batching for models that support it. Lowers the batch size of the input batch encoding 256 if the model runs out of memory. 257 258 Args: 259 func (BatchEncodingWrapper): Function to wrap that takes a batch encoding. 260 Returns: 261 BatchEncodingWrapper: Wrapped function that handles sub-batching. 262 Raises: 263 RuntimeError: If CUDA runs out of memory and the batch size cannot be lowered further. 264 ValueError: If no output was generated. 265 """ 266 267 @wraps(func) 268 def wrapper(self, encoding: BatchEncoding, *args, **kwargs) -> Any: 269 if not self.ALLOW_SUB_BATCHING: 270 return func(self, encoding, *args, **kwargs) 271 sub_batch_size = self._sub_batch_size or encoding.input_ids.shape[0] 272 sub_encoding = encoding 273 remaining_encoding = encoding 274 OutputClass = None 275 outputs = [] 276 while True: 277 try: 278 # ceil division 279 num_batches = -(remaining_encoding.input_ids.shape[0] // -sub_batch_size) 280 for _ in range(num_batches): 281 sub_encoding = BatchEncoding( 282 {key: value[:sub_batch_size] for key, value in remaining_encoding.items()} 283 ) 284 output = func(self, sub_encoding, *args, **kwargs) 285 OutputClass = output.__class__ 286 outputs.append(output) 287 remaining_encoding = BatchEncoding( 288 {key: value[sub_batch_size:] for key, value in remaining_encoding.items()} 289 ) 290 break 291 except RuntimeError as e: 292 if "CUDA out of memory" in str(e) or "CUDACachingAllocator.cpp" in str(e): 293 self._sub_batch_size = sub_batch_size = sub_batch_size // 2 294 if sub_batch_size == 0: 295 raise e 296 else: 297 raise e 298 if OutputClass is None: 299 raise ValueError("No output was generated.") 300 return _cat_outputs(outputs, OutputClass) 301 302 return wrapper