1"""Main entry point for Lightning IR using the Lightning IR CLI.
2
3The module also defines several helper classes for configuring and running experiments.
4"""
5
6import argparse
7import os
8import sys
9from collections.abc import Mapping
10from pathlib import Path
11from textwrap import dedent
12from typing import Any
13
14import torch
15from lightning import LightningDataModule, LightningModule, Trainer
16from lightning.fabric.loggers.logger import _DummyExperiment as DummyExperiment
17from lightning.pytorch.cli import LightningCLI, SaveConfigCallback
18from lightning.pytorch.loggers import WandbLogger
19from typing_extensions import override
20
21import lightning_ir # noqa: F401
22from lightning_ir.schedulers.lr_schedulers import LR_SCHEDULERS, WarmupLRScheduler
23
24if torch.cuda.is_available():
25 torch.set_float32_matmul_precision("medium")
26
27sys.path.append(str(Path.cwd()))
28
29os.environ["TOKENIZERS_PARALLELISM"] = "false"
30
31
[docs]
32class LightningIRSaveConfigCallback(SaveConfigCallback):
33 """Lightning IR configuration saving callback with intelligent save conditions.
34
35 This callback extends PyTorch Lightning's SaveConfigCallback_ to provide smarter configuration
36 file saving behavior specifically designed for Lightning IR workflows. It only saves YAML
37 configuration files during the 'fit' stage and when a logger is properly configured, preventing
38 unnecessary file creation during inference operations like indexing, searching, or re-ranking.
39
40 The callback automatically saves the complete experiment configuration including model, data,
41 trainer, and optimizer settings to enable full experiment reproducibility.
42
43 .. _SaveConfigCallback: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.SaveConfigCallback.html
44
45 Examples:
46 Automatic usage through LightningIRCLI:
47
48 .. code-block:: python
49
50 from lightning_ir.main import LightningIRCLI, LightningIRSaveConfigCallback
51
52 # The callback is automatically configured in the CLI
53 cli = LightningIRCLI(
54 save_config_callback=LightningIRSaveConfigCallback,
55 save_config_kwargs={"config_filename": "pl_config.yaml", "overwrite": True}
56 )
57
58 Manual usage with trainer:
59
60 .. code-block:: python
61
62 from lightning_ir import LightningIRTrainer, LightningIRSaveConfigCallback
63
64 # Add callback to trainer
65 callback = LightningIRSaveConfigCallback(
66 config_filename="experiment_config.yaml",
67 overwrite=True
68 )
69 trainer = LightningIRTrainer(callbacks=[callback])
70
71 Configuration file output example:
72
73 .. code-block:: yaml
74
75 # Generated pl_config.yaml
76 model:
77 class_path: lightning_ir.BiEncoderModule
78 init_args:
79 model_name_or_path: bert-base-uncased
80 data:
81 class_path: lightning_ir.LightningIRDataModule
82 init_args:
83 train_batch_size: 32
84 trainer:
85 max_steps: 100000
86 """
87
[docs]
88 @override
89 def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
90 """Setup the callback with intelligent save conditions.
91
92 This method implements the core logic for conditional configuration saving. It only
93 proceeds with configuration file saving when both conditions are met:
94 1. The training stage is 'fit' (not inference stages like index/search/re_rank)
95 2. A logger is properly configured on the trainer
96
97 This prevents unnecessary configuration file creation during inference operations
98 while ensuring that training experiments are properly documented for reproducibility.
99
100 Args:
101 trainer (Trainer): The Lightning trainer instance containing training configuration
102 and logger settings.
103 pl_module (LightningModule): The Lightning module instance being trained or used
104 for inference.
105 stage (str): The current training stage. Expected values include 'fit', 'validate',
106 'test', 'predict', as well as Lightning IR specific stages like 'index',
107 'search', 're_rank'.
108
109 Examples:
110 The method automatically handles different stages:
111
112 .. code-block:: python
113
114 # During training - config will be saved
115 trainer.fit(module, datamodule) # stage='fit', saves config
116
117 # During inference - config will NOT be saved
118 trainer.index(module, datamodule) # stage='index', skips saving
119 trainer.search(module, datamodule) # stage='search', skips saving
120 trainer.re_rank(module, datamodule) # stage='re_rank', skips saving
121 """
122 if stage != "fit" or trainer.logger is None:
123 return
124 super().setup(trainer, pl_module, stage)
125
126
[docs]
127class LightningIRWandbLogger(WandbLogger):
128 """Lightning IR extension of the Weights & Biases Logger for enhanced experiment tracking.
129
130 This logger extends the PyTorch Lightning WandbLogger_ to provide improved file management
131 and experiment tracking specifically tailored for Lightning IR experiments. It ensures that
132 experiment files are properly saved in the WandB run's files directory and handles the
133 save directory management correctly.
134
135 .. _WandbLogger: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.WandbLogger.html
136 """
137
138 @property
139 def save_dir(self) -> str | None:
140 """Gets the save directory for experiment files and artifacts.
141
142 This property returns the directory where WandB saves experiment files, logs, and
143 artifacts. It handles the case where the experiment might not be properly initialized
144 (DummyExperiment) and returns None in such cases to prevent errors.
145
146 Returns:
147 str | None: The absolute path to the WandB experiment directory where files
148 are saved, or None if the experiment is not properly initialized
149 or WandB is running in offline/disabled mode.
150 """
151 if isinstance(self.experiment, DummyExperiment):
152 return None
153 return self.experiment.dir
154
155
[docs]
156class LightningIRTrainer(Trainer):
157 """Lightning IR Trainer that extends PyTorch Lightning Trainer with information retrieval specific methods.
158
159 This trainer inherits all functionality from the PyTorch Lightning Trainer_ and adds specialized methods
160 for information retrieval tasks including document indexing, searching, and re-ranking. It provides a
161 unified interface for both training neural ranking models and performing inference across different
162 IR stages.
163
164 The trainer seamlessly integrates with Lightning IR callbacks and supports all standard Lightning features
165 including distributed training, mixed precision, gradient accumulation, and checkpointing.
166
167 .. _PyTorch Lightning Trainer: https://lightning.ai/docs/pytorch/stable/common/trainer.html
168
169 Examples:
170 Basic usage for fine-tuning and inference:
171
172 .. code-block:: python
173
174 from lightning_ir import LightningIRTrainer, BiEncoderModule, LightningIRDataModule
175
176 # Initialize trainer with Lightning configuration
177 trainer = LightningIRTrainer(
178 max_steps=100_000,
179 precision="16-mixed",
180 devices=2,
181 accelerator="gpu"
182 )
183
184 # Fine-tune a model
185 module = BiEncoderModule(model_name_or_path="bert-base-uncased")
186 datamodule = LightningIRDataModule(...)
187 trainer.fit(module, datamodule)
188
189 # Index documents
190 trainer.index(module, datamodule)
191
192 # Search for relevant documents
193 trainer.search(module, datamodule)
194
195 # Re-rank retrieved documents
196 trainer.re_rank(module, datamodule)
197
198 Note:
199 The trainer requires appropriate callbacks to be configured for each IR task:
200 - IndexCallback for indexing operations
201 - SearchCallback for search operations
202 - ReRankCallback for re-ranking operations
203 """
204
205 # TODO check that correct callbacks are registered for each subcommand
206
[docs]
207 def index(
208 self,
209 model: LightningModule | None = None,
210 dataloaders: Any | LightningDataModule | None = None,
211 ckpt_path: str | Path | None = None,
212 verbose: bool = True,
213 datamodule: LightningDataModule | None = None,
214 ) -> list[Mapping[str, float]]:
215 """Index a collection of documents using a fine-tuned bi-encoder model.
216
217 This method performs document indexing by running inference on a document collection and
218 storing the resulting embeddings in an index structure. It requires an IndexCallback to
219 be configured in the trainer to handle the actual indexing process.
220
221 Args:
222 model (LightningModule | None): The LightningIRModule containing the bi-encoder model
223 to use for encoding documents. If None, uses the model from the datamodule.
224 dataloaders (Any | LightningDataModule | None): DataLoader(s) or LightningIRDataModule
225 containing the document collection to index. Should contain DocDataset instances.
226 ckpt_path (str | Path | None): Path to a model checkpoint to load before indexing.
227 If None, uses the current model state.
228 verbose (bool): Whether to display progress during indexing. Defaults to True.
229 datamodule (LightningDataModule | None): LightningIRDataModule instance. Alternative
230 to passing dataloaders directly.
231
232 Returns:
233 list[Mapping[str, float]]: list of dictionaries containing indexing metrics and results.
234
235 Example:
236 .. code-block:: python
237
238 from lightning_ir import LightningIRTrainer, BiEncoderModule, LightningIRDataModule
239 from lightning_ir import IndexCallback, TorchDenseIndexConfig, DocDataset
240
241 # Setup trainer with index callback
242 callback = IndexCallback(
243 index_dir="./index",
244 index_config=TorchDenseIndexConfig()
245 )
246 trainer = LightningIRTrainer(callbacks=[callback])
247
248 # Setup model and data
249 module = BiEncoderModule(model_name_or_path="webis/bert-bi-encoder")
250 datamodule = LightningIRDataModule(
251 inference_datasets=[DocDataset("msmarco-passage")]
252 )
253
254 # Index the documents
255 trainer.index(module, datamodule)
256
257 Note:
258 - Requires IndexCallback to be configured in trainer callbacks
259 - Only works with bi-encoder models that can encode documents
260 - The index type and configuration are specified in the IndexCallback
261 """
262 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
263
[docs]
264 def search(
265 self,
266 model: LightningModule | None = None,
267 dataloaders: Any | LightningDataModule | None = None,
268 ckpt_path: str | Path | None = None,
269 verbose: bool = True,
270 datamodule: LightningDataModule | None = None,
271 ) -> list[Mapping[str, float]]:
272 """Search for relevant documents using a bi-encoder model and pre-built index.
273
274 This method performs dense or sparse retrieval by encoding queries and searching through
275 a pre-built index to find the most relevant documents. It requires a SearchCallback to
276 be configured in the trainer to handle the search process and optionally a RankCallback
277 to save results.
278
279 Args:
280 model (LightningModule | None): The LightningIRModule containing the bi-encoder model
281 to use for encoding queries. If None, uses the model from the datamodule.
282 dataloaders (Any | LightningDataModule | None): DataLoader(s) or LightningIRDataModule
283 containing the queries to search for. Should contain QueryDataset instances.
284 ckpt_path (str | Path | None): Path to a model checkpoint to load before searching.
285 If None, uses the current model state.
286 verbose (bool): Whether to display progress during searching. Defaults to True.
287 datamodule (LightningDataModule | None): LightningIRDataModule instance. Alternative
288 to passing dataloaders directly.
289
290 Returns:
291 list[Mapping[str, float]]: list of dictionaries containing search metrics and effectiveness
292 results (if relevance judgments are available).
293
294 Example:
295 .. code-block:: python
296
297 from lightning_ir import LightningIRTrainer, BiEncoderModule, LightningIRDataModule
298 from lightning_ir import SearchCallback, RankCallback, QueryDataset
299 from lightning_ir import TorchDenseSearchConfig
300
301 # Setup trainer with search and rank callbacks
302 search_callback = SearchCallback(
303 index_dir="./index",
304 search_config=TorchDenseSearchConfig(k=100)
305 )
306 rank_callback = RankCallback(results_dir="./results")
307 trainer = LightningIRTrainer(callbacks=[search_callback, rank_callback])
308
309 # Setup model and data
310 module = BiEncoderModule(model_name_or_path="webis/bert-bi-encoder")
311 datamodule = LightningIRDataModule(
312 inference_datasets=[QueryDataset("trec-dl-2019/queries")]
313 )
314
315 # Search for relevant documents
316 results = trainer.search(module, datamodule)
317
318 Note:
319 - Requires SearchCallback to be configured in trainer callbacks
320 - Index must be built beforehand using the index() method
321 - Search configuration must match the index configuration used during indexing
322 - Add RankCallback to save search results to disk
323 """
324 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
325
[docs]
326 def re_rank(
327 self,
328 model: LightningModule | None = None,
329 dataloaders: Any | LightningDataModule | None = None,
330 ckpt_path: str | Path | None = None,
331 verbose: bool = True,
332 datamodule: LightningDataModule | None = None,
333 ) -> list[Mapping[str, float]]:
334 """Re-rank a set of retrieved documents using bi-encoder or cross-encoder models.
335
336 This method performs re-ranking by scoring query-document pairs and reordering them
337 based on relevance scores. Cross-encoders typically provide higher effectiveness for
338 re-ranking tasks compared to bi-encoders. It requires a ReRankCallback to be configured
339 in the trainer to handle saving the re-ranked results.
340
341 Args:
342 model (LightningModule | None): The LightningIRModule containing the model to use for
343 re-ranking. Can be either BiEncoderModule or CrossEncoderModule. If None, uses
344 the model from the datamodule.
345 dataloaders (Any | LightningDataModule | None): DataLoader(s) or LightningIRDataModule
346 containing the query-document pairs to re-rank. Should contain RunDataset instances.
347 ckpt_path (str | Path | None): Path to a model checkpoint to load before re-ranking.
348 If None, uses the current model state.
349 verbose (bool): Whether to display progress during re-ranking. Defaults to True.
350 datamodule (LightningDataModule | None): LightningIRDataModule instance. Alternative
351 to passing dataloaders directly.
352
353 Returns:
354 list[Mapping[str, float]]: list of dictionaries containing re-ranking metrics and
355 effectiveness results (if relevance judgments are available).
356
357 Example:
358 .. code-block:: python
359
360 from lightning_ir import LightningIRTrainer, CrossEncoderModule, LightningIRDataModule
361 from lightning_ir import ReRankCallback, RunDataset
362
363 # Setup trainer with re-rank callback
364 rerank_callback = ReRankCallback(results_dir="./reranked_results")
365 trainer = LightningIRTrainer(callbacks=[rerank_callback])
366
367 # Setup model and data
368 module = CrossEncoderModule(model_name_or_path="webis/bert-cross-encoder")
369 datamodule = LightningIRDataModule(
370 inference_datasets=[RunDataset("path/to/run/file.txt")]
371 )
372
373 # Re-rank the documents
374 results = trainer.re_rank(module, datamodule)
375
376 Note:
377 - Requires ReRankCallback to be configured in trainer callbacks
378 - Input data should be in run file format (query-document pairs with initial scores)
379 - Cross-encoders typically provide better effectiveness than bi-encoders for re-ranking
380 """
381 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
382
383
[docs]
384class LightningIRCLI(LightningCLI):
385 """Lightning IR Command Line Interface that extends PyTorch LightningCLI_ for information retrieval tasks.
386
387 This CLI provides a unified command-line interface for fine-tuning neural ranking models and running
388 information retrieval experiments. It extends the PyTorch LightningCLI_ with IR-specific subcommands
389 and automatic configuration management for seamless integration between models, data, and training.
390
391 .. _LightningCLI: https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html
392
393 Examples:
394 Command line usage:
395
396 .. code-block:: bash
397
398 # Fine-tune a model
399 lightning-ir fit --config fine-tune.yaml
400
401 # Index documents
402 lightning-ir index --config index.yaml
403
404 # Search for documents
405 lightning-ir search --config search.yaml
406
407 # Re-rank documents
408 lightning-ir re_rank --config re-rank.yaml
409
410 # Generate default configuration
411 lightning-ir fit --print_config > config.yaml
412
413 Programmatic usage:
414
415 .. code-block:: python
416
417 from lightning_ir.main import LightningIRCLI, LightningIRTrainer, LightningIRSaveConfigCallback
418
419 # Create CLI instance
420 cli = LightningIRCLI(
421 trainer_class=LightningIRTrainer,
422 save_config_callback=LightningIRSaveConfigCallback,
423 save_config_kwargs={"config_filename": "pl_config.yaml", "overwrite": True}
424 )
425
426 YAML configuration example:
427
428 .. code-block:: yaml
429
430 model:
431 class_path: lightning_ir.BiEncoderModule
432 init_args:
433 model_name_or_path: bert-base-uncased
434 loss_functions:
435 - class_path: lightning_ir.InBatchCrossEntropy
436
437 data:
438 class_path: lightning_ir.LightningIRDataModule
439 init_args:
440 train_dataset:
441 class_path: lightning_ir.TupleDataset
442 init_args:
443 dataset_id: msmarco-passage/train/triples-small
444 train_batch_size: 32
445
446 trainer:
447 max_steps: 100000
448 precision: "16-mixed"
449
450 optimizer:
451 class_path: torch.optim.AdamW
452 init_args:
453 lr: 5e-5
454
455 Note:
456 - Automatically links model and data configurations (model_name_or_path, config)
457 - Links trainer max_steps to learning rate scheduler num_training_steps
458 - Supports all PyTorch Lightning CLI features including class path instantiation
459 - Built-in support for warmup learning rate schedulers
460 - Saves configuration files automatically during training
461 """
462
495
[docs]
496 def add_arguments_to_parser(self, parser):
497 """Add Lightning IR specific arguments and links to the CLI parser.
498
499 This method extends the base Lightning CLI parser with IR-specific learning rate
500 schedulers and automatically links related configuration arguments to ensure
501 consistency between model, data, and trainer configurations.
502
503 Args:
504 parser: The CLI argument parser to extend.
505
506 Note:
507 Automatic argument linking:
508 - model.init_args.model_name_or_path -> data.init_args.model_name_or_path
509 - model.init_args.config -> data.init_args.config
510 - trainer.max_steps -> lr_scheduler.init_args.num_training_steps
511 """
512 parser.add_lr_scheduler_args(tuple(LR_SCHEDULERS))
513 parser.link_arguments(
514 "model.init_args.model_name_or_path", "data.init_args.model_name_or_path"
515 )
516 parser.link_arguments("model.init_args.config", "data.init_args.config")
517 parser.link_arguments(
518 "trainer.max_steps", "lr_scheduler.init_args.num_training_steps"
519 )
520
[docs]
521 @staticmethod
522 def subcommands() -> dict[str, set[str]]:
523 """Defines the list of available subcommands and the arguments to skip.
524
525 Returns a dictionary mapping subcommand names to the set of configuration sections
526 they require. This extends the base Lightning CLI with IR-specific subcommands for
527 indexing, searching, and re-ranking operations.
528
529 Returns:
530 dict[str, set[str]]: Dictionary mapping subcommand names to required config sections.
531 - fit: Standard Lightning training subcommand with all sections
532 - index: Document indexing requiring model, dataloaders, and datamodule
533 - search: Document search requiring model, dataloaders, and datamodule
534 - re_rank: Document re-ranking requiring model, dataloaders, and datamodule
535
536 """
537 return {
538 "fit": LightningCLI.subcommands()["fit"],
539 "index": {"model", "dataloaders", "datamodule"},
540 "search": {"model", "dataloaders", "datamodule"},
541 "re_rank": {"model", "dataloaders", "datamodule"},
542 }
543
544 def _add_configure_optimizers_method_to_model(self, subcommand: str | None) -> None:
545 import warnings
546
547 with warnings.catch_warnings():
548 warnings.simplefilter("ignore")
549 return super()._add_configure_optimizers_method_to_model(subcommand)
550
551
[docs]
552def main():
553 """Entry point for the Lightning IR command line interface.
554
555 Initializes and runs the LightningIRCLI with the LightningIRTrainer and configuration
556 callback. This function serves as the main entry point when Lightning IR is run from
557 the command line using the 'lightning-ir' command.
558
559 The CLI is configured with:
560 - LightningIRTrainer as the trainer class for all IR operations
561 - LightningIRSaveConfigCallback for automatic config file saving during training
562 - Configuration to save configs as 'pl_config.yaml' with overwrite enabled
563
564 Examples:
565 This function is called when using Lightning IR from command line:
566
567 .. code-block:: bash
568
569 lightning-ir fit --config fine-tune.yaml
570 lightning-ir index --config index.yaml
571 lightning-ir search --config search.yaml
572 lightning-ir re_rank --config re-rank.yaml
573
574 Note:
575 - Configuration files are automatically saved during fit operations
576 - All PyTorch Lightning CLI features are available
577 - Supports YAML configuration files and command line argument overrides
578 """
579 help_epilog = dedent(
580 """\
581 For full documentation, visit: https://lightning-ir.webis.de
582 """
583 )
584
585 LightningIRCLI(
586 trainer_class=LightningIRTrainer,
587 save_config_callback=LightningIRSaveConfigCallback,
588 save_config_kwargs={"config_filename": "pl_config.yaml", "overwrite": True},
589 parser_kwargs={
590 "epilog": help_epilog,
591 "formatter_class": argparse.RawDescriptionHelpFormatter,
592 },
593 )
594
595
596if __name__ == "__main__":
597 main()