1"""
2Model module for Lightning IR.
3
4This module contains the main model class and output class for the Lightning IR library.
5"""
6
7from collections import defaultdict
8from collections.abc import Mapping, Sequence
9from dataclasses import dataclass
10from functools import wraps
11from pathlib import Path
12from typing import Any, Protocol, Self, TypeVar
13
14import torch
15from transformers import BatchEncoding, BertModel, PreTrainedModel
16from transformers.modeling_outputs import ModelOutput
17
18from .adapter import LightningIRAdapterMixin
19from .class_factory import LightningIRModelClassFactory, _get_model_class
20from .config import LightningIRConfig
21from .external_model_hub import (
22 BACKBONE_MAPPING,
23 CHECKPOINT_MAPPING,
24 POST_LOAD_CALLBACKS,
25 STATE_DICT_KEY_MAPPING,
26)
27
28
29def _update_config_with_kwargs(config: LightningIRConfig, **kwargs):
30 config.update(kwargs)
31
32 used_keys = set(config.to_dict().keys()) & set(kwargs.keys())
33
34 for key in used_keys:
35 kwargs.pop(key)
36
37 return config, kwargs
38
39
[docs]
40@dataclass
41class LightningIROutput(ModelOutput):
42 """Base class for the output of the Lightning IR model. It is a subclass of transformers.ModelOutput_.
43
44 .. _transformers.ModelOutput: https://huggingface.co/transformers/main_classes/output.html#transformers.ModelOutput
45
46 Attributes:
47 scores (torch.Tensor | None): Output relevance scores for query--document pairs. Defaults to None.
48 """
49
50 scores: torch.Tensor | None = None
51
52
[docs]
53class LightningIRModel(LightningIRAdapterMixin, PreTrainedModel):
54 """Base class for Lightning IR models. Derived classes implement the forward method for handling query
55 and document embeddings. It acts as mixin for a transformers.PreTrainedModel_ backbone model.
56
57 .. _transformers.PreTrainedModel: \
58https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel
59
60 Attributes:
61 config_class (type[LightningIRConfig]): Configuration class for the model.
62 ALLOW_SUB_BATCHING (bool): Flag to allow mini batches of documents for a single query.
63 set to false for listwise models to ensure correctness.
64 """
65
66 config_class: type[LightningIRConfig] = LightningIRConfig
67 """Configuration class for the model."""
68
69 ALLOW_SUB_BATCHING = True
70 """Flag to allow mini batches of documents for a single query. set to false for listwise models to ensure
71 correctness."""
72
[docs]
73 def __init__(self, config: LightningIRConfig, *args, **kwargs) -> None:
74 """Initializes the model.
75
76 Args:
77 config(LightningIRConfig): Configuration class for the model
78 """
79 super().__init__(config, *args, **kwargs)
80 self.config = config
81
82 self._sub_batch_size: int | None = None
83
84 def _initialize_adapters(self) -> None:
85 """Initialize adapters based on configuration."""
86 if not self.config.use_adapter:
87 return
88
89 # Enable adapters if configuration is provided
90 if self.config.adapter_config is not None:
91 self.init_adapters(self.config.adapter_config)
92
93 # Load adapter weights if path is provided
94 if self.config.pretrained_adapter_name_or_path is not None:
95 self.load_adapter(self.config.pretrained_adapter_name_or_path)
96
97 def _init_weights(self, module: torch.nn.Module) -> None:
98 super()._init_weights(module)
99 if self.config.backbone_model_type == "modernbert":
100 # NOTE modernbert does not initialize the weights of non-modernbert layers
101 # So we need to initialize them separately using the default initialization of the PreTrainedModel
102 PreTrainedModel._init_weights(self, module)
103
104 def _backbone_forward(self, *args, **kwargs):
105 """Runs the forward method of the backbone model. Is overridden in
106 :class:`~lightning_ir.base.class_factory.LightningIRModelClassFactory`.
107
108 Raises:
109 NotImplementedError: If not overridden in the derived class
110 """
111 raise NotImplementedError
112
[docs]
113 def forward(self, *args, **kwargs) -> LightningIROutput:
114 """Forward method of the model. Must be implemented by the derived class."""
115 raise NotImplementedError
116
[docs]
117 @classmethod
118 def from_pretrained(
119 cls, model_name_or_path: str | Path, *args, BackboneModel: type[PreTrainedModel] | None = None, **kwargs
120 ) -> Self:
121 """Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained_ method to return a
122 derived LightningIRModel. See :class:`LightningIRModelClassFactory` for more details.
123
124.. _transformers.PreTrainedModel.from_pretrained: \
125 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
126
127 .. ::doctest
128 .. highlight:: python
129 .. code-block:: python
130
131 >>> # Loading using model class and backbone checkpoint
132 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased"))
133 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
134 >>> # Loading using base class and backbone checkpoint
135 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig()))
136 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
137
138 Args:
139 model_name_or_path (str | Path): Name or path of the pretrained model.
140 BackboneModel (type[PreTrainedModel] | None): Huggingface PreTrainedModel class to use as backbone
141 instead of the default AutoModel. Defaults to None.
142 Raises:
143 ValueError: If called on the abstract class `LightningIRModel` and no config is passed.
144 Returns:
145 LightningIRModel: A derived `LightningIRModel` consisting of a backbone model
146 and a `LightningIRModel` mixin.
147 """
148 # provides AutoModel.from_pretrained support
149 config = kwargs.get("config", None)
150 if cls is LightningIRModel or all(issubclass(base, LightningIRModel) for base in cls.__bases__):
151 # no backbone models found, create derived lightning-ir model based on backbone model
152 if config is not None:
153 ConfigClass = config.__class__
154 elif model_name_or_path in CHECKPOINT_MAPPING:
155 _config = CHECKPOINT_MAPPING[model_name_or_path]
156 ConfigClass = _config.__class__
157 if config is None:
158 config = _config
159 elif cls is not LightningIRModel:
160 ConfigClass = cls.config_class
161 else:
162 ConfigClass = type(LightningIRModelClassFactory.get_lightning_ir_config(model_name_or_path))
163 if ConfigClass is None:
164 raise ValueError("Pass a config to `from_pretrained`.")
165 if BackboneModel is None:
166 if model_name_or_path in BACKBONE_MAPPING:
167 BackboneModel = BACKBONE_MAPPING[str(model_name_or_path)]
168 else:
169 backbone_config = LightningIRModelClassFactory.get_backbone_config(
170 model_name_or_path
171 ).from_pretrained(model_name_or_path)
172 BackboneModel = _get_model_class(backbone_config)
173 cls = LightningIRModelClassFactory(ConfigClass).from_backbone_class(BackboneModel)
174 if config is not None:
175 if all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__):
176 derived_config = cls.config_class.from_pretrained(model_name_or_path, config=config)
177 derived_config.update(config.to_diff_dict())
178 config = derived_config
179 kwargs["config"] = config
180 # NOTE 'config' is contained in kwargs, so we can update it
181 config, kwargs = _update_config_with_kwargs(**kwargs)
182 kwargs["config"] = config
183 return cls.from_pretrained(model_name_or_path, *args, **kwargs)
184 if issubclass(cls, BertModel):
185 kwargs["add_pooling_layer"] = False
186 key_mapping = kwargs.pop("key_mapping", {})
187 if model_name_or_path in STATE_DICT_KEY_MAPPING:
188 key_mapping.update(STATE_DICT_KEY_MAPPING[str(model_name_or_path)])
189 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs)
190 if model_name_or_path in POST_LOAD_CALLBACKS:
191 model = POST_LOAD_CALLBACKS[str(model_name_or_path)](model)
192
193 # Initialize adapters after model is fully loaded
194 model._initialize_adapters()
195
196 return model
197
198
199T = TypeVar("T")
200
201
202def _cat_outputs(
203 outputs: Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None], OutputClass: type[T] | None
204) -> torch.Tensor | T | None:
205 """Helper method to concatenate outputs of the model.
206
207 Args:
208 outputs (Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None]): Outputs from the model.
209 OutputClass (type[T] | None): Class to return the concatenated output as.
210 Returns:
211 torch.Tensor | T | None: Concatenated output.
212 """
213 if len(outputs) == 1:
214 return outputs[0]
215 if len(outputs) == 0 or outputs[0] is None or OutputClass is None:
216 return None
217 if isinstance(outputs[0], torch.Tensor):
218 return torch.cat(outputs, dim=0)
219 agg = defaultdict(list)
220 types = {}
221 for output in outputs:
222 for key, value in output.items():
223 agg[key].append(value)
224 types[key] = type(value)
225 kwargs = {key: _cat_outputs(value, types[key]) for key, value in agg.items()}
226 if OutputClass is BatchEncoding:
227 return OutputClass(kwargs)
228 return OutputClass(**kwargs)
229
230
[docs]
231class BatchEncodingWrapper(Protocol):
232 def __call__(self, encoding: BatchEncoding, *args, **kwargs) -> Any: ...
233
234
[docs]
235def batch_encoding_wrapper(func: BatchEncodingWrapper) -> BatchEncodingWrapper:
236 """Decorator to enable sub-batching for models that support it. Lowers the batch size of the input batch encoding
237 if the model runs out of memory.
238
239 Args:
240 func (BatchEncodingWrapper): Function to wrap that takes a batch encoding.
241 Returns:
242 BatchEncodingWrapper: Wrapped function that handles sub-batching.
243 Raises:
244 RuntimeError: If CUDA runs out of memory and the batch size cannot be lowered further.
245 ValueError: If no output was generated.
246 """
247
248 @wraps(func)
249 def wrapper(self, encoding: BatchEncoding, *args, **kwargs) -> Any:
250 if not self.ALLOW_SUB_BATCHING:
251 return func(self, encoding, *args, **kwargs)
252 sub_batch_size = self._sub_batch_size or encoding.input_ids.shape[0]
253 sub_encoding = encoding
254 remaining_encoding = encoding
255 OutputClass = None
256 outputs = []
257 while True:
258 try:
259 # ceil division
260 num_batches = -(remaining_encoding.input_ids.shape[0] // -sub_batch_size)
261 for _ in range(num_batches):
262 sub_encoding = BatchEncoding(
263 {key: value[:sub_batch_size] for key, value in remaining_encoding.items()}
264 )
265 output = func(self, sub_encoding, *args, **kwargs)
266 OutputClass = output.__class__
267 outputs.append(output)
268 remaining_encoding = BatchEncoding(
269 {key: value[sub_batch_size:] for key, value in remaining_encoding.items()}
270 )
271 break
272 except RuntimeError as e:
273 if "CUDA out of memory" in str(e) or "CUDACachingAllocator.cpp" in str(e):
274 self._sub_batch_size = sub_batch_size = sub_batch_size // 2
275 if sub_batch_size == 0:
276 raise e
277 else:
278 raise e
279 if OutputClass is None:
280 raise ValueError("No output was generated.")
281 return _cat_outputs(outputs, OutputClass)
282
283 return wrapper