1"""
2Model module for Lightning IR.
3
4This module contains the main model class and output class for the Lightning IR library.
5"""
6
7from collections import defaultdict
8from collections.abc import Mapping, Sequence
9from dataclasses import dataclass
10from functools import wraps
11from pathlib import Path
12from typing import Any, Protocol, Self, TypeVar
13
14import torch
15from transformers import BatchEncoding, BertModel, PreTrainedModel
16from transformers.modeling_outputs import ModelOutput
17
18from .adapter import LightningIRAdapterMixin
19from .class_factory import LightningIRModelClassFactory, _get_model_class
20from .config import LightningIRConfig
21from .external_model_hub import (
22 BACKBONE_MAPPING,
23 CHECKPOINT_MAPPING,
24 POST_LOAD_CALLBACKS,
25 STATE_DICT_KEY_MAPPING,
26)
27
28
29def _update_config_with_kwargs(config: LightningIRConfig, **kwargs):
30 config.update(kwargs)
31
32 used_keys = set(config.to_dict().keys()) & set(kwargs.keys())
33
34 for key in used_keys:
35 kwargs.pop(key)
36
37 return config, kwargs
38
39
[docs]
40@dataclass
41class LightningIROutput(ModelOutput):
42 """Base class for the output of the Lightning IR model. It is a subclass of transformers.ModelOutput_.
43
44 .. _transformers.ModelOutput: https://huggingface.co/transformers/main_classes/output.html#transformers.ModelOutput
45
46 Attributes:
47 scores (torch.Tensor | None): Output relevance scores for query--document pairs. Defaults to None.
48 """
49
50 scores: torch.Tensor | None = None
51
52
[docs]
53class LightningIRModel(LightningIRAdapterMixin, PreTrainedModel):
54 """Base class for Lightning IR models. Derived classes implement the forward method for handling query
55 and document embeddings. It acts as mixin for a transformers.PreTrainedModel_ backbone model.
56
57 .. _transformers.PreTrainedModel: \
58https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel
59
60 Attributes:
61 config_class (type[LightningIRConfig]): Configuration class for the model.
62 ALLOW_SUB_BATCHING (bool): Flag to allow mini batches of documents for a single query.
63 set to false for listwise models to ensure correctness.
64 """
65
66 config_class: type[LightningIRConfig] = LightningIRConfig
67 """Configuration class for the model."""
68
69 ALLOW_SUB_BATCHING = True
70 """Flag to allow mini batches of documents for a single query. set to false for listwise models to ensure
71 correctness."""
72
[docs]
73 def __init__(self, config: LightningIRConfig, *args, **kwargs) -> None:
74 """Initializes the model.
75
76 Args:
77 config(LightningIRConfig): Configuration class for the model
78 """
79 super().__init__(config, *args, **kwargs)
80 self.config = config
81
82 self._sub_batch_size: int | None = None
83
84 def _initialize_adapters(self) -> None:
85 """Initialize adapters based on configuration."""
86 if not self.config.use_adapter:
87 return
88
89 # Enable adapters if configuration is provided
90 if self.config.adapter_config is not None:
91 self.init_adapters(self.config.adapter_config)
92
93 # Load adapter weights if path is provided
94 if self.config.pretrained_adapter_name_or_path is not None:
95 self.load_adapter(self.config.pretrained_adapter_name_or_path)
96
97 def _backbone_forward(self, *args, **kwargs):
98 """Runs the forward method of the backbone model. Is overridden in
99 :class:`~lightning_ir.base.class_factory.LightningIRModelClassFactory`.
100
101 Raises:
102 NotImplementedError: If not overridden in the derived class
103 """
104 raise NotImplementedError
105
[docs]
106 def forward(self, *args, **kwargs) -> LightningIROutput:
107 """Forward method of the model. Must be implemented by the derived class."""
108 raise NotImplementedError
109
[docs]
110 @classmethod
111 def from_pretrained(
112 cls, model_name_or_path: str | Path, *args, BackboneModel: type[PreTrainedModel] | None = None, **kwargs
113 ) -> Self:
114 """Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained_ method to return a
115 derived LightningIRModel. See :class:`LightningIRModelClassFactory` for more details.
116
117.. _transformers.PreTrainedModel.from_pretrained: \
118 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
119
120 .. ::doctest
121 .. highlight:: python
122 .. code-block:: python
123
124 >>> # Loading using model class and backbone checkpoint
125 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased"))
126 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
127 >>> # Loading using base class and backbone checkpoint
128 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig()))
129 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
130
131 Args:
132 model_name_or_path (str | Path): Name or path of the pretrained model.
133 BackboneModel (type[PreTrainedModel] | None): Huggingface PreTrainedModel class to use as backbone
134 instead of the default AutoModel. Defaults to None.
135 Raises:
136 ValueError: If called on the abstract class `LightningIRModel` and no config is passed.
137 Returns:
138 LightningIRModel: A derived `LightningIRModel` consisting of a backbone model
139 and a `LightningIRModel` mixin.
140 """
141 # provides AutoModel.from_pretrained support
142 config = kwargs.get("config", None)
143 if cls is LightningIRModel or all(issubclass(base, LightningIRModel) for base in cls.__bases__):
144 # no backbone models found, create derived lightning-ir model based on backbone model
145 if config is not None:
146 ConfigClass = config.__class__
147 elif model_name_or_path in CHECKPOINT_MAPPING:
148 _config = CHECKPOINT_MAPPING[model_name_or_path]
149 ConfigClass = _config.__class__
150 if config is None:
151 config = _config
152 elif cls is not LightningIRModel:
153 ConfigClass = cls.config_class
154 else:
155 ConfigClass = type(LightningIRModelClassFactory.get_lightning_ir_config(model_name_or_path))
156 if ConfigClass is None:
157 raise ValueError("Pass a config to `from_pretrained`.")
158 if BackboneModel is None:
159 if model_name_or_path in BACKBONE_MAPPING:
160 BackboneModel = BACKBONE_MAPPING[str(model_name_or_path)]
161 else:
162 backbone_config = LightningIRModelClassFactory.get_backbone_config(
163 model_name_or_path
164 ).from_pretrained(model_name_or_path)
165 BackboneModel = _get_model_class(backbone_config)
166 cls = LightningIRModelClassFactory(ConfigClass).from_backbone_class(BackboneModel)
167 if config is not None:
168 if all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__):
169 derived_config = cls.config_class.from_pretrained(model_name_or_path, config=config)
170 derived_config.update(config.to_diff_dict())
171 config = derived_config
172 kwargs["config"] = config
173 # NOTE 'config' is contained in kwargs, so we can update it
174 config, kwargs = _update_config_with_kwargs(**kwargs)
175 kwargs["config"] = config
176 return cls.from_pretrained(model_name_or_path, *args, **kwargs)
177 if issubclass(cls, BertModel):
178 kwargs["add_pooling_layer"] = False
179 key_mapping = kwargs.pop("key_mapping", {})
180 if model_name_or_path in STATE_DICT_KEY_MAPPING:
181 key_mapping.update(STATE_DICT_KEY_MAPPING[str(model_name_or_path)])
182 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs)
183 if model_name_or_path in POST_LOAD_CALLBACKS:
184 model = POST_LOAD_CALLBACKS[str(model_name_or_path)](model)
185
186 # Initialize adapters after model is fully loaded
187 model._initialize_adapters()
188
189 return model
190
191
192T = TypeVar("T")
193
194
195def _cat_outputs(
196 outputs: Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None], OutputClass: type[T] | None
197) -> torch.Tensor | T | None:
198 """Helper method to concatenate outputs of the model.
199
200 Args:
201 outputs (Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None]): Outputs from the model.
202 OutputClass (type[T] | None): Class to return the concatenated output as.
203 Returns:
204 torch.Tensor | T | None: Concatenated output.
205 """
206 if len(outputs) == 1:
207 return outputs[0]
208 if len(outputs) == 0 or outputs[0] is None or OutputClass is None:
209 return None
210 if isinstance(outputs[0], torch.Tensor):
211 return torch.cat(outputs, dim=0)
212 agg = defaultdict(list)
213 types = {}
214 for output in outputs:
215 for key, value in output.items():
216 agg[key].append(value)
217 types[key] = type(value)
218 kwargs = {key: _cat_outputs(value, types[key]) for key, value in agg.items()}
219 if OutputClass is BatchEncoding:
220 return OutputClass(kwargs)
221 return OutputClass(**kwargs)
222
223
[docs]
224class BatchEncodingWrapper(Protocol):
225 def __call__(self, encoding: BatchEncoding, *args, **kwargs) -> Any: ...
226
227
[docs]
228def batch_encoding_wrapper(func: BatchEncodingWrapper) -> BatchEncodingWrapper:
229 """Decorator to enable sub-batching for models that support it. Lowers the batch size of the input batch encoding
230 if the model runs out of memory.
231
232 Args:
233 func (BatchEncodingWrapper): Function to wrap that takes a batch encoding.
234 Returns:
235 BatchEncodingWrapper: Wrapped function that handles sub-batching.
236 Raises:
237 RuntimeError: If CUDA runs out of memory and the batch size cannot be lowered further.
238 ValueError: If no output was generated.
239 """
240
241 @wraps(func)
242 def wrapper(self, encoding: BatchEncoding, *args, **kwargs) -> Any:
243 if not self.ALLOW_SUB_BATCHING:
244 return func(self, encoding, *args, **kwargs)
245 sub_batch_size = self._sub_batch_size or encoding.input_ids.shape[0]
246 sub_encoding = encoding
247 remaining_encoding = encoding
248 OutputClass = None
249 outputs = []
250 while True:
251 try:
252 # ceil division
253 num_batches = -(remaining_encoding.input_ids.shape[0] // -sub_batch_size)
254 for _ in range(num_batches):
255 sub_encoding = BatchEncoding(
256 {key: value[:sub_batch_size] for key, value in remaining_encoding.items()}
257 )
258 output = func(self, sub_encoding, *args, **kwargs)
259 OutputClass = output.__class__
260 outputs.append(output)
261 remaining_encoding = BatchEncoding(
262 {key: value[sub_batch_size:] for key, value in remaining_encoding.items()}
263 )
264 break
265 except RuntimeError as e:
266 if "CUDA out of memory" in str(e) or "CUDACachingAllocator.cpp" in str(e):
267 self._sub_batch_size = sub_batch_size = sub_batch_size // 2
268 if sub_batch_size == 0:
269 raise e
270 else:
271 raise e
272 if OutputClass is None:
273 raise ValueError("No output was generated.")
274 return _cat_outputs(outputs, OutputClass)
275
276 return wrapper