Source code for lightning_ir.base.model

  1"""
  2Model module for Lightning IR.
  3
  4This module contains the main model class and output class for the Lightning IR library.
  5"""
  6
  7from collections import defaultdict
  8from dataclasses import dataclass
  9from functools import wraps
 10from pathlib import Path
 11from typing import Any, Literal, Mapping, Protocol, Self, Sequence, Type, TypeVar
 12
 13import torch
 14from transformers import BatchEncoding, BertModel, PreTrainedModel
 15from transformers.modeling_outputs import ModelOutput
 16
 17from .class_factory import LightningIRModelClassFactory, _get_model_class
 18from .config import LightningIRConfig
 19from .external_model_hub import CHECKPOINT_MAPPING, POST_LOAD_CALLBACKS, STATE_DICT_KEY_MAPPING
 20
 21
 22def _update_config_with_kwargs(config: LightningIRConfig, **kwargs):
 23    config.update(kwargs)
 24
 25    used_keys = set(config.to_dict().keys()) & set(kwargs.keys())
 26
 27    for key in used_keys:
 28        kwargs.pop(key)
 29
 30    return config, kwargs
 31
 32
[docs] 33@dataclass 34class LightningIROutput(ModelOutput): 35 """Base class for the output of the Lightning IR model. It is a subclass of transformers.ModelOutput_. 36 37 .. _transformers.ModelOutput: https://huggingface.co/transformers/main_classes/output.html#transformers.ModelOutput 38 39 :param scores: Output relevance scores for query--document pairs, defaults to None 40 :type scores: torch.Tensor | None, optional 41 """ 42 43 scores: torch.Tensor | None = None
44 45
[docs] 46class LightningIRModel(PreTrainedModel): 47 """Base class for Lightning IR models. Derived classes implement the forward method for handling query 48 and document embeddings. It acts as mixin for a transformers.PreTrainedModel_ backbone model. 49 50 .. _transformers.PreTrainedModel: \ 51https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel 52 """ 53 54 config_class: Type[LightningIRConfig] = LightningIRConfig 55 """Configuration class for the model.""" 56 57 ALLOW_SUB_BATCHING = True 58 """Flag to allow mini batches of documents for a single query. Set to false for listwise models to ensure 59 correctness.""" 60
[docs] 61 def __init__(self, config: LightningIRConfig, *args, **kwargs) -> None: 62 """Initializes the model. 63 64 :param config: Configuration class for the model 65 :type config: LightningIRConfig 66 """ 67 super().__init__(config, *args, **kwargs) 68 self.config = config 69 70 self._sub_batch_size: int | None = None
71 72 def _backbone_forward(self, *args, **kwargs): 73 """Runs the forward method of the backbone model. Is overridden in 74 :class:`~lightning_ir.base.class_factory.LightningIRModelClassFactory`. 75 76 :raises NotImplementedError: If not overridden in the derived class 77 """ 78 raise NotImplementedError 79
[docs] 80 def forward(self, *args, **kwargs) -> LightningIROutput: 81 """Forward method of the model. Must be implemented by the derived class.""" 82 raise NotImplementedError
83
[docs] 84 def sparsification( 85 self, embeddings: torch.Tensor, sparsification_strategy: Literal["relu", "relu_log"] | None = None 86 ) -> torch.Tensor: 87 """Helper method to apply sparsification to the embeddings. 88 89 :param embeddings: Query or document embeddings 90 :type embeddings: torch.Tensor 91 :param sparsification_strategy: The sparsification strategy. No sparsification is applied if None, 92 defaults to None 93 :type sparsification_strategy: Literal['relu', 'relu_log'] | None, optional 94 :raises ValueError: If an unknown sparsification strategy is passed 95 :return: (Optionally) sparsified embeddings 96 :rtype: torch.Tensor 97 """ 98 if sparsification_strategy is None: 99 return embeddings 100 if sparsification_strategy == "relu": 101 return torch.relu(embeddings) 102 if sparsification_strategy == "relu_log": 103 return torch.log1p(torch.relu(embeddings)) 104 raise ValueError(f"Unknown sparsification strategy: {sparsification_strategy}")
105
[docs] 106 def pooling( 107 self, 108 embeddings: torch.Tensor, 109 attention_mask: torch.Tensor | None, 110 pooling_strategy: Literal["first", "mean", "max", "sum"] | None, 111 ) -> torch.Tensor: 112 """Helper method to apply pooling to the embeddings. 113 114 :param embeddings: Query or document embeddings 115 :type embeddings: torch.Tensor 116 :param attention_mask: Query or document attention mask 117 :type attention_mask: torch.Tensor | None 118 :param pooling_strategy: The pooling strategy. No pooling is applied if None. 119 :type pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None 120 :raises ValueError: If an unknown pooling strategy is passed 121 :return: (Optionally) pooled embeddings 122 :rtype: torch.Tensor 123 """ 124 if pooling_strategy is None: 125 return embeddings 126 if pooling_strategy == "first": 127 return embeddings.index_select(1, torch.tensor(0, device=embeddings.device)) 128 if pooling_strategy in ("sum", "mean"): 129 if attention_mask is not None: 130 embeddings = embeddings * attention_mask.unsqueeze(-1) 131 embeddings = embeddings.sum(dim=1, keepdim=True) 132 if pooling_strategy == "mean": 133 if attention_mask is not None: 134 embeddings = embeddings / attention_mask.sum(dim=1, keepdim=True).unsqueeze(-1) 135 return embeddings 136 if pooling_strategy == "max": 137 if attention_mask is not None: 138 embeddings = embeddings.masked_fill(~attention_mask.bool().unsqueeze(-1), float("-inf")) 139 return embeddings.amax(dim=1, keepdim=True) 140 raise ValueError(f"Unknown pooling strategy: {pooling_strategy}")
141
[docs] 142 @classmethod 143 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> Self: 144 """Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained_ method to return a 145 derived LightningIRModel. See :class:`LightningIRModelClassFactory` for more details. 146 147.. _transformers.PreTrainedModel.from_pretrained: \ 148 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained 149 150 :param model_name_or_path: Name or path of the pretrained model 151 :type model_name_or_path: str | Path 152 :raises ValueError: If called on the abstract class :class:`LightningIRModel` and no config is passed 153 :return: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin 154 :rtype: LightningIRModel 155 156 .. ::doctest 157 .. highlight:: python 158 .. code-block:: python 159 160 >>> # Loading using model class and backbone checkpoint 161 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) 162 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 163 >>> # Loading using base class and backbone checkpoint 164 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) 165 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> 166 """ 167 # provides AutoModel.from_pretrained support 168 config = kwargs.get("config", None) 169 if cls is LightningIRModel or all(issubclass(base, LightningIRModel) for base in cls.__bases__): 170 # no backbone models found, create derived lightning-ir model based on backbone model 171 if config is not None: 172 ConfigClass = config.__class__ 173 elif model_name_or_path in CHECKPOINT_MAPPING: 174 _config = CHECKPOINT_MAPPING[model_name_or_path] 175 ConfigClass = _config.__class__ 176 if config is None: 177 config = _config 178 elif cls is not LightningIRModel: 179 ConfigClass = cls.config_class 180 else: 181 ConfigClass = type(LightningIRModelClassFactory.get_lightning_ir_config(model_name_or_path)) 182 if ConfigClass is None: 183 raise ValueError("Pass a config to `from_pretrained`.") 184 backbone_config = LightningIRModelClassFactory.get_backbone_config(model_name_or_path).from_pretrained( 185 model_name_or_path 186 ) 187 BackboneModel = _get_model_class(backbone_config) 188 cls = LightningIRModelClassFactory(ConfigClass).from_backbone_class(BackboneModel) 189 if config is not None: 190 if all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__): 191 derived_config = cls.config_class.from_pretrained(model_name_or_path, config=config) 192 derived_config.update(config.to_dict()) 193 config = derived_config 194 kwargs["config"] = config 195 # NOTE 'config' is contained in kwargs, so we can update it 196 config, kwargs = _update_config_with_kwargs(**kwargs) 197 kwargs["config"] = config 198 return cls.from_pretrained(model_name_or_path, *args, **kwargs) 199 if issubclass(cls, BertModel): 200 kwargs["add_pooling_layer"] = False 201 key_mapping = kwargs.pop("key_mapping", {}) 202 if model_name_or_path in STATE_DICT_KEY_MAPPING: 203 key_mapping.update(STATE_DICT_KEY_MAPPING[str(model_name_or_path)]) 204 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs) 205 if model_name_or_path in POST_LOAD_CALLBACKS: 206 model = POST_LOAD_CALLBACKS[str(model_name_or_path)](model) 207 return model
208 209 210T = TypeVar("T") 211 212 213def _cat_outputs( 214 outputs: Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None], OutputClass: Type[T] | None 215) -> torch.Tensor | T | None: 216 """Helper method to concatenate outputs of the model.""" 217 if len(outputs) == 1: 218 return outputs[0] 219 if len(outputs) == 0 or outputs[0] is None or OutputClass is None: 220 return None 221 if isinstance(outputs[0], torch.Tensor): 222 return torch.cat(outputs, dim=0) 223 agg = defaultdict(list) 224 types = {} 225 for output in outputs: 226 for key, value in output.items(): 227 agg[key].append(value) 228 types[key] = type(value) 229 kwargs = {key: _cat_outputs(value, types[key]) for key, value in agg.items()} 230 if OutputClass is BatchEncoding: 231 return OutputClass(kwargs) 232 return OutputClass(**kwargs) 233 234
[docs] 235class BatchEncodingWrapper(Protocol): 236 def __call__(self, encoding: BatchEncoding, *args, **kwargs) -> Any: ...
237 238
[docs] 239def batch_encoding_wrapper(func: BatchEncodingWrapper) -> BatchEncodingWrapper: 240 """Decorator to enable sub-batching for models that support it. Lowers the batch size of the input batch encoding 241 if the model runs out of memory. 242 243 :param func: Function to wrap that takes a batch encoding 244 :type func: BatchEncodingWrapper 245 :raises e: If CUDA runs out of memory even after lowering the batch size to 1 246 :raises ValueError: If no output was generated 247 :return: Wrapped function 248 :rtype: BatchEncodingWrapper 249 """ 250 251 @wraps(func) 252 def wrapper(self, encoding: BatchEncoding, *args, **kwargs) -> Any: 253 if not self.ALLOW_SUB_BATCHING: 254 return func(self, encoding, *args, **kwargs) 255 sub_batch_size = self._sub_batch_size or encoding.input_ids.shape[0] 256 sub_encoding = encoding 257 remaining_encoding = encoding 258 OutputClass = None 259 outputs = [] 260 while True: 261 try: 262 # ceil division 263 num_batches = -(remaining_encoding.input_ids.shape[0] // -sub_batch_size) 264 for _ in range(num_batches): 265 sub_encoding = BatchEncoding( 266 {key: value[:sub_batch_size] for key, value in remaining_encoding.items()} 267 ) 268 output = func(self, sub_encoding, *args, **kwargs) 269 OutputClass = output.__class__ 270 outputs.append(output) 271 remaining_encoding = BatchEncoding( 272 {key: value[sub_batch_size:] for key, value in remaining_encoding.items()} 273 ) 274 break 275 except RuntimeError as e: 276 if "CUDA out of memory" in str(e) or "CUDACachingAllocator.cpp" in str(e): 277 self._sub_batch_size = sub_batch_size = sub_batch_size // 2 278 if sub_batch_size == 0: 279 raise e 280 else: 281 raise e 282 if OutputClass is None: 283 raise ValueError("No output was generated.") 284 return _cat_outputs(outputs, OutputClass) 285 286 return wrapper