Source code for lightning_ir.data.datamodule

  1"""
  2DataModule for Lightning IR that handles batching and collation of data samples.
  3
  4This module defines the LightningIRDataModule class that handles batching and collation of data samples for training and
  5inference in Lightning IR.
  6"""
  7
  8from __future__ import annotations
  9
 10from collections import defaultdict
 11from typing import Any, Dict, List, Literal, Sequence
 12
 13import torch
 14from lightning import LightningDataModule
 15from torch.utils.data import DataLoader, IterableDataset
 16
 17from .data import IndexBatch, RankBatch, SearchBatch, TrainBatch
 18from .dataset import (
 19    DocDataset,
 20    DocSample,
 21    QueryDataset,
 22    QuerySample,
 23    RankSample,
 24    RunDataset,
 25    TupleDataset,
 26    _DummyIterableDataset,
 27)
 28
 29
[docs] 30class LightningIRDataModule(LightningDataModule):
[docs] 31 def __init__( 32 self, 33 train_dataset: RunDataset | TupleDataset | None = None, 34 train_batch_size: int | None = None, 35 shuffle_train: bool = True, 36 inference_datasets: Sequence[RunDataset | TupleDataset | QueryDataset | DocDataset] | None = None, 37 inference_batch_size: int | None = None, 38 num_workers: int = 0, 39 ) -> None: 40 """Initializes a new Lightning IR DataModule. 41 42 :param train_dataset: A training dataset, defaults to None 43 :type train_dataset: RunDataset | TupleDataset | None, optional 44 :param train_batch_size: Batch size to use for training, defaults to None 45 :type train_batch_size: int | None, optional 46 :param shuffle_train: Whether to shuffle the training data, defaults to True 47 :type shuffle_train: bool, optional 48 :param inference_datasets: List of datasets to use for inference (indexing, searching, and re-ranking), 49 defaults to None 50 :type inference_datasets: Sequence[RunDataset | TupleDataset | QueryDataset | DocDataset] | None, optional 51 :param inference_batch_size: Batch size to use for inference, defaults to None 52 :type inference_batch_size: int | None, optional 53 :param num_workers: Number of workers for loading data in parallel, defaults to 0 54 :type num_workers: int, optional 55 """ 56 super().__init__() 57 self.num_workers = num_workers 58 59 self.train_dataset = train_dataset 60 self.train_batch_size = train_batch_size 61 self.shuffle_train = shuffle_train 62 self.inference_datasets = None if inference_datasets is None else list(inference_datasets) 63 self.inference_batch_size = inference_batch_size 64 65 if (self.train_batch_size is not None) != (self.train_dataset is not None): 66 raise ValueError("Both train_batch_size and train_dataset must be provided.") 67 if (self.inference_batch_size is not None) != (self.inference_datasets is not None): 68 raise ValueError("Both inference_batch_size and inference_dataset must be provided.")
69 70 def _setup_inference(self, stage: Literal["validate", "test"]) -> None: 71 if self.inference_datasets is None: 72 return 73 for inference_dataset in self.inference_datasets: 74 if isinstance(inference_dataset, TupleDataset): 75 if stage == "test": 76 raise ValueError("Prediction cannot be performed with TupleDataset.") 77 elif isinstance(inference_dataset, RunDataset): 78 if inference_dataset.sampling_strategy == "single_relevant": 79 raise ValueError("Inference RunDataset cannot use the single_relevant sampling strategy.") 80 elif isinstance(inference_dataset, (QueryDataset, DocDataset)): 81 pass 82 else: 83 raise ValueError( 84 "Inference Dataset must be of type RunDataset, TupleDataset, QueryDataset, or DocDataset." 85 ) 86
[docs] 87 def prepare_data(self) -> None: 88 """Downloads the data using ir_datasets if needed.""" 89 if self.train_dataset is not None: 90 self.train_dataset.prepare_data() 91 if self.inference_datasets is not None: 92 for inference_dataset in self.inference_datasets: 93 inference_dataset.prepare_data()
94
[docs] 95 def setup(self, stage: Literal["fit", "validate", "test"]) -> None: 96 """Sets up the data module for a given stage. 97 98 :param stage: Stage to set up the data module for 99 :type stage: Literal['fit', 'validate', 'test'] 100 :raises ValueError: If the stage is `fit` and no training dataset is provided 101 """ 102 if stage == "fit": 103 if self.train_dataset is None: 104 raise ValueError("A training dataset and config must be provided.") 105 if stage == "fit": 106 stage = "validate" 107 self._setup_inference(stage)
108
[docs] 109 def train_dataloader(self) -> DataLoader: 110 """Returns a dataloader for training. 111 112 :raises ValueError: If no training dataset is found 113 :return: Dataloader for training 114 :rtype: DataLoader 115 """ 116 if self.train_dataset is None: 117 raise ValueError("No training dataset found.") 118 return DataLoader( 119 self.train_dataset, 120 batch_size=self.train_batch_size, 121 num_workers=self.num_workers, 122 collate_fn=self._collate_fn, 123 shuffle=(False if isinstance(self.train_dataset, IterableDataset) else self.shuffle_train), 124 prefetch_factor=16 if self.num_workers > 0 else None, 125 )
126
[docs] 127 def val_dataloader(self) -> List[DataLoader]: 128 """Returns a list of dataloaders for validation. 129 130 :return: Dataloaders for validation 131 :rtype: List[DataLoader] 132 """ 133 return self.inference_dataloader()
134
[docs] 135 def test_dataloader(self) -> List[DataLoader]: 136 """Returns a list of dataloaders for testing. 137 138 :return: Dataloaders for testing 139 :rtype: List[DataLoader] 140 """ 141 return self.inference_dataloader()
142
[docs] 143 def predict_dataloader(self) -> Any: 144 """Returns a list of dataloaders for predicting. 145 146 :return: Dataloaders for predicting 147 :rtype: List[DataLoader] 148 """ 149 return self.inference_dataloader()
150
[docs] 151 def inference_dataloader(self) -> List[DataLoader]: 152 """Returns a list of dataloaders for inference (validation, testing, or predicting). 153 154 :return: Dataloaders for inference 155 :rtype: List[DataLoader] 156 """ 157 inference_datasets = self.inference_datasets or [] 158 dataloaders = [ 159 DataLoader( 160 dataset, 161 batch_size=self.inference_batch_size, 162 num_workers=self.num_workers, 163 collate_fn=self._collate_fn, 164 prefetch_factor=16 if self.num_workers > 0 else None, 165 ) 166 for dataset in inference_datasets 167 if not dataset._SKIP 168 ] 169 if not dataloaders: 170 dataloaders = [DataLoader(_DummyIterableDataset())] 171 return dataloaders
172 173 def _aggregate_samples(self, samples: Sequence[RankSample | QuerySample | DocSample]) -> Dict[str, Any]: 174 aggregated = defaultdict(list) 175 field_options = { 176 "query_id": {"extend": False}, 177 "query": {"extend": False}, 178 "doc_id": {"extend": False}, 179 "doc_ids": {"extend": False}, 180 "doc": {"extend": False}, 181 "docs": {"extend": False}, 182 "targets": {"extend": True}, 183 "qrels": {"extend": True}, 184 } 185 for sample in samples: 186 for field in sample.__dict__: 187 extend = field_options[field]["extend"] 188 key = field if field.endswith("s") else f"{field}s" 189 value = getattr(sample, field) 190 if value is None: 191 continue 192 if extend: 193 aggregated[key].extend(value) 194 else: 195 aggregated[key].append(value) 196 return aggregated 197 198 def _clean_sample(self, aggregated: Dict[str, Any]) -> Dict[str, Any]: 199 kwargs: Dict[str, Any] = dict(aggregated) 200 if "querys" in kwargs: 201 kwargs["queries"] = kwargs["querys"] 202 del kwargs["querys"] 203 if "targets" in kwargs: 204 kwargs["targets"] = torch.stack(kwargs["targets"]) 205 return kwargs 206 207 def _parse_batch( 208 self, sample: RankSample | QuerySample | DocSample, **kwargs 209 ) -> RankBatch | TrainBatch | IndexBatch | SearchBatch: 210 if isinstance(sample, RankSample): 211 if "targets" in kwargs: 212 return TrainBatch(**kwargs) 213 else: 214 return RankBatch(**kwargs) 215 if isinstance(sample, QuerySample): 216 return SearchBatch(**kwargs) 217 if isinstance(sample, DocSample): 218 return IndexBatch(**kwargs) 219 raise ValueError("Invalid dataset configuration.") 220 221 def _collate_fn( 222 self, 223 samples: Sequence[RankSample | QuerySample | DocSample] | RankSample | QuerySample | DocSample, 224 ) -> TrainBatch | RankBatch | IndexBatch | SearchBatch: 225 if isinstance(samples, (RankSample, QuerySample, DocSample)): 226 samples = [samples] 227 aggregated = self._aggregate_samples(samples) 228 kwargs = self._clean_sample(aggregated) 229 return self._parse_batch(samples[0], **kwargs)