Source code for lightning_ir.data.datamodule

  1"""
  2DataModule for Lightning IR that handles batching and collation of data samples.
  3
  4This module defines the LightningIRDataModule class that handles batching and collation of data samples for training and
  5inference in Lightning IR.
  6"""
  7
  8from __future__ import annotations
  9
 10from collections import defaultdict
 11from typing import Any, Dict, List, Literal, Sequence
 12
 13import torch
 14from lightning import LightningDataModule
 15from torch.utils.data import DataLoader, IterableDataset
 16
 17from .data import IndexBatch, RankBatch, SearchBatch, TrainBatch
 18from .dataset import (
 19    DocDataset,
 20    DocSample,
 21    QueryDataset,
 22    QuerySample,
 23    RankSample,
 24    RunDataset,
 25    TupleDataset,
 26    _DummyIterableDataset,
 27)
 28
 29
[docs] 30class LightningIRDataModule(LightningDataModule):
[docs] 31 def __init__( 32 self, 33 train_dataset: RunDataset | TupleDataset | None = None, 34 train_batch_size: int | None = None, 35 shuffle_train: bool = True, 36 inference_datasets: Sequence[RunDataset | TupleDataset | QueryDataset | DocDataset] | None = None, 37 inference_batch_size: int | None = None, 38 num_workers: int = 0, 39 ) -> None: 40 """Initializes a new Lightning IR DataModule. 41 42 Args: 43 train_dataset (RunDataset | TupleDataset | None): A training dataset. Defaults to None. 44 train_batch_size (int | None): Batch size to use for training. Defaults to None. 45 shuffle_train (bool): Whether to shuffle the training data. Defaults to True. 46 inference_datasets (Sequence[RunDataset | TupleDataset | QueryDataset | DocDataset] | None): List of 47 datasets to use for inference (indexing, searching, and re-ranking). Defaults to None. 48 inference_batch_size (int | None): Batch size to use for inference. Defaults to None. 49 num_workers (int): Number of workers for loading data in parallel. Defaults to 0. 50 """ 51 super().__init__() 52 self.num_workers = num_workers 53 54 self.train_dataset = train_dataset 55 self.train_batch_size = train_batch_size 56 self.shuffle_train = shuffle_train 57 self.inference_datasets = None if inference_datasets is None else list(inference_datasets) 58 self.inference_batch_size = inference_batch_size 59 60 if (self.train_batch_size is not None) != (self.train_dataset is not None): 61 raise ValueError("Both train_batch_size and train_dataset must be provided.") 62 if (self.inference_batch_size is not None) != (self.inference_datasets is not None): 63 raise ValueError("Both inference_batch_size and inference_dataset must be provided.")
64 65 def _setup_inference(self, stage: Literal["validate", "test"]) -> None: 66 if self.inference_datasets is None: 67 return 68 for inference_dataset in self.inference_datasets: 69 if isinstance(inference_dataset, TupleDataset): 70 if stage == "test": 71 raise ValueError("Prediction cannot be performed with TupleDataset.") 72 elif isinstance(inference_dataset, RunDataset): 73 if inference_dataset.sampling_strategy == "single_relevant": 74 raise ValueError("Inference RunDataset cannot use the single_relevant sampling strategy.") 75 elif isinstance(inference_dataset, (QueryDataset, DocDataset)): 76 pass 77 else: 78 raise ValueError( 79 "Inference Dataset must be of type RunDataset, TupleDataset, QueryDataset, or DocDataset." 80 ) 81
[docs] 82 def prepare_data(self) -> None: 83 """Downloads the data using ir_datasets if needed.""" 84 if self.train_dataset is not None: 85 self.train_dataset.prepare_data() 86 if self.inference_datasets is not None: 87 for inference_dataset in self.inference_datasets: 88 inference_dataset.prepare_data()
89
[docs] 90 def setup(self, stage: Literal["fit", "validate", "test"]) -> None: 91 """Sets up the data module for a given stage. 92 93 Args: 94 stage (Literal["fit", "validate", "test"]): Stage to set up the data module for. 95 Raises: 96 ValueError: If the stage is `fit` and no training dataset is provided. 97 """ 98 if stage == "fit": 99 if self.train_dataset is None: 100 raise ValueError("A training dataset and config must be provided.") 101 if stage == "fit": 102 stage = "validate" 103 self._setup_inference(stage)
104
[docs] 105 def train_dataloader(self) -> DataLoader: 106 """Returns a dataloader for training. 107 108 Returns: 109 DataLoader: Dataloader for training. 110 Raises: 111 ValueError: If no training dataset is found. 112 """ 113 if self.train_dataset is None: 114 raise ValueError("No training dataset found.") 115 return DataLoader( 116 self.train_dataset, 117 batch_size=self.train_batch_size, 118 num_workers=self.num_workers, 119 collate_fn=self._collate_fn, 120 shuffle=(False if isinstance(self.train_dataset, IterableDataset) else self.shuffle_train), 121 prefetch_factor=16 if self.num_workers > 0 else None, 122 )
123
[docs] 124 def val_dataloader(self) -> List[DataLoader]: 125 """Returns a list of dataloaders for validation. 126 127 Returns: 128 List[DataLoader]: Dataloaders for validation. 129 """ 130 return self.inference_dataloader()
131
[docs] 132 def test_dataloader(self) -> List[DataLoader]: 133 """Returns a list of dataloaders for testing. 134 135 Returns: 136 List[DataLoader]: Dataloaders for testing. 137 """ 138 return self.inference_dataloader()
139
[docs] 140 def predict_dataloader(self) -> Any: 141 """Returns a list of dataloaders for predicting. 142 143 Returns: 144 List[DataLoader]: Dataloaders for predicting. 145 """ 146 return self.inference_dataloader()
147
[docs] 148 def inference_dataloader(self) -> List[DataLoader]: 149 """Returns a list of dataloaders for inference (validation, testing, or predicting). 150 151 Returns: 152 List[DataLoader]: Dataloaders for inference. 153 """ 154 inference_datasets = self.inference_datasets or [] 155 dataloaders = [ 156 DataLoader( 157 dataset, 158 batch_size=self.inference_batch_size, 159 num_workers=self.num_workers, 160 collate_fn=self._collate_fn, 161 prefetch_factor=16 if self.num_workers > 0 else None, 162 ) 163 for dataset in inference_datasets 164 if not dataset._SKIP 165 ] 166 if not dataloaders: 167 dataloaders = [DataLoader(_DummyIterableDataset())] 168 return dataloaders
169 170 def _aggregate_samples(self, samples: Sequence[RankSample | QuerySample | DocSample]) -> Dict[str, Any]: 171 aggregated = defaultdict(list) 172 field_options = { 173 "query_id": {"extend": False}, 174 "query": {"extend": False}, 175 "doc_id": {"extend": False}, 176 "doc_ids": {"extend": False}, 177 "doc": {"extend": False}, 178 "docs": {"extend": False}, 179 "targets": {"extend": True}, 180 "qrels": {"extend": True}, 181 } 182 for sample in samples: 183 for field in sample.__dict__: 184 extend = field_options[field]["extend"] 185 key = field if field.endswith("s") else f"{field}s" 186 value = getattr(sample, field) 187 if value is None: 188 continue 189 if extend: 190 aggregated[key].extend(value) 191 else: 192 aggregated[key].append(value) 193 return aggregated 194 195 def _clean_sample(self, aggregated: Dict[str, Any]) -> Dict[str, Any]: 196 kwargs: Dict[str, Any] = dict(aggregated) 197 if "querys" in kwargs: 198 kwargs["queries"] = kwargs["querys"] 199 del kwargs["querys"] 200 if "targets" in kwargs: 201 kwargs["targets"] = torch.stack(kwargs["targets"]) 202 return kwargs 203 204 def _parse_batch( 205 self, sample: RankSample | QuerySample | DocSample, **kwargs 206 ) -> RankBatch | TrainBatch | IndexBatch | SearchBatch: 207 if isinstance(sample, RankSample): 208 if "targets" in kwargs: 209 return TrainBatch(**kwargs) 210 else: 211 return RankBatch(**kwargs) 212 if isinstance(sample, QuerySample): 213 return SearchBatch(**kwargs) 214 if isinstance(sample, DocSample): 215 return IndexBatch(**kwargs) 216 raise ValueError("Invalid dataset configuration.") 217 218 def _collate_fn( 219 self, 220 samples: Sequence[RankSample | QuerySample | DocSample] | RankSample | QuerySample | DocSample, 221 ) -> TrainBatch | RankBatch | IndexBatch | SearchBatch: 222 if isinstance(samples, (RankSample, QuerySample, DocSample)): 223 samples = [samples] 224 aggregated = self._aggregate_samples(samples) 225 kwargs = self._clean_sample(aggregated) 226 return self._parse_batch(samples[0], **kwargs)