1"""
2DataModule for Lightning IR that handles batching and collation of data samples.
3
4This module defines the LightningIRDataModule class that handles batching and collation of data samples for training and
5inference in Lightning IR.
6"""
7
8from __future__ import annotations
9
10from collections import defaultdict
11from typing import Any, Dict, List, Literal, Sequence
12
13import torch
14from lightning import LightningDataModule
15from torch.utils.data import DataLoader, IterableDataset
16
17from .data import IndexBatch, RankBatch, SearchBatch, TrainBatch
18from .dataset import (
19 DocDataset,
20 DocSample,
21 QueryDataset,
22 QuerySample,
23 RankSample,
24 RunDataset,
25 TupleDataset,
26 _DummyIterableDataset,
27)
28
29
[docs]
30class LightningIRDataModule(LightningDataModule):
[docs]
31 def __init__(
32 self,
33 train_dataset: RunDataset | TupleDataset | None = None,
34 train_batch_size: int | None = None,
35 shuffle_train: bool = True,
36 inference_datasets: Sequence[RunDataset | TupleDataset | QueryDataset | DocDataset] | None = None,
37 inference_batch_size: int | None = None,
38 num_workers: int = 0,
39 ) -> None:
40 """Initializes a new Lightning IR DataModule.
41
42 :param train_dataset: A training dataset, defaults to None
43 :type train_dataset: RunDataset | TupleDataset | None, optional
44 :param train_batch_size: Batch size to use for training, defaults to None
45 :type train_batch_size: int | None, optional
46 :param shuffle_train: Whether to shuffle the training data, defaults to True
47 :type shuffle_train: bool, optional
48 :param inference_datasets: List of datasets to use for inference (indexing, searching, and re-ranking),
49 defaults to None
50 :type inference_datasets: Sequence[RunDataset | TupleDataset | QueryDataset | DocDataset] | None, optional
51 :param inference_batch_size: Batch size to use for inference, defaults to None
52 :type inference_batch_size: int | None, optional
53 :param num_workers: Number of workers for loading data in parallel, defaults to 0
54 :type num_workers: int, optional
55 """
56 super().__init__()
57 self.num_workers = num_workers
58
59 self.train_dataset = train_dataset
60 self.train_batch_size = train_batch_size
61 self.shuffle_train = shuffle_train
62 self.inference_datasets = None if inference_datasets is None else list(inference_datasets)
63 self.inference_batch_size = inference_batch_size
64
65 if (self.train_batch_size is not None) != (self.train_dataset is not None):
66 raise ValueError("Both train_batch_size and train_dataset must be provided.")
67 if (self.inference_batch_size is not None) != (self.inference_datasets is not None):
68 raise ValueError("Both inference_batch_size and inference_dataset must be provided.")
69
70 def _setup_inference(self, stage: Literal["validate", "test"]) -> None:
71 if self.inference_datasets is None:
72 return
73 for inference_dataset in self.inference_datasets:
74 if isinstance(inference_dataset, TupleDataset):
75 if stage == "test":
76 raise ValueError("Prediction cannot be performed with TupleDataset.")
77 elif isinstance(inference_dataset, RunDataset):
78 if inference_dataset.sampling_strategy == "single_relevant":
79 raise ValueError("Inference RunDataset cannot use the single_relevant sampling strategy.")
80 elif isinstance(inference_dataset, (QueryDataset, DocDataset)):
81 pass
82 else:
83 raise ValueError(
84 "Inference Dataset must be of type RunDataset, TupleDataset, QueryDataset, or DocDataset."
85 )
86
[docs]
87 def prepare_data(self) -> None:
88 """Downloads the data using ir_datasets if needed."""
89 if self.train_dataset is not None:
90 self.train_dataset.prepare_data()
91 if self.inference_datasets is not None:
92 for inference_dataset in self.inference_datasets:
93 inference_dataset.prepare_data()
94
[docs]
95 def setup(self, stage: Literal["fit", "validate", "test"]) -> None:
96 """Sets up the data module for a given stage.
97
98 :param stage: Stage to set up the data module for
99 :type stage: Literal['fit', 'validate', 'test']
100 :raises ValueError: If the stage is `fit` and no training dataset is provided
101 """
102 if stage == "fit":
103 if self.train_dataset is None:
104 raise ValueError("A training dataset and config must be provided.")
105 if stage == "fit":
106 stage = "validate"
107 self._setup_inference(stage)
108
[docs]
109 def train_dataloader(self) -> DataLoader:
110 """Returns a dataloader for training.
111
112 :raises ValueError: If no training dataset is found
113 :return: Dataloader for training
114 :rtype: DataLoader
115 """
116 if self.train_dataset is None:
117 raise ValueError("No training dataset found.")
118 return DataLoader(
119 self.train_dataset,
120 batch_size=self.train_batch_size,
121 num_workers=self.num_workers,
122 collate_fn=self._collate_fn,
123 shuffle=(False if isinstance(self.train_dataset, IterableDataset) else self.shuffle_train),
124 prefetch_factor=16 if self.num_workers > 0 else None,
125 )
126
[docs]
127 def val_dataloader(self) -> List[DataLoader]:
128 """Returns a list of dataloaders for validation.
129
130 :return: Dataloaders for validation
131 :rtype: List[DataLoader]
132 """
133 return self.inference_dataloader()
134
[docs]
135 def test_dataloader(self) -> List[DataLoader]:
136 """Returns a list of dataloaders for testing.
137
138 :return: Dataloaders for testing
139 :rtype: List[DataLoader]
140 """
141 return self.inference_dataloader()
142
[docs]
143 def predict_dataloader(self) -> Any:
144 """Returns a list of dataloaders for predicting.
145
146 :return: Dataloaders for predicting
147 :rtype: List[DataLoader]
148 """
149 return self.inference_dataloader()
150
[docs]
151 def inference_dataloader(self) -> List[DataLoader]:
152 """Returns a list of dataloaders for inference (validation, testing, or predicting).
153
154 :return: Dataloaders for inference
155 :rtype: List[DataLoader]
156 """
157 inference_datasets = self.inference_datasets or []
158 dataloaders = [
159 DataLoader(
160 dataset,
161 batch_size=self.inference_batch_size,
162 num_workers=self.num_workers,
163 collate_fn=self._collate_fn,
164 prefetch_factor=16 if self.num_workers > 0 else None,
165 )
166 for dataset in inference_datasets
167 if not dataset._SKIP
168 ]
169 if not dataloaders:
170 dataloaders = [DataLoader(_DummyIterableDataset())]
171 return dataloaders
172
173 def _aggregate_samples(self, samples: Sequence[RankSample | QuerySample | DocSample]) -> Dict[str, Any]:
174 aggregated = defaultdict(list)
175 field_options = {
176 "query_id": {"extend": False},
177 "query": {"extend": False},
178 "doc_id": {"extend": False},
179 "doc_ids": {"extend": False},
180 "doc": {"extend": False},
181 "docs": {"extend": False},
182 "targets": {"extend": True},
183 "qrels": {"extend": True},
184 }
185 for sample in samples:
186 for field in sample.__dict__:
187 extend = field_options[field]["extend"]
188 key = field if field.endswith("s") else f"{field}s"
189 value = getattr(sample, field)
190 if value is None:
191 continue
192 if extend:
193 aggregated[key].extend(value)
194 else:
195 aggregated[key].append(value)
196 return aggregated
197
198 def _clean_sample(self, aggregated: Dict[str, Any]) -> Dict[str, Any]:
199 kwargs: Dict[str, Any] = dict(aggregated)
200 if "querys" in kwargs:
201 kwargs["queries"] = kwargs["querys"]
202 del kwargs["querys"]
203 if "targets" in kwargs:
204 kwargs["targets"] = torch.stack(kwargs["targets"])
205 return kwargs
206
207 def _parse_batch(
208 self, sample: RankSample | QuerySample | DocSample, **kwargs
209 ) -> RankBatch | TrainBatch | IndexBatch | SearchBatch:
210 if isinstance(sample, RankSample):
211 if "targets" in kwargs:
212 return TrainBatch(**kwargs)
213 else:
214 return RankBatch(**kwargs)
215 if isinstance(sample, QuerySample):
216 return SearchBatch(**kwargs)
217 if isinstance(sample, DocSample):
218 return IndexBatch(**kwargs)
219 raise ValueError("Invalid dataset configuration.")
220
221 def _collate_fn(
222 self,
223 samples: Sequence[RankSample | QuerySample | DocSample] | RankSample | QuerySample | DocSample,
224 ) -> TrainBatch | RankBatch | IndexBatch | SearchBatch:
225 if isinstance(samples, (RankSample, QuerySample, DocSample)):
226 samples = [samples]
227 aggregated = self._aggregate_samples(samples)
228 kwargs = self._clean_sample(aggregated)
229 return self._parse_batch(samples[0], **kwargs)