1"""
2Model module for Lightning IR.
3
4This module contains the main model class and output class for the Lightning IR library.
5"""
6
7from collections import defaultdict
8from dataclasses import dataclass
9from functools import wraps
10from pathlib import Path
11from typing import Any, Literal, Mapping, Protocol, Self, Sequence, Type, TypeVar
12
13import torch
14from transformers import BatchEncoding, BertModel, PreTrainedModel
15from transformers.modeling_outputs import ModelOutput
16
17from .class_factory import LightningIRModelClassFactory, _get_model_class
18from .config import LightningIRConfig
19from .external_model_hub import CHECKPOINT_MAPPING, POST_LOAD_CALLBACKS, STATE_DICT_KEY_MAPPING
20
21
22def _update_config_with_kwargs(config: LightningIRConfig, **kwargs):
23 config.update(kwargs)
24
25 used_keys = set(config.to_dict().keys()) & set(kwargs.keys())
26
27 for key in used_keys:
28 kwargs.pop(key)
29
30 return config, kwargs
31
32
[docs]
33@dataclass
34class LightningIROutput(ModelOutput):
35 """Base class for the output of the Lightning IR model. It is a subclass of transformers.ModelOutput_.
36
37 .. _transformers.ModelOutput: https://huggingface.co/transformers/main_classes/output.html#transformers.ModelOutput
38
39 :param scores: Output relevance scores for query--document pairs, defaults to None
40 :type scores: torch.Tensor | None, optional
41 """
42
43 scores: torch.Tensor | None = None
44
45
[docs]
46class LightningIRModel(PreTrainedModel):
47 """Base class for Lightning IR models. Derived classes implement the forward method for handling query
48 and document embeddings. It acts as mixin for a transformers.PreTrainedModel_ backbone model.
49
50 .. _transformers.PreTrainedModel: \
51https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel
52 """
53
54 config_class: Type[LightningIRConfig] = LightningIRConfig
55 """Configuration class for the model."""
56
57 ALLOW_SUB_BATCHING = True
58 """Flag to allow mini batches of documents for a single query. Set to false for listwise models to ensure
59 correctness."""
60
[docs]
61 def __init__(self, config: LightningIRConfig, *args, **kwargs) -> None:
62 """Initializes the model.
63
64 :param config: Configuration class for the model
65 :type config: LightningIRConfig
66 """
67 super().__init__(config, *args, **kwargs)
68 self.config = config
69
70 self._sub_batch_size: int | None = None
71
72 def _backbone_forward(self, *args, **kwargs):
73 """Runs the forward method of the backbone model. Is overridden in
74 :class:`~lightning_ir.base.class_factory.LightningIRModelClassFactory`.
75
76 :raises NotImplementedError: If not overridden in the derived class
77 """
78 raise NotImplementedError
79
[docs]
80 def forward(self, *args, **kwargs) -> LightningIROutput:
81 """Forward method of the model. Must be implemented by the derived class."""
82 raise NotImplementedError
83
[docs]
84 def sparsification(
85 self, embeddings: torch.Tensor, sparsification_strategy: Literal["relu", "relu_log"] | None = None
86 ) -> torch.Tensor:
87 """Helper method to apply sparsification to the embeddings.
88
89 :param embeddings: Query or document embeddings
90 :type embeddings: torch.Tensor
91 :param sparsification_strategy: The sparsification strategy. No sparsification is applied if None,
92 defaults to None
93 :type sparsification_strategy: Literal['relu', 'relu_log'] | None, optional
94 :raises ValueError: If an unknown sparsification strategy is passed
95 :return: (Optionally) sparsified embeddings
96 :rtype: torch.Tensor
97 """
98 if sparsification_strategy is None:
99 return embeddings
100 if sparsification_strategy == "relu":
101 return torch.relu(embeddings)
102 if sparsification_strategy == "relu_log":
103 return torch.log1p(torch.relu(embeddings))
104 raise ValueError(f"Unknown sparsification strategy: {sparsification_strategy}")
105
[docs]
106 def pooling(
107 self,
108 embeddings: torch.Tensor,
109 attention_mask: torch.Tensor | None,
110 pooling_strategy: Literal["first", "mean", "max", "sum"] | None,
111 ) -> torch.Tensor:
112 """Helper method to apply pooling to the embeddings.
113
114 :param embeddings: Query or document embeddings
115 :type embeddings: torch.Tensor
116 :param attention_mask: Query or document attention mask
117 :type attention_mask: torch.Tensor | None
118 :param pooling_strategy: The pooling strategy. No pooling is applied if None.
119 :type pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None
120 :raises ValueError: If an unknown pooling strategy is passed
121 :return: (Optionally) pooled embeddings
122 :rtype: torch.Tensor
123 """
124 if pooling_strategy is None:
125 return embeddings
126 if pooling_strategy == "first":
127 return embeddings.index_select(1, torch.tensor(0, device=embeddings.device))
128 if pooling_strategy in ("sum", "mean"):
129 if attention_mask is not None:
130 embeddings = embeddings * attention_mask.unsqueeze(-1)
131 embeddings = embeddings.sum(dim=1, keepdim=True)
132 if pooling_strategy == "mean":
133 if attention_mask is not None:
134 embeddings = embeddings / attention_mask.sum(dim=1, keepdim=True).unsqueeze(-1)
135 return embeddings
136 if pooling_strategy == "max":
137 if attention_mask is not None:
138 embeddings = embeddings.masked_fill(~attention_mask.bool().unsqueeze(-1), float("-inf"))
139 return embeddings.amax(dim=1, keepdim=True)
140 raise ValueError(f"Unknown pooling strategy: {pooling_strategy}")
141
[docs]
142 @classmethod
143 def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> Self:
144 """Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained_ method to return a
145 derived LightningIRModel. See :class:`LightningIRModelClassFactory` for more details.
146
147.. _transformers.PreTrainedModel.from_pretrained: \
148 https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
149
150 :param model_name_or_path: Name or path of the pretrained model
151 :type model_name_or_path: str | Path
152 :raises ValueError: If called on the abstract class :class:`LightningIRModel` and no config is passed
153 :return: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin
154 :rtype: LightningIRModel
155
156 .. ::doctest
157 .. highlight:: python
158 .. code-block:: python
159
160 >>> # Loading using model class and backbone checkpoint
161 >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased"))
162 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
163 >>> # Loading using base class and backbone checkpoint
164 >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig()))
165 <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
166 """
167 # provides AutoModel.from_pretrained support
168 config = kwargs.get("config", None)
169 if cls is LightningIRModel or all(issubclass(base, LightningIRModel) for base in cls.__bases__):
170 # no backbone models found, create derived lightning-ir model based on backbone model
171 if config is not None:
172 ConfigClass = config.__class__
173 elif model_name_or_path in CHECKPOINT_MAPPING:
174 _config = CHECKPOINT_MAPPING[model_name_or_path]
175 ConfigClass = _config.__class__
176 if config is None:
177 config = _config
178 elif cls is not LightningIRModel:
179 ConfigClass = cls.config_class
180 else:
181 ConfigClass = type(LightningIRModelClassFactory.get_lightning_ir_config(model_name_or_path))
182 if ConfigClass is None:
183 raise ValueError("Pass a config to `from_pretrained`.")
184 backbone_config = LightningIRModelClassFactory.get_backbone_config(model_name_or_path).from_pretrained(
185 model_name_or_path
186 )
187 BackboneModel = _get_model_class(backbone_config)
188 cls = LightningIRModelClassFactory(ConfigClass).from_backbone_class(BackboneModel)
189 if config is not None:
190 if all(issubclass(base, LightningIRConfig) for base in config.__class__.__bases__):
191 derived_config = cls.config_class.from_pretrained(model_name_or_path, config=config)
192 derived_config.update(config.to_dict())
193 config = derived_config
194 kwargs["config"] = config
195 # NOTE 'config' is contained in kwargs, so we can update it
196 config, kwargs = _update_config_with_kwargs(**kwargs)
197 kwargs["config"] = config
198 return cls.from_pretrained(model_name_or_path, *args, **kwargs)
199 if issubclass(cls, BertModel):
200 kwargs["add_pooling_layer"] = False
201 key_mapping = kwargs.pop("key_mapping", {})
202 if model_name_or_path in STATE_DICT_KEY_MAPPING:
203 key_mapping.update(STATE_DICT_KEY_MAPPING[str(model_name_or_path)])
204 model = super().from_pretrained(model_name_or_path, *args, key_mapping=key_mapping, **kwargs)
205 if model_name_or_path in POST_LOAD_CALLBACKS:
206 model = POST_LOAD_CALLBACKS[str(model_name_or_path)](model)
207 return model
208
209
210T = TypeVar("T")
211
212
213def _cat_outputs(
214 outputs: Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None], OutputClass: Type[T] | None
215) -> torch.Tensor | T | None:
216 """Helper method to concatenate outputs of the model."""
217 if len(outputs) == 1:
218 return outputs[0]
219 if len(outputs) == 0 or outputs[0] is None or OutputClass is None:
220 return None
221 if isinstance(outputs[0], torch.Tensor):
222 return torch.cat(outputs, dim=0)
223 agg = defaultdict(list)
224 types = {}
225 for output in outputs:
226 for key, value in output.items():
227 agg[key].append(value)
228 types[key] = type(value)
229 kwargs = {key: _cat_outputs(value, types[key]) for key, value in agg.items()}
230 if OutputClass is BatchEncoding:
231 return OutputClass(kwargs)
232 return OutputClass(**kwargs)
233
234
[docs]
235class BatchEncodingWrapper(Protocol):
236 def __call__(self, encoding: BatchEncoding, *args, **kwargs) -> Any: ...
237
238
[docs]
239def batch_encoding_wrapper(func: BatchEncodingWrapper) -> BatchEncodingWrapper:
240 """Decorator to enable sub-batching for models that support it. Lowers the batch size of the input batch encoding
241 if the model runs out of memory.
242
243 :param func: Function to wrap that takes a batch encoding
244 :type func: BatchEncodingWrapper
245 :raises e: If CUDA runs out of memory even after lowering the batch size to 1
246 :raises ValueError: If no output was generated
247 :return: Wrapped function
248 :rtype: BatchEncodingWrapper
249 """
250
251 @wraps(func)
252 def wrapper(self, encoding: BatchEncoding, *args, **kwargs) -> Any:
253 if not self.ALLOW_SUB_BATCHING:
254 return func(self, encoding, *args, **kwargs)
255 sub_batch_size = self._sub_batch_size or encoding.input_ids.shape[0]
256 sub_encoding = encoding
257 remaining_encoding = encoding
258 OutputClass = None
259 outputs = []
260 while True:
261 try:
262 # ceil division
263 num_batches = -(remaining_encoding.input_ids.shape[0] // -sub_batch_size)
264 for _ in range(num_batches):
265 sub_encoding = BatchEncoding(
266 {key: value[:sub_batch_size] for key, value in remaining_encoding.items()}
267 )
268 output = func(self, sub_encoding, *args, **kwargs)
269 OutputClass = output.__class__
270 outputs.append(output)
271 remaining_encoding = BatchEncoding(
272 {key: value[sub_batch_size:] for key, value in remaining_encoding.items()}
273 )
274 break
275 except RuntimeError as e:
276 if "CUDA out of memory" in str(e) or "CUDACachingAllocator.cpp" in str(e):
277 self._sub_batch_size = sub_batch_size = sub_batch_size // 2
278 if sub_batch_size == 0:
279 raise e
280 else:
281 raise e
282 if OutputClass is None:
283 raise ValueError("No output was generated.")
284 return _cat_outputs(outputs, OutputClass)
285
286 return wrapper