1"""
2DataModule for Lightning IR that handles batching and collation of data samples.
3
4This module defines the LightningIRDataModule class that handles batching and collation of data samples for training and
5inference in Lightning IR.
6"""
7
8from __future__ import annotations
9
10from collections import defaultdict
11from collections.abc import Sequence
12from typing import Any, Literal
13
14import torch
15from lightning import LightningDataModule
16from torch.utils.data import DataLoader, IterableDataset
17
18from .data import IndexBatch, RankBatch, SearchBatch, TrainBatch
19from .dataset import (
20 DocDataset,
21 DocSample,
22 QueryDataset,
23 QuerySample,
24 RankSample,
25 RunDataset,
26 TupleDataset,
27 _DummyIterableDataset,
28)
29
30
[docs]
31class LightningIRDataModule(LightningDataModule):
[docs]
32 def __init__(
33 self,
34 train_dataset: RunDataset | TupleDataset | None = None,
35 train_batch_size: int | None = None,
36 shuffle_train: bool = True,
37 inference_datasets: Sequence[RunDataset | TupleDataset | QueryDataset | DocDataset] | None = None,
38 inference_batch_size: int | None = None,
39 num_workers: int = 0,
40 ) -> None:
41 """Initializes a new Lightning IR DataModule.
42
43 Args:
44 train_dataset (RunDataset | TupleDataset | None): A training dataset. Defaults to None.
45 train_batch_size (int | None): Batch size to use for training. Defaults to None.
46 shuffle_train (bool): Whether to shuffle the training data. Defaults to True.
47 inference_datasets (Sequence[RunDataset | TupleDataset | QueryDataset | DocDataset] | None): list of
48 datasets to use for inference (indexing, searching, and re-ranking). Defaults to None.
49 inference_batch_size (int | None): Batch size to use for inference. Defaults to None.
50 num_workers (int): Number of workers for loading data in parallel. Defaults to 0.
51 """
52 super().__init__()
53 self.num_workers = num_workers
54
55 self.train_dataset = train_dataset
56 self.train_batch_size = train_batch_size
57 self.shuffle_train = shuffle_train
58 self.inference_datasets = None if inference_datasets is None else list(inference_datasets)
59 self.inference_batch_size = inference_batch_size
60
61 if (self.train_batch_size is not None) != (self.train_dataset is not None):
62 raise ValueError("Both train_batch_size and train_dataset must be provided.")
63 if (self.inference_batch_size is not None) != (self.inference_datasets is not None):
64 raise ValueError("Both inference_batch_size and inference_dataset must be provided.")
65
66 def _setup_inference(self, stage: Literal["validate", "test"]) -> None:
67 if self.inference_datasets is None:
68 return
69 for inference_dataset in self.inference_datasets:
70 if isinstance(inference_dataset, TupleDataset):
71 if stage == "test":
72 raise ValueError("Prediction cannot be performed with TupleDataset.")
73 elif isinstance(inference_dataset, RunDataset):
74 if inference_dataset.sampling_strategy == "single_relevant":
75 raise ValueError("Inference RunDataset cannot use the single_relevant sampling strategy.")
76 elif isinstance(inference_dataset, (QueryDataset, DocDataset)):
77 pass
78 else:
79 raise ValueError(
80 "Inference Dataset must be of type RunDataset, TupleDataset, QueryDataset, or DocDataset."
81 )
82
[docs]
83 def prepare_data(self) -> None:
84 """Downloads the data using ir_datasets if needed."""
85 if self.train_dataset is not None:
86 self.train_dataset.prepare_data()
87 if self.inference_datasets is not None:
88 for inference_dataset in self.inference_datasets:
89 inference_dataset.prepare_data()
90
[docs]
91 def setup(self, stage: Literal["fit", "validate", "test"]) -> None:
92 """Sets up the data module for a given stage.
93
94 Args:
95 stage (Literal["fit", "validate", "test"]): Stage to set up the data module for.
96 Raises:
97 ValueError: If the stage is `fit` and no training dataset is provided.
98 """
99 if stage == "fit":
100 if self.train_dataset is None:
101 raise ValueError("A training dataset and config must be provided.")
102 if stage == "fit":
103 stage = "validate"
104 self._setup_inference(stage)
105
[docs]
106 def train_dataloader(self) -> DataLoader:
107 """Returns a dataloader for training.
108
109 Returns:
110 DataLoader: Dataloader for training.
111 Raises:
112 ValueError: If no training dataset is found.
113 """
114 if self.train_dataset is None:
115 raise ValueError("No training dataset found.")
116 return DataLoader(
117 self.train_dataset,
118 batch_size=self.train_batch_size,
119 num_workers=self.num_workers,
120 collate_fn=self._collate_fn,
121 shuffle=(False if isinstance(self.train_dataset, IterableDataset) else self.shuffle_train),
122 prefetch_factor=16 if self.num_workers > 0 else None,
123 )
124
[docs]
125 def val_dataloader(self) -> list[DataLoader]:
126 """Returns a list of dataloaders for validation.
127
128 Returns:
129 list[DataLoader]: Dataloaders for validation.
130 """
131 return self.inference_dataloader()
132
[docs]
133 def test_dataloader(self) -> list[DataLoader]:
134 """Returns a list of dataloaders for testing.
135
136 Returns:
137 list[DataLoader]: Dataloaders for testing.
138 """
139 return self.inference_dataloader()
140
[docs]
141 def predict_dataloader(self) -> Any:
142 """Returns a list of dataloaders for predicting.
143
144 Returns:
145 list[DataLoader]: Dataloaders for predicting.
146 """
147 return self.inference_dataloader()
148
[docs]
149 def inference_dataloader(self) -> list[DataLoader]:
150 """Returns a list of dataloaders for inference (validation, testing, or predicting).
151
152 Returns:
153 list[DataLoader]: Dataloaders for inference.
154 """
155 inference_datasets = self.inference_datasets or []
156 dataloaders = [
157 DataLoader(
158 dataset,
159 batch_size=self.inference_batch_size,
160 num_workers=self.num_workers,
161 collate_fn=self._collate_fn,
162 prefetch_factor=16 if self.num_workers > 0 else None,
163 )
164 for dataset in inference_datasets
165 if not dataset._SKIP
166 ]
167 if not dataloaders:
168 dataloaders = [DataLoader(_DummyIterableDataset())]
169 return dataloaders
170
171 def _aggregate_samples(self, samples: Sequence[RankSample | QuerySample | DocSample]) -> dict[str, Any]:
172 aggregated = defaultdict(list)
173 field_options = {
174 "query_id": {"extend": False},
175 "query": {"extend": False},
176 "doc_id": {"extend": False},
177 "doc_ids": {"extend": False},
178 "doc": {"extend": False},
179 "docs": {"extend": False},
180 "targets": {"extend": True},
181 "qrels": {"extend": True},
182 }
183 for sample in samples:
184 for field in sample.__dict__:
185 extend = field_options[field]["extend"]
186 key = field if field.endswith("s") else f"{field}s"
187 value = getattr(sample, field)
188 if value is None:
189 continue
190 if extend:
191 aggregated[key].extend(value)
192 else:
193 aggregated[key].append(value)
194 return aggregated
195
196 def _clean_sample(self, aggregated: dict[str, Any]) -> dict[str, Any]:
197 kwargs: dict[str, Any] = dict(aggregated)
198 if "querys" in kwargs:
199 kwargs["queries"] = kwargs["querys"]
200 del kwargs["querys"]
201 if "targets" in kwargs:
202 kwargs["targets"] = torch.stack(kwargs["targets"])
203 return kwargs
204
205 def _parse_batch(
206 self, sample: RankSample | QuerySample | DocSample, **kwargs
207 ) -> RankBatch | TrainBatch | IndexBatch | SearchBatch:
208 if isinstance(sample, RankSample):
209 if "targets" in kwargs:
210 return TrainBatch(**kwargs)
211 else:
212 return RankBatch(**kwargs)
213 if isinstance(sample, QuerySample):
214 return SearchBatch(**kwargs)
215 if isinstance(sample, DocSample):
216 return IndexBatch(**kwargs)
217 raise ValueError("Invalid dataset configuration.")
218
219 def _collate_fn(
220 self,
221 samples: Sequence[RankSample | QuerySample | DocSample] | RankSample | QuerySample | DocSample,
222 ) -> TrainBatch | RankBatch | IndexBatch | SearchBatch:
223 if isinstance(samples, (RankSample, QuerySample, DocSample)):
224 samples = [samples]
225 aggregated = self._aggregate_samples(samples)
226 kwargs = self._clean_sample(aggregated)
227 return self._parse_batch(samples[0], **kwargs)