Source code for lightning_ir.main

  1"""Main entry point for Lightning IR using the Lightning IR CLI.
  2
  3The module also defines several helper classes for configuring and running experiments.
  4"""
  5
  6import argparse
  7import os
  8import sys
  9from collections.abc import Mapping
 10from pathlib import Path
 11from textwrap import dedent
 12from typing import Any
 13
 14import torch
 15from lightning import LightningDataModule, LightningModule, Trainer
 16from lightning.fabric.loggers.logger import _DummyExperiment as DummyExperiment
 17from lightning.pytorch.cli import LightningCLI, SaveConfigCallback
 18from lightning.pytorch.loggers import WandbLogger
 19from typing_extensions import override
 20
 21import lightning_ir  # noqa: F401
 22from lightning_ir.schedulers.lr_schedulers import LR_SCHEDULERS, WarmupLRScheduler
 23
 24if torch.cuda.is_available():
 25    torch.set_float32_matmul_precision("medium")
 26
 27sys.path.append(str(Path.cwd()))
 28
 29os.environ["TOKENIZERS_PARALLELISM"] = "false"
 30
 31
[docs] 32class LightningIRSaveConfigCallback(SaveConfigCallback): 33 """Lightning IR configuration saving callback with intelligent save conditions. 34 35 This callback extends PyTorch Lightning's SaveConfigCallback_ to provide smarter configuration 36 file saving behavior specifically designed for Lightning IR workflows. It only saves YAML 37 configuration files during the 'fit' stage and when a logger is properly configured, preventing 38 unnecessary file creation during inference operations like indexing, searching, or re-ranking. 39 40 The callback automatically saves the complete experiment configuration including model, data, 41 trainer, and optimizer settings to enable full experiment reproducibility. 42 43 .. _SaveConfigCallback: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.SaveConfigCallback.html 44 45 Examples: 46 Automatic usage through LightningIRCLI: 47 48 .. code-block:: python 49 50 from lightning_ir.main import LightningIRCLI, LightningIRSaveConfigCallback 51 52 # The callback is automatically configured in the CLI 53 cli = LightningIRCLI( 54 save_config_callback=LightningIRSaveConfigCallback, 55 save_config_kwargs={"config_filename": "pl_config.yaml", "overwrite": True} 56 ) 57 58 Manual usage with trainer: 59 60 .. code-block:: python 61 62 from lightning_ir import LightningIRTrainer, LightningIRSaveConfigCallback 63 64 # Add callback to trainer 65 callback = LightningIRSaveConfigCallback( 66 config_filename="experiment_config.yaml", 67 overwrite=True 68 ) 69 trainer = LightningIRTrainer(callbacks=[callback]) 70 71 Configuration file output example: 72 73 .. code-block:: yaml 74 75 # Generated pl_config.yaml 76 model: 77 class_path: lightning_ir.BiEncoderModule 78 init_args: 79 model_name_or_path: bert-base-uncased 80 data: 81 class_path: lightning_ir.LightningIRDataModule 82 init_args: 83 train_batch_size: 32 84 trainer: 85 max_steps: 100000 86 """ 87
[docs] 88 @override 89 def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: 90 """Setup the callback with intelligent save conditions. 91 92 This method implements the core logic for conditional configuration saving. It only 93 proceeds with configuration file saving when both conditions are met: 94 1. The training stage is 'fit' (not inference stages like index/search/re_rank) 95 2. A logger is properly configured on the trainer 96 97 This prevents unnecessary configuration file creation during inference operations 98 while ensuring that training experiments are properly documented for reproducibility. 99 100 Args: 101 trainer (Trainer): The Lightning trainer instance containing training configuration 102 and logger settings. 103 pl_module (LightningModule): The Lightning module instance being trained or used 104 for inference. 105 stage (str): The current training stage. Expected values include 'fit', 'validate', 106 'test', 'predict', as well as Lightning IR specific stages like 'index', 107 'search', 're_rank'. 108 109 Examples: 110 The method automatically handles different stages: 111 112 .. code-block:: python 113 114 # During training - config will be saved 115 trainer.fit(module, datamodule) # stage='fit', saves config 116 117 # During inference - config will NOT be saved 118 trainer.index(module, datamodule) # stage='index', skips saving 119 trainer.search(module, datamodule) # stage='search', skips saving 120 trainer.re_rank(module, datamodule) # stage='re_rank', skips saving 121 """ 122 if stage != "fit" or trainer.logger is None: 123 return 124 super().setup(trainer, pl_module, stage)
125 126
[docs] 127class LightningIRWandbLogger(WandbLogger): 128 """Lightning IR extension of the Weights & Biases Logger for enhanced experiment tracking. 129 130 This logger extends the PyTorch Lightning WandbLogger_ to provide improved file management 131 and experiment tracking specifically tailored for Lightning IR experiments. It ensures that 132 experiment files are properly saved in the WandB run's files directory and handles the 133 save directory management correctly. 134 135 .. _WandbLogger: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.WandbLogger.html 136 """ 137 138 @property 139 def save_dir(self) -> str | None: 140 """Gets the save directory for experiment files and artifacts. 141 142 This property returns the directory where WandB saves experiment files, logs, and 143 artifacts. It handles the case where the experiment might not be properly initialized 144 (DummyExperiment) and returns None in such cases to prevent errors. 145 146 Returns: 147 str | None: The absolute path to the WandB experiment directory where files 148 are saved, or None if the experiment is not properly initialized 149 or WandB is running in offline/disabled mode. 150 """ 151 if isinstance(self.experiment, DummyExperiment): 152 return None 153 return self.experiment.dir
154 155
[docs] 156class LightningIRTrainer(Trainer): 157 """Lightning IR Trainer that extends PyTorch Lightning Trainer with information retrieval specific methods. 158 159 This trainer inherits all functionality from the PyTorch Lightning Trainer_ and adds specialized methods 160 for information retrieval tasks including document indexing, searching, and re-ranking. It provides a 161 unified interface for both training neural ranking models and performing inference across different 162 IR stages. 163 164 The trainer seamlessly integrates with Lightning IR callbacks and supports all standard Lightning features 165 including distributed training, mixed precision, gradient accumulation, and checkpointing. 166 167 .. _PyTorch Lightning Trainer: https://lightning.ai/docs/pytorch/stable/common/trainer.html 168 169 Examples: 170 Basic usage for fine-tuning and inference: 171 172 .. code-block:: python 173 174 from lightning_ir import LightningIRTrainer, BiEncoderModule, LightningIRDataModule 175 176 # Initialize trainer with Lightning configuration 177 trainer = LightningIRTrainer( 178 max_steps=100_000, 179 precision="16-mixed", 180 devices=2, 181 accelerator="gpu" 182 ) 183 184 # Fine-tune a model 185 module = BiEncoderModule(model_name_or_path="bert-base-uncased") 186 datamodule = LightningIRDataModule(...) 187 trainer.fit(module, datamodule) 188 189 # Index documents 190 trainer.index(module, datamodule) 191 192 # Search for relevant documents 193 trainer.search(module, datamodule) 194 195 # Re-rank retrieved documents 196 trainer.re_rank(module, datamodule) 197 198 Note: 199 The trainer requires appropriate callbacks to be configured for each IR task: 200 - IndexCallback for indexing operations 201 - SearchCallback for search operations 202 - ReRankCallback for re-ranking operations 203 """ 204 205 # TODO check that correct callbacks are registered for each subcommand 206
[docs] 207 def index( 208 self, 209 model: LightningModule | None = None, 210 dataloaders: Any | LightningDataModule | None = None, 211 ckpt_path: str | Path | None = None, 212 verbose: bool = True, 213 datamodule: LightningDataModule | None = None, 214 ) -> list[Mapping[str, float]]: 215 """Index a collection of documents using a fine-tuned bi-encoder model. 216 217 This method performs document indexing by running inference on a document collection and 218 storing the resulting embeddings in an index structure. It requires an IndexCallback to 219 be configured in the trainer to handle the actual indexing process. 220 221 Args: 222 model (LightningModule | None): The LightningIRModule containing the bi-encoder model 223 to use for encoding documents. If None, uses the model from the datamodule. 224 dataloaders (Any | LightningDataModule | None): DataLoader(s) or LightningIRDataModule 225 containing the document collection to index. Should contain DocDataset instances. 226 ckpt_path (str | Path | None): Path to a model checkpoint to load before indexing. 227 If None, uses the current model state. 228 verbose (bool): Whether to display progress during indexing. Defaults to True. 229 datamodule (LightningDataModule | None): LightningIRDataModule instance. Alternative 230 to passing dataloaders directly. 231 232 Returns: 233 list[Mapping[str, float]]: list of dictionaries containing indexing metrics and results. 234 235 Example: 236 .. code-block:: python 237 238 from lightning_ir import LightningIRTrainer, BiEncoderModule, LightningIRDataModule 239 from lightning_ir import IndexCallback, TorchDenseIndexConfig, DocDataset 240 241 # Setup trainer with index callback 242 callback = IndexCallback( 243 index_dir="./index", 244 index_config=TorchDenseIndexConfig() 245 ) 246 trainer = LightningIRTrainer(callbacks=[callback]) 247 248 # Setup model and data 249 module = BiEncoderModule(model_name_or_path="webis/bert-bi-encoder") 250 datamodule = LightningIRDataModule( 251 inference_datasets=[DocDataset("msmarco-passage")] 252 ) 253 254 # Index the documents 255 trainer.index(module, datamodule) 256 257 Note: 258 - Requires IndexCallback to be configured in trainer callbacks 259 - Only works with bi-encoder models that can encode documents 260 - The index type and configuration are specified in the IndexCallback 261 """ 262 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
263
[docs] 264 def search( 265 self, 266 model: LightningModule | None = None, 267 dataloaders: Any | LightningDataModule | None = None, 268 ckpt_path: str | Path | None = None, 269 verbose: bool = True, 270 datamodule: LightningDataModule | None = None, 271 ) -> list[Mapping[str, float]]: 272 """Search for relevant documents using a bi-encoder model and pre-built index. 273 274 This method performs dense or sparse retrieval by encoding queries and searching through 275 a pre-built index to find the most relevant documents. It requires a SearchCallback to 276 be configured in the trainer to handle the search process and optionally a RankCallback 277 to save results. 278 279 Args: 280 model (LightningModule | None): The LightningIRModule containing the bi-encoder model 281 to use for encoding queries. If None, uses the model from the datamodule. 282 dataloaders (Any | LightningDataModule | None): DataLoader(s) or LightningIRDataModule 283 containing the queries to search for. Should contain QueryDataset instances. 284 ckpt_path (str | Path | None): Path to a model checkpoint to load before searching. 285 If None, uses the current model state. 286 verbose (bool): Whether to display progress during searching. Defaults to True. 287 datamodule (LightningDataModule | None): LightningIRDataModule instance. Alternative 288 to passing dataloaders directly. 289 290 Returns: 291 list[Mapping[str, float]]: list of dictionaries containing search metrics and effectiveness 292 results (if relevance judgments are available). 293 294 Example: 295 .. code-block:: python 296 297 from lightning_ir import LightningIRTrainer, BiEncoderModule, LightningIRDataModule 298 from lightning_ir import SearchCallback, RankCallback, QueryDataset 299 from lightning_ir import TorchDenseSearchConfig 300 301 # Setup trainer with search and rank callbacks 302 search_callback = SearchCallback( 303 index_dir="./index", 304 search_config=TorchDenseSearchConfig(k=100) 305 ) 306 rank_callback = RankCallback(results_dir="./results") 307 trainer = LightningIRTrainer(callbacks=[search_callback, rank_callback]) 308 309 # Setup model and data 310 module = BiEncoderModule(model_name_or_path="webis/bert-bi-encoder") 311 datamodule = LightningIRDataModule( 312 inference_datasets=[QueryDataset("trec-dl-2019/queries")] 313 ) 314 315 # Search for relevant documents 316 results = trainer.search(module, datamodule) 317 318 Note: 319 - Requires SearchCallback to be configured in trainer callbacks 320 - Index must be built beforehand using the index() method 321 - Search configuration must match the index configuration used during indexing 322 - Add RankCallback to save search results to disk 323 """ 324 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
325
[docs] 326 def re_rank( 327 self, 328 model: LightningModule | None = None, 329 dataloaders: Any | LightningDataModule | None = None, 330 ckpt_path: str | Path | None = None, 331 verbose: bool = True, 332 datamodule: LightningDataModule | None = None, 333 ) -> list[Mapping[str, float]]: 334 """Re-rank a set of retrieved documents using bi-encoder or cross-encoder models. 335 336 This method performs re-ranking by scoring query-document pairs and reordering them 337 based on relevance scores. Cross-encoders typically provide higher effectiveness for 338 re-ranking tasks compared to bi-encoders. It requires a ReRankCallback to be configured 339 in the trainer to handle saving the re-ranked results. 340 341 Args: 342 model (LightningModule | None): The LightningIRModule containing the model to use for 343 re-ranking. Can be either BiEncoderModule or CrossEncoderModule. If None, uses 344 the model from the datamodule. 345 dataloaders (Any | LightningDataModule | None): DataLoader(s) or LightningIRDataModule 346 containing the query-document pairs to re-rank. Should contain RunDataset instances. 347 ckpt_path (str | Path | None): Path to a model checkpoint to load before re-ranking. 348 If None, uses the current model state. 349 verbose (bool): Whether to display progress during re-ranking. Defaults to True. 350 datamodule (LightningDataModule | None): LightningIRDataModule instance. Alternative 351 to passing dataloaders directly. 352 353 Returns: 354 list[Mapping[str, float]]: list of dictionaries containing re-ranking metrics and 355 effectiveness results (if relevance judgments are available). 356 357 Example: 358 .. code-block:: python 359 360 from lightning_ir import LightningIRTrainer, CrossEncoderModule, LightningIRDataModule 361 from lightning_ir import ReRankCallback, RunDataset 362 363 # Setup trainer with re-rank callback 364 rerank_callback = ReRankCallback(results_dir="./reranked_results") 365 trainer = LightningIRTrainer(callbacks=[rerank_callback]) 366 367 # Setup model and data 368 module = CrossEncoderModule(model_name_or_path="webis/bert-cross-encoder") 369 datamodule = LightningIRDataModule( 370 inference_datasets=[RunDataset("path/to/run/file.txt")] 371 ) 372 373 # Re-rank the documents 374 results = trainer.re_rank(module, datamodule) 375 376 Note: 377 - Requires ReRankCallback to be configured in trainer callbacks 378 - Input data should be in run file format (query-document pairs with initial scores) 379 - Cross-encoders typically provide better effectiveness than bi-encoders for re-ranking 380 """ 381 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
382 383
[docs] 384class LightningIRCLI(LightningCLI): 385 """Lightning IR Command Line Interface that extends PyTorch LightningCLI_ for information retrieval tasks. 386 387 This CLI provides a unified command-line interface for fine-tuning neural ranking models and running 388 information retrieval experiments. It extends the PyTorch LightningCLI_ with IR-specific subcommands 389 and automatic configuration management for seamless integration between models, data, and training. 390 391 .. _LightningCLI: https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html 392 393 Examples: 394 Command line usage: 395 396 .. code-block:: bash 397 398 # Fine-tune a model 399 lightning-ir fit --config fine-tune.yaml 400 401 # Index documents 402 lightning-ir index --config index.yaml 403 404 # Search for documents 405 lightning-ir search --config search.yaml 406 407 # Re-rank documents 408 lightning-ir re_rank --config re-rank.yaml 409 410 # Generate default configuration 411 lightning-ir fit --print_config > config.yaml 412 413 Programmatic usage: 414 415 .. code-block:: python 416 417 from lightning_ir.main import LightningIRCLI, LightningIRTrainer, LightningIRSaveConfigCallback 418 419 # Create CLI instance 420 cli = LightningIRCLI( 421 trainer_class=LightningIRTrainer, 422 save_config_callback=LightningIRSaveConfigCallback, 423 save_config_kwargs={"config_filename": "pl_config.yaml", "overwrite": True} 424 ) 425 426 YAML configuration example: 427 428 .. code-block:: yaml 429 430 model: 431 class_path: lightning_ir.BiEncoderModule 432 init_args: 433 model_name_or_path: bert-base-uncased 434 loss_functions: 435 - class_path: lightning_ir.InBatchCrossEntropy 436 437 data: 438 class_path: lightning_ir.LightningIRDataModule 439 init_args: 440 train_dataset: 441 class_path: lightning_ir.TupleDataset 442 init_args: 443 dataset_id: msmarco-passage/train/triples-small 444 train_batch_size: 32 445 446 trainer: 447 max_steps: 100000 448 precision: "16-mixed" 449 450 optimizer: 451 class_path: torch.optim.AdamW 452 init_args: 453 lr: 5e-5 454 455 Note: 456 - Automatically links model and data configurations (model_name_or_path, config) 457 - Links trainer max_steps to learning rate scheduler num_training_steps 458 - Supports all PyTorch Lightning CLI features including class path instantiation 459 - Built-in support for warmup learning rate schedulers 460 - Saves configuration files automatically during training 461 """ 462
[docs] 463 @staticmethod 464 def configure_optimizers( 465 lightning_module: LightningModule, 466 optimizer: torch.optim.Optimizer, 467 lr_scheduler: WarmupLRScheduler | None = None, 468 ) -> Any: 469 """Configure optimizers and learning rate schedulers for Lightning training. 470 471 This method automatically configures the optimizer and learning rate scheduler combination 472 for Lightning training. It handles warmup learning rate schedulers by setting the 473 appropriate interval and returning the correct format expected by Lightning. 474 475 Args: 476 lightning_module (LightningModule): The Lightning module being trained. 477 optimizer (torch.optim.Optimizer): The optimizer instance to use for training. 478 lr_scheduler (WarmupLRScheduler | None): Optional warmup learning rate scheduler. 479 If None, only the optimizer is returned. 480 481 Returns: 482 Any: Either the optimizer alone (if no scheduler) or a tuple of optimizers and 483 schedulers list in Lightning's expected format. 484 485 Note: 486 - Warmup schedulers automatically set the correct interval based on scheduler type 487 - Returns format compatible with Lightning's configure_optimizers method 488 """ 489 if lr_scheduler is None: 490 return optimizer 491 492 return [optimizer], [ 493 {"scheduler": lr_scheduler, "interval": lr_scheduler.interval} 494 ]
495
[docs] 496 def add_arguments_to_parser(self, parser): 497 """Add Lightning IR specific arguments and links to the CLI parser. 498 499 This method extends the base Lightning CLI parser with IR-specific learning rate 500 schedulers and automatically links related configuration arguments to ensure 501 consistency between model, data, and trainer configurations. 502 503 Args: 504 parser: The CLI argument parser to extend. 505 506 Note: 507 Automatic argument linking: 508 - model.init_args.model_name_or_path -> data.init_args.model_name_or_path 509 - model.init_args.config -> data.init_args.config 510 - trainer.max_steps -> lr_scheduler.init_args.num_training_steps 511 """ 512 parser.add_lr_scheduler_args(tuple(LR_SCHEDULERS)) 513 parser.link_arguments( 514 "model.init_args.model_name_or_path", "data.init_args.model_name_or_path" 515 ) 516 parser.link_arguments("model.init_args.config", "data.init_args.config") 517 parser.link_arguments( 518 "trainer.max_steps", "lr_scheduler.init_args.num_training_steps" 519 )
520
[docs] 521 @staticmethod 522 def subcommands() -> dict[str, set[str]]: 523 """Defines the list of available subcommands and the arguments to skip. 524 525 Returns a dictionary mapping subcommand names to the set of configuration sections 526 they require. This extends the base Lightning CLI with IR-specific subcommands for 527 indexing, searching, and re-ranking operations. 528 529 Returns: 530 dict[str, set[str]]: Dictionary mapping subcommand names to required config sections. 531 - fit: Standard Lightning training subcommand with all sections 532 - index: Document indexing requiring model, dataloaders, and datamodule 533 - search: Document search requiring model, dataloaders, and datamodule 534 - re_rank: Document re-ranking requiring model, dataloaders, and datamodule 535 536 """ 537 return { 538 "fit": LightningCLI.subcommands()["fit"], 539 "index": {"model", "dataloaders", "datamodule"}, 540 "search": {"model", "dataloaders", "datamodule"}, 541 "re_rank": {"model", "dataloaders", "datamodule"}, 542 }
543 544 def _add_configure_optimizers_method_to_model(self, subcommand: str | None) -> None: 545 import warnings 546 547 with warnings.catch_warnings(): 548 warnings.simplefilter("ignore") 549 return super()._add_configure_optimizers_method_to_model(subcommand)
550 551
[docs] 552def main(): 553 """Entry point for the Lightning IR command line interface. 554 555 Initializes and runs the LightningIRCLI with the LightningIRTrainer and configuration 556 callback. This function serves as the main entry point when Lightning IR is run from 557 the command line using the 'lightning-ir' command. 558 559 The CLI is configured with: 560 - LightningIRTrainer as the trainer class for all IR operations 561 - LightningIRSaveConfigCallback for automatic config file saving during training 562 - Configuration to save configs as 'pl_config.yaml' with overwrite enabled 563 564 Examples: 565 This function is called when using Lightning IR from command line: 566 567 .. code-block:: bash 568 569 lightning-ir fit --config fine-tune.yaml 570 lightning-ir index --config index.yaml 571 lightning-ir search --config search.yaml 572 lightning-ir re_rank --config re-rank.yaml 573 574 Note: 575 - Configuration files are automatically saved during fit operations 576 - All PyTorch Lightning CLI features are available 577 - Supports YAML configuration files and command line argument overrides 578 """ 579 help_epilog = dedent( 580 """\ 581 For full documentation, visit: https://lightning-ir.webis.de 582 """ 583 ) 584 585 LightningIRCLI( 586 trainer_class=LightningIRTrainer, 587 save_config_callback=LightningIRSaveConfigCallback, 588 save_config_kwargs={"config_filename": "pl_config.yaml", "overwrite": True}, 589 parser_kwargs={ 590 "epilog": help_epilog, 591 "formatter_class": argparse.RawDescriptionHelpFormatter, 592 }, 593 )
594 595 596if __name__ == "__main__": 597 main()