1"""
2Model module for Lightning IR.
3
4This module contains the main model class and output class for the Lightning IR library.
5"""
6
7from collections import defaultdict
8from dataclasses import dataclass
9from functools import wraps
10from pathlib import Path
11from typing import Any, Literal, Mapping, Protocol, Self, Sequence, Type, TypeVar
12
13import torch
14from transformers import BatchEncoding, BertModel, PreTrainedModel
15from transformers.modeling_outputs import ModelOutput
16
17from .adapter import LightningIRAdapterMixin
18from .class_factory import LightningIRModelClassFactory, _get_model_class
19from .config import LightningIRConfig
20from .external_model_hub import BACKBONE_MAPPING, CHECKPOINT_MAPPING, POST_LOAD_CALLBACKS, STATE_DICT_KEY_MAPPING
21
22
23def _update_config_with_kwargs(config: LightningIRConfig, **kwargs):
24 config.update(kwargs)
25
26 used_keys = set(config.to_dict().keys()) & set(kwargs.keys())
27
28 for key in used_keys:
29 kwargs.pop(key)
30
31 return config, kwargs
32
33
[docs]
34@dataclass
35class LightningIROutput(ModelOutput):
36 """Base class for the output of the Lightning IR model. It is a subclass of transformers.ModelOutput_.
37
38 .. _transformers.ModelOutput: https://huggingface.co/transformers/main_classes/output.html#transformers.ModelOutput
39
40 Attributes:
41 scores (torch.Tensor | None): Output relevance scores for query--document pairs. Defaults to None.
42 """
43
44 scores: torch.Tensor | None = None
45
46
[docs]
47class LightningIRModel(LightningIRAdapterMixin, PreTrainedModel):
48 """Base class for Lightning IR models. Derived classes implement the forward method for handling query
49 and document embeddings. It acts as mixin for a transformers.PreTrainedModel_ backbone model.
50
51 .. _transformers.PreTrainedModel: \
52https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel
53
54 Attributes:
55 config_class (Type[LightningIRConfig]): Configuration class for the model.
56 ALLOW_SUB_BATCHING (bool): Flag to allow mini batches of documents for a single query.
57 Set to false for listwise models to ensure correctness.
58 """
59
60 config_class: Type[LightningIRConfig] = LightningIRConfig
61 """Configuration class for the model."""
62
63 ALLOW_SUB_BATCHING = True
64 """Flag to allow mini batches of documents for a single query. Set to false for listwise models to ensure
65 correctness."""
66
[docs]
67 def __init__(self, config: LightningIRConfig, *args, **kwargs) -> None:
68 """Initializes the model.
69
70 Args:
71 config(LightningIRConfig): Configuration class for the model
72 """
73 super().__init__(config, *args, **kwargs)
74 self.config = config
75
76 self._sub_batch_size: int | None = None
77
78 def _initialize_adapters(self) -> None:
79 """Initialize adapters based on configuration."""
80 if not self.config.use_adapter:
81 return
82
83 # Enable adapters if configuration is provided
84 if self.config.adapter_config is not None:
85 self.init_adapters(self.config.adapter_config)
86
87 # Load adapter weights if path is provided
88 if self.config.pretrained_adapter_name_or_path is not None:
89 self.load_adapter(self.config.pretrained_adapter_name_or_path)
90
91 def _backbone_forward(self, *args, **kwargs):
92 """Runs the forward method of the backbone model. Is overridden in
93 :class:`~lightning_ir.base.class_factory.LightningIRModelClassFactory`.
94
95 Raises:
96 NotImplementedError: If not overridden in the derived class
97 """
98 raise NotImplementedError
99
[docs]
100 def forward(self, *args, **kwargs) -> LightningIROutput:
101 """Forward method of the model. Must be implemented by the derived class."""
102 raise NotImplementedError
103
[docs]
104 def sparsification(
105 self, embeddings: torch.Tensor, sparsification_strategy: Literal["relu", "relu_log", "relu_2xlog"] | None = None
106 ) -> torch.Tensor:
107 """Helper method to apply sparsification to the embeddings.
108
109 Args:
110 embeddings(torch.Tensor): Query or document embeddings
111 sparsification_strategy(Literal['relu', 'relu_log', 'relu_2xlog'] | None): The sparsification strategy. No
112 sparsification is applied if None. Defaults to None.
113 Returns:
114 torch.Tensor: (Optionally) sparsified embeddings.
115 Raises:
116 ValueError: If an unknown sparsification strategy is passed.
117 """
118 if sparsification_strategy is None:
119 return embeddings
120 if sparsification_strategy == "relu":
121 return torch.relu(embeddings)
122 if sparsification_strategy == "relu_log":
123 return torch.log1p(torch.relu(embeddings))
124 if sparsification_strategy == "relu_2xlog":
125 return torch.log1p(torch.log1p(torch.relu(embeddings)))
126 raise ValueError(f"Unknown sparsification strategy: {sparsification_strategy}")
127
[docs]
128 def pooling(
129 self,
130 embeddings: torch.Tensor,
131 attention_mask: torch.Tensor | None,
132 pooling_strategy: Literal["first", "mean", "max", "sum"] | None,
133 ) -> torch.Tensor:
134 """Helper method to apply pooling to the embeddings.
135
136 Args:
137 embeddings (torch.Tensor): Query or document embeddings
138 attention_mask (torch.Tensor | None): Query or document attention mask
139 pooling_strategy (Literal['first', 'mean', 'max', 'sum'] | None):
140 The pooling strategy. No pooling is applied if None.
141 Returns:
142 torch.Tensor: (Optionally) pooled embeddings.
143 Raises:
144 ValueError: If an unknown pooling strategy is passed.
145 """
146 if pooling_strategy is None:
147 return embeddings
148 if pooling_strategy == "first":
149 return embeddings[:, [0]]
150 if pooling_strategy in ("sum", "mean"):
151 if attention_mask is not None:
152 embeddings = embeddings * attention_mask.unsqueeze(-1)
153 embeddings = embeddings.sum(dim=1, keepdim=True)
154 if pooling_strategy == "mean":
155 if attention_mask is not None:
156 embeddings = embeddings / attention_mask.sum(dim=1, keepdim=True).unsqueeze(-1)
157 return embeddings
158 if pooling_strategy == "max":
159 if attention_mask is not None:
160 embeddings = embeddings.masked_fill(~attention_mask.bool().unsqueeze(-1), float("-inf"))
161 return embeddings.amax(dim=1, keepdim=True)
162 raise ValueError(f"Unknown pooling strategy: {pooling_strategy}")
163
[docs]
164 @classmethod
165 def from_pretrained(
166 cls, model_name_or_path: str | Path, *args, BackboneModel: Type[PreTrainedModel] | None = None, **kwargs
167 ) -> Self:
168 """Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained_ method to return a
169 derived LightningIRModel. See :class:`LightningIRModelClassFactory` for more details.
170
171.. _transformers.PreTrainedModel.from_pretrained: \
172 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
173
174 .. ::doctest
175 .. highlight:: python
176 .. code-block:: python
177
178 >>> # Loading using model class and backbone checkpoint
179 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased"))
180 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
181 >>> # Loading using base class and backbone checkpoint
182 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig()))
183 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
184
185 Args:
186 model_name_or_path (str | Path): Name or path of the pretrained model.
187 BackboneModel (Type[PreTrainedModel] | None): Huggingface PreTrainedModel class to use as backbone
188 instead of the default AutoModel. Defaults to None.
189 Raises:
190 ValueError: If called on the abstract class `LightningIRModel` and no config is passed.
191 Returns:
192 LightningIRModel: A derived `LightningIRModel` consisting of a backbone model
193 and a `LightningIRModel` mixin.
194 """
195 # provides AutoModel.from_pretrained support
196 config = kwargs.get("config", None)
197 if cls is LightningIRModel or all(issubclass(base, LightningIRModel) for base in cls.__bases__):
198 # no backbone models found, create derived lightning-ir model based on backbone model
199 if config is not None:
200 ConfigClass = config.__class__
201 elif model_name_or_path in CHECKPOINT_MAPPING:
202 _config = CHECKPOINT_MAPPING[model_name_or_path]
203 ConfigClass = _config.__class__
204 if config is None:
205 config = _config
206 elif cls is not LightningIRModel:
207 ConfigClass = cls.config_class
208 else:
209 ConfigClass = type(LightningIRModelClassFactory.get_lightning_ir_config(model_name_or_path))
210 if ConfigClass is None:
211 raise ValueError("Pass a config to `from_pretrained`.")
212 if BackboneModel is None:
213 if model_name_or_path in BACKBONE_MAPPING:
214 BackboneModel = BACKBONE_MAPPING[str(model_name_or_path)]
215 else:
216 backbone_config = LightningIRModelClassFactory.get_backbone_config(
217 model_name_or_path
218 ).from_pretrained(model_name_or_path)
219 BackboneModel = _get_model_class(backbone_config)
220 cls = LightningIRModelClassFactory(ConfigClass).from_backbone_class(BackboneModel)
221 if config is not None:
222 if all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__):
223 derived_config = cls.config_class.from_pretrained(model_name_or_path, config=config)
224 derived_config.update(config.to_diff_dict())
225 config = derived_config
226 kwargs["config"] = config
227 # NOTE 'config' is contained in kwargs, so we can update it
228 config, kwargs = _update_config_with_kwargs(**kwargs)
229 kwargs["config"] = config
230 return cls.from_pretrained(model_name_or_path, *args, **kwargs)
231 if issubclass(cls, BertModel):
232 kwargs["add_pooling_layer"] = False
233 key_mapping = kwargs.pop("key_mapping", {})
234 if model_name_or_path in STATE_DICT_KEY_MAPPING:
235 key_mapping.update(STATE_DICT_KEY_MAPPING[str(model_name_or_path)])
236 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs)
237 if model_name_or_path in POST_LOAD_CALLBACKS:
238 model = POST_LOAD_CALLBACKS[str(model_name_or_path)](model)
239
240 # Initialize adapters after model is fully loaded
241 model._initialize_adapters()
242
243 return model
244
245
246T = TypeVar("T")
247
248
249def _cat_outputs(
250 outputs: Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None], OutputClass: Type[T] | None
251) -> torch.Tensor | T | None:
252 """Helper method to concatenate outputs of the model.
253
254 Args:
255 outputs (Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None]): Outputs from the model.
256 OutputClass (Type[T] | None): Class to return the concatenated output as.
257 Returns:
258 torch.Tensor | T | None: Concatenated output.
259 """
260 if len(outputs) == 1:
261 return outputs[0]
262 if len(outputs) == 0 or outputs[0] is None or OutputClass is None:
263 return None
264 if isinstance(outputs[0], torch.Tensor):
265 return torch.cat(outputs, dim=0)
266 agg = defaultdict(list)
267 types = {}
268 for output in outputs:
269 for key, value in output.items():
270 agg[key].append(value)
271 types[key] = type(value)
272 kwargs = {key: _cat_outputs(value, types[key]) for key, value in agg.items()}
273 if OutputClass is BatchEncoding:
274 return OutputClass(kwargs)
275 return OutputClass(**kwargs)
276
277
[docs]
278class BatchEncodingWrapper(Protocol):
279 def __call__(self, encoding: BatchEncoding, *args, **kwargs) -> Any: ...
280
281
[docs]
282def batch_encoding_wrapper(func: BatchEncodingWrapper) -> BatchEncodingWrapper:
283 """Decorator to enable sub-batching for models that support it. Lowers the batch size of the input batch encoding
284 if the model runs out of memory.
285
286 Args:
287 func (BatchEncodingWrapper): Function to wrap that takes a batch encoding.
288 Returns:
289 BatchEncodingWrapper: Wrapped function that handles sub-batching.
290 Raises:
291 RuntimeError: If CUDA runs out of memory and the batch size cannot be lowered further.
292 ValueError: If no output was generated.
293 """
294
295 @wraps(func)
296 def wrapper(self, encoding: BatchEncoding, *args, **kwargs) -> Any:
297 if not self.ALLOW_SUB_BATCHING:
298 return func(self, encoding, *args, **kwargs)
299 sub_batch_size = self._sub_batch_size or encoding.input_ids.shape[0]
300 sub_encoding = encoding
301 remaining_encoding = encoding
302 OutputClass = None
303 outputs = []
304 while True:
305 try:
306 # ceil division
307 num_batches = -(remaining_encoding.input_ids.shape[0] // -sub_batch_size)
308 for _ in range(num_batches):
309 sub_encoding = BatchEncoding(
310 {key: value[:sub_batch_size] for key, value in remaining_encoding.items()}
311 )
312 output = func(self, sub_encoding, *args, **kwargs)
313 OutputClass = output.__class__
314 outputs.append(output)
315 remaining_encoding = BatchEncoding(
316 {key: value[sub_batch_size:] for key, value in remaining_encoding.items()}
317 )
318 break
319 except RuntimeError as e:
320 if "CUDA out of memory" in str(e) or "CUDACachingAllocator.cpp" in str(e):
321 self._sub_batch_size = sub_batch_size = sub_batch_size // 2
322 if sub_batch_size == 0:
323 raise e
324 else:
325 raise e
326 if OutputClass is None:
327 raise ValueError("No output was generated.")
328 return _cat_outputs(outputs, OutputClass)
329
330 return wrapper