1"""Main entry point for Lightning IR using the Lightning IR CLI.
2
3The module also defines several helper classes for configuring and running experiments.
4"""
5
6import os
7import sys
8from pathlib import Path
9from typing import Any, Dict, List, Mapping, Set
10
11import torch
12from lightning import LightningDataModule, LightningModule, Trainer
13from lightning.fabric.loggers.logger import _DummyExperiment as DummyExperiment
14from lightning.pytorch.cli import LightningCLI, SaveConfigCallback
15from lightning.pytorch.loggers import WandbLogger
16from typing_extensions import override
17
18import lightning_ir # noqa: F401
19from lightning_ir.schedulers.lr_schedulers import LR_SCHEDULERS, WarmupLRScheduler
20
21if torch.cuda.is_available():
22 torch.set_float32_matmul_precision("medium")
23
24sys.path.append(str(Path.cwd()))
25
26os.environ["TOKENIZERS_PARALLELISM"] = "false"
27
28
[docs]
29class LightningIRSaveConfigCallback(SaveConfigCallback):
30 """Lightning IR configuration saving callback with intelligent save conditions.
31
32 This callback extends PyTorch Lightning's SaveConfigCallback_ to provide smarter configuration
33 file saving behavior specifically designed for Lightning IR workflows. It only saves YAML
34 configuration files during the 'fit' stage and when a logger is properly configured, preventing
35 unnecessary file creation during inference operations like indexing, searching, or re-ranking.
36
37 The callback automatically saves the complete experiment configuration including model, data,
38 trainer, and optimizer settings to enable full experiment reproducibility.
39
40 .. _SaveConfigCallback: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.SaveConfigCallback.html
41
42 Examples:
43 Automatic usage through LightningIRCLI:
44
45 .. code-block:: python
46
47 from lightning_ir.main import LightningIRCLI, LightningIRSaveConfigCallback
48
49 # The callback is automatically configured in the CLI
50 cli = LightningIRCLI(
51 save_config_callback=LightningIRSaveConfigCallback,
52 save_config_kwargs={"config_filename": "pl_config.yaml", "overwrite": True}
53 )
54
55 Manual usage with trainer:
56
57 .. code-block:: python
58
59 from lightning_ir import LightningIRTrainer, LightningIRSaveConfigCallback
60
61 # Add callback to trainer
62 callback = LightningIRSaveConfigCallback(
63 config_filename="experiment_config.yaml",
64 overwrite=True
65 )
66 trainer = LightningIRTrainer(callbacks=[callback])
67
68 Configuration file output example:
69
70 .. code-block:: yaml
71
72 # Generated pl_config.yaml
73 model:
74 class_path: lightning_ir.BiEncoderModule
75 init_args:
76 model_name_or_path: bert-base-uncased
77 data:
78 class_path: lightning_ir.LightningIRDataModule
79 init_args:
80 train_batch_size: 32
81 trainer:
82 max_steps: 100000
83 """
84
[docs]
85 @override
86 def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
87 """Setup the callback with intelligent save conditions.
88
89 This method implements the core logic for conditional configuration saving. It only
90 proceeds with configuration file saving when both conditions are met:
91 1. The training stage is 'fit' (not inference stages like index/search/re_rank)
92 2. A logger is properly configured on the trainer
93
94 This prevents unnecessary configuration file creation during inference operations
95 while ensuring that training experiments are properly documented for reproducibility.
96
97 Args:
98 trainer (Trainer): The Lightning trainer instance containing training configuration
99 and logger settings.
100 pl_module (LightningModule): The Lightning module instance being trained or used
101 for inference.
102 stage (str): The current training stage. Expected values include 'fit', 'validate',
103 'test', 'predict', as well as Lightning IR specific stages like 'index',
104 'search', 're_rank'.
105
106 Examples:
107 The method automatically handles different stages:
108
109 .. code-block:: python
110
111 # During training - config will be saved
112 trainer.fit(module, datamodule) # stage='fit', saves config
113
114 # During inference - config will NOT be saved
115 trainer.index(module, datamodule) # stage='index', skips saving
116 trainer.search(module, datamodule) # stage='search', skips saving
117 trainer.re_rank(module, datamodule) # stage='re_rank', skips saving
118 """
119 if stage != "fit" or trainer.logger is None:
120 return
121 super().setup(trainer, pl_module, stage)
122
123
[docs]
124class LightningIRWandbLogger(WandbLogger):
125 """Lightning IR extension of the Weights & Biases Logger for enhanced experiment tracking.
126
127 This logger extends the PyTorch Lightning WandbLogger_ to provide improved file management
128 and experiment tracking specifically tailored for Lightning IR experiments. It ensures that
129 experiment files are properly saved in the WandB run's files directory and handles the
130 save directory management correctly.
131
132 .. _WandbLogger: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.WandbLogger.html
133 """
134
135 @property
136 def save_dir(self) -> str | None:
137 """Gets the save directory for experiment files and artifacts.
138
139 This property returns the directory where WandB saves experiment files, logs, and
140 artifacts. It handles the case where the experiment might not be properly initialized
141 (DummyExperiment) and returns None in such cases to prevent errors.
142
143 Returns:
144 str | None: The absolute path to the WandB experiment directory where files
145 are saved, or None if the experiment is not properly initialized
146 or WandB is running in offline/disabled mode.
147 """
148 super().save_dir
149 if isinstance(self.experiment, DummyExperiment):
150 return None
151 return self.experiment.dir
152
153
[docs]
154class LightningIRTrainer(Trainer):
155 """Lightning IR Trainer that extends PyTorch Lightning Trainer with information retrieval specific methods.
156
157 This trainer inherits all functionality from the PyTorch Lightning Trainer_ and adds specialized methods
158 for information retrieval tasks including document indexing, searching, and re-ranking. It provides a
159 unified interface for both training neural ranking models and performing inference across different
160 IR stages.
161
162 The trainer seamlessly integrates with Lightning IR callbacks and supports all standard Lightning features
163 including distributed training, mixed precision, gradient accumulation, and checkpointing.
164
165 .. _PyTorch Lightning Trainer: https://lightning.ai/docs/pytorch/stable/common/trainer.html
166
167 Examples:
168 Basic usage for fine-tuning and inference:
169
170 .. code-block:: python
171
172 from lightning_ir import LightningIRTrainer, BiEncoderModule, LightningIRDataModule
173
174 # Initialize trainer with Lightning configuration
175 trainer = LightningIRTrainer(
176 max_steps=100_000,
177 precision="16-mixed",
178 devices=2,
179 accelerator="gpu"
180 )
181
182 # Fine-tune a model
183 module = BiEncoderModule(model_name_or_path="bert-base-uncased")
184 datamodule = LightningIRDataModule(...)
185 trainer.fit(module, datamodule)
186
187 # Index documents
188 trainer.index(module, datamodule)
189
190 # Search for relevant documents
191 trainer.search(module, datamodule)
192
193 # Re-rank retrieved documents
194 trainer.re_rank(module, datamodule)
195
196 Note:
197 The trainer requires appropriate callbacks to be configured for each IR task:
198 - IndexCallback for indexing operations
199 - SearchCallback for search operations
200 - ReRankCallback for re-ranking operations
201 """
202
203 # TODO check that correct callbacks are registered for each subcommand
204
[docs]
205 def index(
206 self,
207 model: LightningModule | None = None,
208 dataloaders: Any | LightningDataModule | None = None,
209 ckpt_path: str | Path | None = None,
210 verbose: bool = True,
211 datamodule: LightningDataModule | None = None,
212 ) -> List[Mapping[str, float]]:
213 """Index a collection of documents using a fine-tuned bi-encoder model.
214
215 This method performs document indexing by running inference on a document collection and
216 storing the resulting embeddings in an index structure. It requires an IndexCallback to
217 be configured in the trainer to handle the actual indexing process.
218
219 Args:
220 model (LightningModule | None): The LightningIRModule containing the bi-encoder model
221 to use for encoding documents. If None, uses the model from the datamodule.
222 dataloaders (Any | LightningDataModule | None): DataLoader(s) or LightningIRDataModule
223 containing the document collection to index. Should contain DocDataset instances.
224 ckpt_path (str | Path | None): Path to a model checkpoint to load before indexing.
225 If None, uses the current model state.
226 verbose (bool): Whether to display progress during indexing. Defaults to True.
227 datamodule (LightningDataModule | None): LightningIRDataModule instance. Alternative
228 to passing dataloaders directly.
229
230 Returns:
231 List[Mapping[str, float]]: List of dictionaries containing indexing metrics and results.
232
233 Example:
234 .. code-block:: python
235
236 from lightning_ir import LightningIRTrainer, BiEncoderModule, LightningIRDataModule
237 from lightning_ir import IndexCallback, TorchDenseIndexConfig, DocDataset
238
239 # Setup trainer with index callback
240 callback = IndexCallback(
241 index_dir="./index",
242 index_config=TorchDenseIndexConfig()
243 )
244 trainer = LightningIRTrainer(callbacks=[callback])
245
246 # Setup model and data
247 module = BiEncoderModule(model_name_or_path="webis/bert-bi-encoder")
248 datamodule = LightningIRDataModule(
249 inference_datasets=[DocDataset("msmarco-passage")]
250 )
251
252 # Index the documents
253 trainer.index(module, datamodule)
254
255 Note:
256 - Requires IndexCallback to be configured in trainer callbacks
257 - Only works with bi-encoder models that can encode documents
258 - The index type and configuration are specified in the IndexCallback
259 """
260 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
261
[docs]
262 def search(
263 self,
264 model: LightningModule | None = None,
265 dataloaders: Any | LightningDataModule | None = None,
266 ckpt_path: str | Path | None = None,
267 verbose: bool = True,
268 datamodule: LightningDataModule | None = None,
269 ) -> List[Mapping[str, float]]:
270 """Search for relevant documents using a bi-encoder model and pre-built index.
271
272 This method performs dense or sparse retrieval by encoding queries and searching through
273 a pre-built index to find the most relevant documents. It requires a SearchCallback to
274 be configured in the trainer to handle the search process and optionally a RankCallback
275 to save results.
276
277 Args:
278 model (LightningModule | None): The LightningIRModule containing the bi-encoder model
279 to use for encoding queries. If None, uses the model from the datamodule.
280 dataloaders (Any | LightningDataModule | None): DataLoader(s) or LightningIRDataModule
281 containing the queries to search for. Should contain QueryDataset instances.
282 ckpt_path (str | Path | None): Path to a model checkpoint to load before searching.
283 If None, uses the current model state.
284 verbose (bool): Whether to display progress during searching. Defaults to True.
285 datamodule (LightningDataModule | None): LightningIRDataModule instance. Alternative
286 to passing dataloaders directly.
287
288 Returns:
289 List[Mapping[str, float]]: List of dictionaries containing search metrics and effectiveness
290 results (if relevance judgments are available).
291
292 Example:
293 .. code-block:: python
294
295 from lightning_ir import LightningIRTrainer, BiEncoderModule, LightningIRDataModule
296 from lightning_ir import SearchCallback, RankCallback, QueryDataset
297 from lightning_ir import TorchDenseSearchConfig
298
299 # Setup trainer with search and rank callbacks
300 search_callback = SearchCallback(
301 index_dir="./index",
302 search_config=TorchDenseSearchConfig(k=100)
303 )
304 rank_callback = RankCallback(results_dir="./results")
305 trainer = LightningIRTrainer(callbacks=[search_callback, rank_callback])
306
307 # Setup model and data
308 module = BiEncoderModule(model_name_or_path="webis/bert-bi-encoder")
309 datamodule = LightningIRDataModule(
310 inference_datasets=[QueryDataset("trec-dl-2019/queries")]
311 )
312
313 # Search for relevant documents
314 results = trainer.search(module, datamodule)
315
316 Note:
317 - Requires SearchCallback to be configured in trainer callbacks
318 - Index must be built beforehand using the index() method
319 - Search configuration must match the index configuration used during indexing
320 - Add RankCallback to save search results to disk
321 """
322 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
323
[docs]
324 def re_rank(
325 self,
326 model: LightningModule | None = None,
327 dataloaders: Any | LightningDataModule | None = None,
328 ckpt_path: str | Path | None = None,
329 verbose: bool = True,
330 datamodule: LightningDataModule | None = None,
331 ) -> List[Mapping[str, float]]:
332 """Re-rank a set of retrieved documents using bi-encoder or cross-encoder models.
333
334 This method performs re-ranking by scoring query-document pairs and reordering them
335 based on relevance scores. Cross-encoders typically provide higher effectiveness for
336 re-ranking tasks compared to bi-encoders. It requires a ReRankCallback to be configured
337 in the trainer to handle saving the re-ranked results.
338
339 Args:
340 model (LightningModule | None): The LightningIRModule containing the model to use for
341 re-ranking. Can be either BiEncoderModule or CrossEncoderModule. If None, uses
342 the model from the datamodule.
343 dataloaders (Any | LightningDataModule | None): DataLoader(s) or LightningIRDataModule
344 containing the query-document pairs to re-rank. Should contain RunDataset instances.
345 ckpt_path (str | Path | None): Path to a model checkpoint to load before re-ranking.
346 If None, uses the current model state.
347 verbose (bool): Whether to display progress during re-ranking. Defaults to True.
348 datamodule (LightningDataModule | None): LightningIRDataModule instance. Alternative
349 to passing dataloaders directly.
350
351 Returns:
352 List[Mapping[str, float]]: List of dictionaries containing re-ranking metrics and
353 effectiveness results (if relevance judgments are available).
354
355 Example:
356 .. code-block:: python
357
358 from lightning_ir import LightningIRTrainer, CrossEncoderModule, LightningIRDataModule
359 from lightning_ir import ReRankCallback, RunDataset
360
361 # Setup trainer with re-rank callback
362 rerank_callback = ReRankCallback(results_dir="./reranked_results")
363 trainer = LightningIRTrainer(callbacks=[rerank_callback])
364
365 # Setup model and data
366 module = CrossEncoderModule(model_name_or_path="webis/bert-cross-encoder")
367 datamodule = LightningIRDataModule(
368 inference_datasets=[RunDataset("path/to/run/file.txt")]
369 )
370
371 # Re-rank the documents
372 results = trainer.re_rank(module, datamodule)
373
374 Note:
375 - Requires ReRankCallback to be configured in trainer callbacks
376 - Input data should be in run file format (query-document pairs with initial scores)
377 - Cross-encoders typically provide better effectiveness than bi-encoders for re-ranking
378 """
379 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
380
381
[docs]
382class LightningIRCLI(LightningCLI):
383 """Lightning IR Command Line Interface that extends PyTorch LightningCLI_ for information retrieval tasks.
384
385 This CLI provides a unified command-line interface for fine-tuning neural ranking models and running
386 information retrieval experiments. It extends the PyTorch LightningCLI_ with IR-specific subcommands
387 and automatic configuration management for seamless integration between models, data, and training.
388
389 .. _LightningCLI: https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html
390
391 Examples:
392 Command line usage:
393
394 .. code-block:: bash
395
396 # Fine-tune a model
397 lightning-ir fit --config fine-tune.yaml
398
399 # Index documents
400 lightning-ir index --config index.yaml
401
402 # Search for documents
403 lightning-ir search --config search.yaml
404
405 # Re-rank documents
406 lightning-ir re_rank --config re-rank.yaml
407
408 # Generate default configuration
409 lightning-ir fit --print_config > config.yaml
410
411 Programmatic usage:
412
413 .. code-block:: python
414
415 from lightning_ir.main import LightningIRCLI, LightningIRTrainer, LightningIRSaveConfigCallback
416
417 # Create CLI instance
418 cli = LightningIRCLI(
419 trainer_class=LightningIRTrainer,
420 save_config_callback=LightningIRSaveConfigCallback,
421 save_config_kwargs={"config_filename": "pl_config.yaml", "overwrite": True}
422 )
423
424 YAML configuration example:
425
426 .. code-block:: yaml
427
428 model:
429 class_path: lightning_ir.BiEncoderModule
430 init_args:
431 model_name_or_path: bert-base-uncased
432 loss_functions:
433 - class_path: lightning_ir.InBatchCrossEntropy
434
435 data:
436 class_path: lightning_ir.LightningIRDataModule
437 init_args:
438 train_dataset:
439 class_path: lightning_ir.TupleDataset
440 init_args:
441 dataset_id: msmarco-passage/train/triples-small
442 train_batch_size: 32
443
444 trainer:
445 max_steps: 100000
446 precision: "16-mixed"
447
448 optimizer:
449 class_path: torch.optim.AdamW
450 init_args:
451 lr: 5e-5
452
453 Note:
454 - Automatically links model and data configurations (model_name_or_path, config)
455 - Links trainer max_steps to learning rate scheduler num_training_steps
456 - Supports all PyTorch Lightning CLI features including class path instantiation
457 - Built-in support for warmup learning rate schedulers
458 - Saves configuration files automatically during training
459 """
460
491
[docs]
492 def add_arguments_to_parser(self, parser):
493 """Add Lightning IR specific arguments and links to the CLI parser.
494
495 This method extends the base Lightning CLI parser with IR-specific learning rate
496 schedulers and automatically links related configuration arguments to ensure
497 consistency between model, data, and trainer configurations.
498
499 Args:
500 parser: The CLI argument parser to extend.
501
502 Note:
503 Automatic argument linking:
504 - model.init_args.model_name_or_path -> data.init_args.model_name_or_path
505 - model.init_args.config -> data.init_args.config
506 - trainer.max_steps -> lr_scheduler.init_args.num_training_steps
507 """
508 parser.add_lr_scheduler_args(tuple(LR_SCHEDULERS))
509 parser.link_arguments("model.init_args.model_name_or_path", "data.init_args.model_name_or_path")
510 parser.link_arguments("model.init_args.config", "data.init_args.config")
511 parser.link_arguments("trainer.max_steps", "lr_scheduler.init_args.num_training_steps")
512
[docs]
513 @staticmethod
514 def subcommands() -> Dict[str, Set[str]]:
515 """Defines the list of available subcommands and the arguments to skip.
516
517 Returns a dictionary mapping subcommand names to the set of configuration sections
518 they require. This extends the base Lightning CLI with IR-specific subcommands for
519 indexing, searching, and re-ranking operations.
520
521 Returns:
522 Dict[str, Set[str]]: Dictionary mapping subcommand names to required config sections.
523 - fit: Standard Lightning training subcommand with all sections
524 - index: Document indexing requiring model, dataloaders, and datamodule
525 - search: Document search requiring model, dataloaders, and datamodule
526 - re_rank: Document re-ranking requiring model, dataloaders, and datamodule
527
528 """
529 return {
530 "fit": LightningCLI.subcommands()["fit"],
531 "index": {"model", "dataloaders", "datamodule"},
532 "search": {"model", "dataloaders", "datamodule"},
533 "re_rank": {"model", "dataloaders", "datamodule"},
534 }
535
536 def _add_configure_optimizers_method_to_model(self, subcommand: str | None) -> None:
537 import warnings
538
539 with warnings.catch_warnings():
540 warnings.simplefilter("ignore")
541 return super()._add_configure_optimizers_method_to_model(subcommand)
542
543
[docs]
544def main():
545 """Entry point for the Lightning IR command line interface.
546
547 Initializes and runs the LightningIRCLI with the LightningIRTrainer and configuration
548 callback. This function serves as the main entry point when Lightning IR is run from
549 the command line using the 'lightning-ir' command.
550
551 The CLI is configured with:
552 - LightningIRTrainer as the trainer class for all IR operations
553 - LightningIRSaveConfigCallback for automatic config file saving during training
554 - Configuration to save configs as 'pl_config.yaml' with overwrite enabled
555
556 Examples:
557 This function is called when using Lightning IR from command line:
558
559 .. code-block:: bash
560
561 lightning-ir fit --config fine-tune.yaml
562 lightning-ir index --config index.yaml
563 lightning-ir search --config search.yaml
564 lightning-ir re_rank --config re-rank.yaml
565
566 Note:
567 - Configuration files are automatically saved during fit operations
568 - All PyTorch Lightning CLI features are available
569 - Supports YAML configuration files and command line argument overrides
570 """
571 LightningIRCLI(
572 trainer_class=LightningIRTrainer,
573 save_config_callback=LightningIRSaveConfigCallback,
574 save_config_kwargs={"config_filename": "pl_config.yaml", "overwrite": True},
575 )
576
577
578if __name__ == "__main__":
579 main()