Source code for lightning_ir.main

  1"""Main entry point for Lightning IR using the Lightning IR CLI.
  2
  3The module also defines several helper classes for configuring and running experiments.
  4"""
  5
  6import os
  7import sys
  8from pathlib import Path
  9from typing import Any, Dict, List, Mapping, Set
 10
 11import torch
 12from lightning import LightningDataModule, LightningModule, Trainer
 13from lightning.fabric.loggers.logger import _DummyExperiment as DummyExperiment
 14from lightning.pytorch.cli import LightningCLI, SaveConfigCallback
 15from lightning.pytorch.loggers import WandbLogger
 16from typing_extensions import override
 17
 18import lightning_ir  # noqa: F401
 19from lightning_ir.schedulers.lr_schedulers import LR_SCHEDULERS, WarmupLRScheduler
 20
 21if torch.cuda.is_available():
 22    torch.set_float32_matmul_precision("medium")
 23
 24sys.path.append(str(Path.cwd()))
 25
 26os.environ["TOKENIZERS_PARALLELISM"] = "false"
 27
 28
[docs] 29class LightningIRSaveConfigCallback(SaveConfigCallback): 30 """Lightning IR configuration saving callback with intelligent save conditions. 31 32 This callback extends PyTorch Lightning's SaveConfigCallback_ to provide smarter configuration 33 file saving behavior specifically designed for Lightning IR workflows. It only saves YAML 34 configuration files during the 'fit' stage and when a logger is properly configured, preventing 35 unnecessary file creation during inference operations like indexing, searching, or re-ranking. 36 37 The callback automatically saves the complete experiment configuration including model, data, 38 trainer, and optimizer settings to enable full experiment reproducibility. 39 40 .. _SaveConfigCallback: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.SaveConfigCallback.html 41 42 Examples: 43 Automatic usage through LightningIRCLI: 44 45 .. code-block:: python 46 47 from lightning_ir.main import LightningIRCLI, LightningIRSaveConfigCallback 48 49 # The callback is automatically configured in the CLI 50 cli = LightningIRCLI( 51 save_config_callback=LightningIRSaveConfigCallback, 52 save_config_kwargs={"config_filename": "pl_config.yaml", "overwrite": True} 53 ) 54 55 Manual usage with trainer: 56 57 .. code-block:: python 58 59 from lightning_ir import LightningIRTrainer, LightningIRSaveConfigCallback 60 61 # Add callback to trainer 62 callback = LightningIRSaveConfigCallback( 63 config_filename="experiment_config.yaml", 64 overwrite=True 65 ) 66 trainer = LightningIRTrainer(callbacks=[callback]) 67 68 Configuration file output example: 69 70 .. code-block:: yaml 71 72 # Generated pl_config.yaml 73 model: 74 class_path: lightning_ir.BiEncoderModule 75 init_args: 76 model_name_or_path: bert-base-uncased 77 data: 78 class_path: lightning_ir.LightningIRDataModule 79 init_args: 80 train_batch_size: 32 81 trainer: 82 max_steps: 100000 83 """ 84
[docs] 85 @override 86 def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: 87 """Setup the callback with intelligent save conditions. 88 89 This method implements the core logic for conditional configuration saving. It only 90 proceeds with configuration file saving when both conditions are met: 91 1. The training stage is 'fit' (not inference stages like index/search/re_rank) 92 2. A logger is properly configured on the trainer 93 94 This prevents unnecessary configuration file creation during inference operations 95 while ensuring that training experiments are properly documented for reproducibility. 96 97 Args: 98 trainer (Trainer): The Lightning trainer instance containing training configuration 99 and logger settings. 100 pl_module (LightningModule): The Lightning module instance being trained or used 101 for inference. 102 stage (str): The current training stage. Expected values include 'fit', 'validate', 103 'test', 'predict', as well as Lightning IR specific stages like 'index', 104 'search', 're_rank'. 105 106 Examples: 107 The method automatically handles different stages: 108 109 .. code-block:: python 110 111 # During training - config will be saved 112 trainer.fit(module, datamodule) # stage='fit', saves config 113 114 # During inference - config will NOT be saved 115 trainer.index(module, datamodule) # stage='index', skips saving 116 trainer.search(module, datamodule) # stage='search', skips saving 117 trainer.re_rank(module, datamodule) # stage='re_rank', skips saving 118 """ 119 if stage != "fit" or trainer.logger is None: 120 return 121 super().setup(trainer, pl_module, stage)
122 123
[docs] 124class LightningIRWandbLogger(WandbLogger): 125 """Lightning IR extension of the Weights & Biases Logger for enhanced experiment tracking. 126 127 This logger extends the PyTorch Lightning WandbLogger_ to provide improved file management 128 and experiment tracking specifically tailored for Lightning IR experiments. It ensures that 129 experiment files are properly saved in the WandB run's files directory and handles the 130 save directory management correctly. 131 132 .. _WandbLogger: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.WandbLogger.html 133 """ 134 135 @property 136 def save_dir(self) -> str | None: 137 """Gets the save directory for experiment files and artifacts. 138 139 This property returns the directory where WandB saves experiment files, logs, and 140 artifacts. It handles the case where the experiment might not be properly initialized 141 (DummyExperiment) and returns None in such cases to prevent errors. 142 143 Returns: 144 str | None: The absolute path to the WandB experiment directory where files 145 are saved, or None if the experiment is not properly initialized 146 or WandB is running in offline/disabled mode. 147 """ 148 super().save_dir 149 if isinstance(self.experiment, DummyExperiment): 150 return None 151 return self.experiment.dir
152 153
[docs] 154class LightningIRTrainer(Trainer): 155 """Lightning IR Trainer that extends PyTorch Lightning Trainer with information retrieval specific methods. 156 157 This trainer inherits all functionality from the PyTorch Lightning Trainer_ and adds specialized methods 158 for information retrieval tasks including document indexing, searching, and re-ranking. It provides a 159 unified interface for both training neural ranking models and performing inference across different 160 IR stages. 161 162 The trainer seamlessly integrates with Lightning IR callbacks and supports all standard Lightning features 163 including distributed training, mixed precision, gradient accumulation, and checkpointing. 164 165 .. _PyTorch Lightning Trainer: https://lightning.ai/docs/pytorch/stable/common/trainer.html 166 167 Examples: 168 Basic usage for fine-tuning and inference: 169 170 .. code-block:: python 171 172 from lightning_ir import LightningIRTrainer, BiEncoderModule, LightningIRDataModule 173 174 # Initialize trainer with Lightning configuration 175 trainer = LightningIRTrainer( 176 max_steps=100_000, 177 precision="16-mixed", 178 devices=2, 179 accelerator="gpu" 180 ) 181 182 # Fine-tune a model 183 module = BiEncoderModule(model_name_or_path="bert-base-uncased") 184 datamodule = LightningIRDataModule(...) 185 trainer.fit(module, datamodule) 186 187 # Index documents 188 trainer.index(module, datamodule) 189 190 # Search for relevant documents 191 trainer.search(module, datamodule) 192 193 # Re-rank retrieved documents 194 trainer.re_rank(module, datamodule) 195 196 Note: 197 The trainer requires appropriate callbacks to be configured for each IR task: 198 - IndexCallback for indexing operations 199 - SearchCallback for search operations 200 - ReRankCallback for re-ranking operations 201 """ 202 203 # TODO check that correct callbacks are registered for each subcommand 204
[docs] 205 def index( 206 self, 207 model: LightningModule | None = None, 208 dataloaders: Any | LightningDataModule | None = None, 209 ckpt_path: str | Path | None = None, 210 verbose: bool = True, 211 datamodule: LightningDataModule | None = None, 212 ) -> List[Mapping[str, float]]: 213 """Index a collection of documents using a fine-tuned bi-encoder model. 214 215 This method performs document indexing by running inference on a document collection and 216 storing the resulting embeddings in an index structure. It requires an IndexCallback to 217 be configured in the trainer to handle the actual indexing process. 218 219 Args: 220 model (LightningModule | None): The LightningIRModule containing the bi-encoder model 221 to use for encoding documents. If None, uses the model from the datamodule. 222 dataloaders (Any | LightningDataModule | None): DataLoader(s) or LightningIRDataModule 223 containing the document collection to index. Should contain DocDataset instances. 224 ckpt_path (str | Path | None): Path to a model checkpoint to load before indexing. 225 If None, uses the current model state. 226 verbose (bool): Whether to display progress during indexing. Defaults to True. 227 datamodule (LightningDataModule | None): LightningIRDataModule instance. Alternative 228 to passing dataloaders directly. 229 230 Returns: 231 List[Mapping[str, float]]: List of dictionaries containing indexing metrics and results. 232 233 Example: 234 .. code-block:: python 235 236 from lightning_ir import LightningIRTrainer, BiEncoderModule, LightningIRDataModule 237 from lightning_ir import IndexCallback, TorchDenseIndexConfig, DocDataset 238 239 # Setup trainer with index callback 240 callback = IndexCallback( 241 index_dir="./index", 242 index_config=TorchDenseIndexConfig() 243 ) 244 trainer = LightningIRTrainer(callbacks=[callback]) 245 246 # Setup model and data 247 module = BiEncoderModule(model_name_or_path="webis/bert-bi-encoder") 248 datamodule = LightningIRDataModule( 249 inference_datasets=[DocDataset("msmarco-passage")] 250 ) 251 252 # Index the documents 253 trainer.index(module, datamodule) 254 255 Note: 256 - Requires IndexCallback to be configured in trainer callbacks 257 - Only works with bi-encoder models that can encode documents 258 - The index type and configuration are specified in the IndexCallback 259 """ 260 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
261
[docs] 262 def search( 263 self, 264 model: LightningModule | None = None, 265 dataloaders: Any | LightningDataModule | None = None, 266 ckpt_path: str | Path | None = None, 267 verbose: bool = True, 268 datamodule: LightningDataModule | None = None, 269 ) -> List[Mapping[str, float]]: 270 """Search for relevant documents using a bi-encoder model and pre-built index. 271 272 This method performs dense or sparse retrieval by encoding queries and searching through 273 a pre-built index to find the most relevant documents. It requires a SearchCallback to 274 be configured in the trainer to handle the search process and optionally a RankCallback 275 to save results. 276 277 Args: 278 model (LightningModule | None): The LightningIRModule containing the bi-encoder model 279 to use for encoding queries. If None, uses the model from the datamodule. 280 dataloaders (Any | LightningDataModule | None): DataLoader(s) or LightningIRDataModule 281 containing the queries to search for. Should contain QueryDataset instances. 282 ckpt_path (str | Path | None): Path to a model checkpoint to load before searching. 283 If None, uses the current model state. 284 verbose (bool): Whether to display progress during searching. Defaults to True. 285 datamodule (LightningDataModule | None): LightningIRDataModule instance. Alternative 286 to passing dataloaders directly. 287 288 Returns: 289 List[Mapping[str, float]]: List of dictionaries containing search metrics and effectiveness 290 results (if relevance judgments are available). 291 292 Example: 293 .. code-block:: python 294 295 from lightning_ir import LightningIRTrainer, BiEncoderModule, LightningIRDataModule 296 from lightning_ir import SearchCallback, RankCallback, QueryDataset 297 from lightning_ir import TorchDenseSearchConfig 298 299 # Setup trainer with search and rank callbacks 300 search_callback = SearchCallback( 301 index_dir="./index", 302 search_config=TorchDenseSearchConfig(k=100) 303 ) 304 rank_callback = RankCallback(results_dir="./results") 305 trainer = LightningIRTrainer(callbacks=[search_callback, rank_callback]) 306 307 # Setup model and data 308 module = BiEncoderModule(model_name_or_path="webis/bert-bi-encoder") 309 datamodule = LightningIRDataModule( 310 inference_datasets=[QueryDataset("trec-dl-2019/queries")] 311 ) 312 313 # Search for relevant documents 314 results = trainer.search(module, datamodule) 315 316 Note: 317 - Requires SearchCallback to be configured in trainer callbacks 318 - Index must be built beforehand using the index() method 319 - Search configuration must match the index configuration used during indexing 320 - Add RankCallback to save search results to disk 321 """ 322 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
323
[docs] 324 def re_rank( 325 self, 326 model: LightningModule | None = None, 327 dataloaders: Any | LightningDataModule | None = None, 328 ckpt_path: str | Path | None = None, 329 verbose: bool = True, 330 datamodule: LightningDataModule | None = None, 331 ) -> List[Mapping[str, float]]: 332 """Re-rank a set of retrieved documents using bi-encoder or cross-encoder models. 333 334 This method performs re-ranking by scoring query-document pairs and reordering them 335 based on relevance scores. Cross-encoders typically provide higher effectiveness for 336 re-ranking tasks compared to bi-encoders. It requires a ReRankCallback to be configured 337 in the trainer to handle saving the re-ranked results. 338 339 Args: 340 model (LightningModule | None): The LightningIRModule containing the model to use for 341 re-ranking. Can be either BiEncoderModule or CrossEncoderModule. If None, uses 342 the model from the datamodule. 343 dataloaders (Any | LightningDataModule | None): DataLoader(s) or LightningIRDataModule 344 containing the query-document pairs to re-rank. Should contain RunDataset instances. 345 ckpt_path (str | Path | None): Path to a model checkpoint to load before re-ranking. 346 If None, uses the current model state. 347 verbose (bool): Whether to display progress during re-ranking. Defaults to True. 348 datamodule (LightningDataModule | None): LightningIRDataModule instance. Alternative 349 to passing dataloaders directly. 350 351 Returns: 352 List[Mapping[str, float]]: List of dictionaries containing re-ranking metrics and 353 effectiveness results (if relevance judgments are available). 354 355 Example: 356 .. code-block:: python 357 358 from lightning_ir import LightningIRTrainer, CrossEncoderModule, LightningIRDataModule 359 from lightning_ir import ReRankCallback, RunDataset 360 361 # Setup trainer with re-rank callback 362 rerank_callback = ReRankCallback(results_dir="./reranked_results") 363 trainer = LightningIRTrainer(callbacks=[rerank_callback]) 364 365 # Setup model and data 366 module = CrossEncoderModule(model_name_or_path="webis/bert-cross-encoder") 367 datamodule = LightningIRDataModule( 368 inference_datasets=[RunDataset("path/to/run/file.txt")] 369 ) 370 371 # Re-rank the documents 372 results = trainer.re_rank(module, datamodule) 373 374 Note: 375 - Requires ReRankCallback to be configured in trainer callbacks 376 - Input data should be in run file format (query-document pairs with initial scores) 377 - Cross-encoders typically provide better effectiveness than bi-encoders for re-ranking 378 """ 379 return super().test(model, dataloaders, ckpt_path, verbose, datamodule)
380 381
[docs] 382class LightningIRCLI(LightningCLI): 383 """Lightning IR Command Line Interface that extends PyTorch LightningCLI_ for information retrieval tasks. 384 385 This CLI provides a unified command-line interface for fine-tuning neural ranking models and running 386 information retrieval experiments. It extends the PyTorch LightningCLI_ with IR-specific subcommands 387 and automatic configuration management for seamless integration between models, data, and training. 388 389 .. _LightningCLI: https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html 390 391 Examples: 392 Command line usage: 393 394 .. code-block:: bash 395 396 # Fine-tune a model 397 lightning-ir fit --config fine-tune.yaml 398 399 # Index documents 400 lightning-ir index --config index.yaml 401 402 # Search for documents 403 lightning-ir search --config search.yaml 404 405 # Re-rank documents 406 lightning-ir re_rank --config re-rank.yaml 407 408 # Generate default configuration 409 lightning-ir fit --print_config > config.yaml 410 411 Programmatic usage: 412 413 .. code-block:: python 414 415 from lightning_ir.main import LightningIRCLI, LightningIRTrainer, LightningIRSaveConfigCallback 416 417 # Create CLI instance 418 cli = LightningIRCLI( 419 trainer_class=LightningIRTrainer, 420 save_config_callback=LightningIRSaveConfigCallback, 421 save_config_kwargs={"config_filename": "pl_config.yaml", "overwrite": True} 422 ) 423 424 YAML configuration example: 425 426 .. code-block:: yaml 427 428 model: 429 class_path: lightning_ir.BiEncoderModule 430 init_args: 431 model_name_or_path: bert-base-uncased 432 loss_functions: 433 - class_path: lightning_ir.InBatchCrossEntropy 434 435 data: 436 class_path: lightning_ir.LightningIRDataModule 437 init_args: 438 train_dataset: 439 class_path: lightning_ir.TupleDataset 440 init_args: 441 dataset_id: msmarco-passage/train/triples-small 442 train_batch_size: 32 443 444 trainer: 445 max_steps: 100000 446 precision: "16-mixed" 447 448 optimizer: 449 class_path: torch.optim.AdamW 450 init_args: 451 lr: 5e-5 452 453 Note: 454 - Automatically links model and data configurations (model_name_or_path, config) 455 - Links trainer max_steps to learning rate scheduler num_training_steps 456 - Supports all PyTorch Lightning CLI features including class path instantiation 457 - Built-in support for warmup learning rate schedulers 458 - Saves configuration files automatically during training 459 """ 460
[docs] 461 @staticmethod 462 def configure_optimizers( 463 lightning_module: LightningModule, 464 optimizer: torch.optim.Optimizer, 465 lr_scheduler: WarmupLRScheduler | None = None, 466 ) -> Any: 467 """Configure optimizers and learning rate schedulers for Lightning training. 468 469 This method automatically configures the optimizer and learning rate scheduler combination 470 for Lightning training. It handles warmup learning rate schedulers by setting the 471 appropriate interval and returning the correct format expected by Lightning. 472 473 Args: 474 lightning_module (LightningModule): The Lightning module being trained. 475 optimizer (torch.optim.Optimizer): The optimizer instance to use for training. 476 lr_scheduler (WarmupLRScheduler | None): Optional warmup learning rate scheduler. 477 If None, only the optimizer is returned. 478 479 Returns: 480 Any: Either the optimizer alone (if no scheduler) or a tuple of optimizers and 481 schedulers list in Lightning's expected format. 482 483 Note: 484 - Warmup schedulers automatically set the correct interval based on scheduler type 485 - Returns format compatible with Lightning's configure_optimizers method 486 """ 487 if lr_scheduler is None: 488 return optimizer 489 490 return [optimizer], [{"scheduler": lr_scheduler, "interval": lr_scheduler.interval}]
491
[docs] 492 def add_arguments_to_parser(self, parser): 493 """Add Lightning IR specific arguments and links to the CLI parser. 494 495 This method extends the base Lightning CLI parser with IR-specific learning rate 496 schedulers and automatically links related configuration arguments to ensure 497 consistency between model, data, and trainer configurations. 498 499 Args: 500 parser: The CLI argument parser to extend. 501 502 Note: 503 Automatic argument linking: 504 - model.init_args.model_name_or_path -> data.init_args.model_name_or_path 505 - model.init_args.config -> data.init_args.config 506 - trainer.max_steps -> lr_scheduler.init_args.num_training_steps 507 """ 508 parser.add_lr_scheduler_args(tuple(LR_SCHEDULERS)) 509 parser.link_arguments("model.init_args.model_name_or_path", "data.init_args.model_name_or_path") 510 parser.link_arguments("model.init_args.config", "data.init_args.config") 511 parser.link_arguments("trainer.max_steps", "lr_scheduler.init_args.num_training_steps")
512
[docs] 513 @staticmethod 514 def subcommands() -> Dict[str, Set[str]]: 515 """Defines the list of available subcommands and the arguments to skip. 516 517 Returns a dictionary mapping subcommand names to the set of configuration sections 518 they require. This extends the base Lightning CLI with IR-specific subcommands for 519 indexing, searching, and re-ranking operations. 520 521 Returns: 522 Dict[str, Set[str]]: Dictionary mapping subcommand names to required config sections. 523 - fit: Standard Lightning training subcommand with all sections 524 - index: Document indexing requiring model, dataloaders, and datamodule 525 - search: Document search requiring model, dataloaders, and datamodule 526 - re_rank: Document re-ranking requiring model, dataloaders, and datamodule 527 528 """ 529 return { 530 "fit": LightningCLI.subcommands()["fit"], 531 "index": {"model", "dataloaders", "datamodule"}, 532 "search": {"model", "dataloaders", "datamodule"}, 533 "re_rank": {"model", "dataloaders", "datamodule"}, 534 }
535 536 def _add_configure_optimizers_method_to_model(self, subcommand: str | None) -> None: 537 import warnings 538 539 with warnings.catch_warnings(): 540 warnings.simplefilter("ignore") 541 return super()._add_configure_optimizers_method_to_model(subcommand)
542 543
[docs] 544def main(): 545 """Entry point for the Lightning IR command line interface. 546 547 Initializes and runs the LightningIRCLI with the LightningIRTrainer and configuration 548 callback. This function serves as the main entry point when Lightning IR is run from 549 the command line using the 'lightning-ir' command. 550 551 The CLI is configured with: 552 - LightningIRTrainer as the trainer class for all IR operations 553 - LightningIRSaveConfigCallback for automatic config file saving during training 554 - Configuration to save configs as 'pl_config.yaml' with overwrite enabled 555 556 Examples: 557 This function is called when using Lightning IR from command line: 558 559 .. code-block:: bash 560 561 lightning-ir fit --config fine-tune.yaml 562 lightning-ir index --config index.yaml 563 lightning-ir search --config search.yaml 564 lightning-ir re_rank --config re-rank.yaml 565 566 Note: 567 - Configuration files are automatically saved during fit operations 568 - All PyTorch Lightning CLI features are available 569 - Supports YAML configuration files and command line argument overrides 570 """ 571 LightningIRCLI( 572 trainer_class=LightningIRTrainer, 573 save_config_callback=LightningIRSaveConfigCallback, 574 save_config_kwargs={"config_filename": "pl_config.yaml", "overwrite": True}, 575 )
576 577 578if __name__ == "__main__": 579 main()