Decision Guide

This guide helps you navigate Lightning IR’s configuration space. It is structured as a series of decision trees: start with what you want to do, then follow the branches to pick the right model architecture, index type, loss function, and data format. Each section ends with concrete CLI and Python examples you can copy and adapt.

What Do You Want to Do?

Start here. Lightning IR supports four top-level workflows, exposed as sub-commands of the lightning-ir CLI and as methods on LightningIRTrainer.

                    ┌──────────────────────┐
                    │  What is your goal?  │
                    └──────────┬───────────┘
         ┌─────────────────────┼─────────────────────┐
         │                     │                     │
         ▼                     ▼                     ▼
┌──────────────────┐  ┌──────────────────┐  ┌──────────────────┐
│    Fine-Tune     │  │   Retrieve docs  │  │ Improve existing │
│     a model      │  │   from a large   │  │     rankings     │
│                  │  │    collection    │  │                  │
│      ► fit       │  │     ► index      │  │    ► re_rank     │
│                  │  │     ► search     │  │                  │
└──────────────────┘  └──────────────────┘  └──────────────────┘

The table below summarizes the key ingredients for each workflow.

Workflow

CLI Sub-command

Module Type

Dataset Type

Required Callback

Fine-tune a model

fit

BiEncoderModule or CrossEncoderModule

TupleDataset or RunDataset (train)

(none — optional ModelCheckpoint)

Index documents

index

BiEncoderModule

DocDataset

IndexCallback

Search (retrieve)

search

BiEncoderModule

QueryDataset

SearchCallback

Re-rank

re_rank

BiEncoderModule or CrossEncoderModule

RunDataset

ReRankCallback

Tip

A typical end-to-end pipeline chains several workflows:

  1. fit — Fine-tune a model

  2. index — Encode all documents into an index (bi-encoder only)

  3. search — Retrieve candidate documents for queries

  4. re_rank — Re-score candidates with a more powerful model (often a cross-encoder)

You can enter the pipeline at any point. For example, if you already have a fine-tuned model from the Model Zoo, skip straight to index or re_rank.

Which Model Architecture to Use?

This is usually the most impactful decision. The diagram below captures the main trade-offs.

Do you need to retrieve from a large collection (millions of docs)?
├── YES ─► Use a Bi-Encoder
│          │
│          ├── Need sparse / lexical matching with term expansion?
│          │   └── YES ─► SPLADE (SpladeConfig)
│          │
│          ├── Need highest bi-encoder quality via token-level matching?
│          │   └── YES ─► ColBERT (ColConfig)
│          │
│          └── Want simplest dense single-vector retrieval?
│              └── YES ─► DPR (DprConfig)
└── NO ─► You only need to re-rank an existing candidate list
          ├── Pointwise scoring (one doc at a time)?
          │   └── YES ─► MonoEncoder (MonoConfig)
          └── Listwise scoring (all candidates at once)?
              └── YES ─► SetEncoder (SetEncoderConfig)

Architecture Comparison

Architecture

Config Class

Encoding

Vectors

Retrieval

Re-ranking

Key Trade-off

DPR

DprConfig

Separate

Single dense

Fastest index & search; lower quality

SPLADE

SpladeConfig

Separate

Single sparse

Interpretable lexical matching; needs regularization

ColBERT

ColConfig

Separate

Multi dense

Best bi-encoder quality; larger index

MonoEncoder

MonoConfig

Joint

Highest quality; cannot index

SetEncoder

SetEncoderConfig

Joint (listwise)

Sees all candidates at once; highest re-rank quality

Note

Lightning IR also supports several other bi-encoder variants (CoilConfig, MvrConfig) and external models such as XTR for advanced use cases. See the Model page and the models API reference for a full list.

Quick Examples: Picking a Model

DPR bi-encoder — simplest dense retrieval:

# model-dpr.yaml
model:
  class_path: lightning_ir.BiEncoderModule
  init_args:
    model_name_or_path: bert-base-uncased
    config:
      class_path: lightning_ir.models.DprConfig
from lightning_ir import BiEncoderModule
from lightning_ir.models import DprConfig

module = BiEncoderModule(
    model_name_or_path="bert-base-uncased",
    config=DprConfig(),
)

ColBERT — multi-vector late interaction:

# model-colbert.yaml
model:
  class_path: lightning_ir.BiEncoderModule
  init_args:
    model_name_or_path: bert-base-uncased
    config:
      class_path: lightning_ir.models.ColConfig
      init_args:
        similarity_function: dot
        query_aggregation_function: sum
        query_expansion: true
        query_length: 32
        doc_length: 256
        normalization_strategy: l2
        embedding_dim: 128
        projection: linear_no_bias
        add_marker_tokens: true
from lightning_ir import BiEncoderModule
from lightning_ir.models import ColConfig

module = BiEncoderModule(
    model_name_or_path="bert-base-uncased",
    config=ColConfig(
        similarity_function="dot",
        query_aggregation_function="sum",
        query_expansion=True,
        query_length=32,
        doc_length=256,
        normalization_strategy="l2",
        embedding_dim=128,
        projection="linear_no_bias",
        add_marker_tokens=True,
    ),
)

SPLADE — learned sparse retrieval:

# model-splade.yaml
model:
  class_path: lightning_ir.BiEncoderModule
  init_args:
    model_name_or_path: bert-base-uncased
    config:
      class_path: lightning_ir.models.SpladeConfig
from lightning_ir import BiEncoderModule
from lightning_ir.models import SpladeConfig

module = BiEncoderModule(
    model_name_or_path="bert-base-uncased",
    config=SpladeConfig(),
)

Cross-encoder (MonoEncoder) — highest quality re-ranking:

# model-cross-encoder.yaml
model:
  class_path: lightning_ir.CrossEncoderModule
  init_args:
    model_name_or_path: webis/monoelectra-base
from lightning_ir import CrossEncoderModule

module = CrossEncoderModule(
    model_name_or_path="webis/monoelectra-base",
)

Which Index Type to Use?

Indexing is only relevant for bi-encoder models (cross-encoders score query–document pairs on the fly). The index type determines the speed–quality trade-off at search time.

What kind of bi-encoder embeddings do you have?
├── Dense single-vector (DPR)
│   │
│   ├── Small collection or need exact results?
│   │   └── TorchDenseIndexConfig or FaissFlatIndexConfig
│   │
│   ├── Large collection, approximate is OK?
│   │   └── FaissIVFIndexConfig (tune num_centroids)
│   │
│   └── Large collection, need compressed index?
│       └── FaissIVFPQIndexConfig (tune num_centroids, num_subquantizers)
├── Dense multi-vector (ColBERT)
│   │
│   ├── Small collection or prototyping?
│   │   └── TorchDenseIndexConfig
│   │
│   └── Large collection, production speed?
│       └── PlaidIndexConfig
└── Sparse (SPLADE, UniCOIL)
    ├── Simple inverted index?
    │   └── TorchSparseIndexConfig
    └── Fast approximate sparse retrieval?
        └── SeismicIndexConfig

Index Type Comparison

Index Config

Search Config

Speed

Memory

Exact?

Compatible Models

TorchDenseIndexConfig

TorchDenseSearchConfig

Slow

High

DPR, ColBERT

FaissFlatIndexConfig

FaissSearchConfig

Medium

High

DPR, ColBERT

FaissIVFIndexConfig

FaissSearchConfig

Fast

High

❌ (approx.)

DPR, ColBERT

FaissIVFPQIndexConfig

FaissSearchConfig

Fastest

Low

❌ (approx.)

DPR, ColBERT

PlaidIndexConfig

PlaidSearchConfig

Fast

Medium

❌ (approx.)

ColBERT only

TorchSparseIndexConfig

TorchSparseSearchConfig

Medium

Medium

SPLADE, UniCOIL

SeismicIndexConfig

SeismicSearchConfig

Fast

Medium

❌ (approx.)

SPLADE, UniCOIL

Important

The Search Config must match the Index Config used during indexing. You cannot build a FAISS index and search it with a Torch searcher, or vice-versa.

Quick Examples: Indexing & Searching

FAISS IVF index (approximate nearest-neighbor for large dense collections):

# index-faiss-ivf.yaml
trainer:
  logger: false
  callbacks:
    - class_path: lightning_ir.IndexCallback
      init_args:
        index_dir: ./my-index
        index_config:
          class_path: lightning_ir.FaissIVFIndexConfig
          init_args:
            num_centroids: 65536
model:
  class_path: lightning_ir.BiEncoderModule
  init_args:
    model_name_or_path: webis/bert-bi-encoder
data:
  class_path: lightning_ir.LightningIRDataModule
  init_args:
    inference_datasets:
      - class_path: lightning_ir.DocDataset
        init_args:
          doc_dataset: msmarco-passage
    inference_batch_size: 256
from lightning_ir import (
    BiEncoderModule, DocDataset, IndexCallback,
    LightningIRDataModule, LightningIRTrainer,
    FaissIVFIndexConfig,
)

module = BiEncoderModule(model_name_or_path="webis/bert-bi-encoder")
data_module = LightningIRDataModule(
    inference_datasets=[DocDataset("msmarco-passage")],
    inference_batch_size=256,
)
callback = IndexCallback(
    index_dir="./my-index",
    index_config=FaissIVFIndexConfig(num_centroids=65536),
)
trainer = LightningIRTrainer(
    callbacks=[callback], logger=False, enable_checkpointing=False
)
trainer.index(module, data_module)

Sparse index for SPLADE:

# index-sparse.yaml
trainer:
  logger: false
  callbacks:
    - class_path: lightning_ir.IndexCallback
      init_args:
        index_dir: ./splade-index
        index_config:
          class_path: lightning_ir.TorchSparseIndexConfig
model:
  class_path: lightning_ir.BiEncoderModule
  init_args:
    model_name_or_path: webis/splade  # hypothetical model
data:
  class_path: lightning_ir.LightningIRDataModule
  init_args:
    inference_datasets:
      - class_path: lightning_ir.DocDataset
        init_args:
          doc_dataset: msmarco-passage
    inference_batch_size: 256
from lightning_ir import (
    BiEncoderModule, DocDataset, IndexCallback,
    LightningIRDataModule, LightningIRTrainer,
    TorchSparseIndexConfig,
)

module = BiEncoderModule(model_name_or_path="webis/splade")
data_module = LightningIRDataModule(
    inference_datasets=[DocDataset("msmarco-passage")],
    inference_batch_size=256,
)
callback = IndexCallback(
    index_dir="./splade-index",
    index_config=TorchSparseIndexConfig(),
)
trainer = LightningIRTrainer(
    callbacks=[callback], logger=False, enable_checkpointing=False
)
trainer.index(module, data_module)

Which Loss Function to Use?

The choice of loss function depends on your training data format and your training objective.

What does your training data look like?
├── Triples (query, positive doc, negative doc) — TupleDataset
│   │
│   ├── Want a simple pairwise objective?
│   │   └── RankNet or ConstantMarginMSE
│   │
│   ├── Want contrastive learning with in-batch negatives?
│   │   └── InBatchCrossEntropy
│   │
│   └── Want to directly optimize a ranking metric (e.g., nDCG)?
│       └── ApproxNDCG or ApproxMRR
├── Ranked list with teacher scores — RunDataset (targets: score)
│   │
│   ├── Knowledge distillation from a teacher model?
│   │   └── KLDivergence or RankNet
│   │
│   └── Want to match teacher ranking distribution?
│       └── InfoNCE
└── Training a sparse model (SPLADE)?
    └── Add a regularization loss alongside your main loss:
        FLOPSRegularization (+ GenericConstantSchedulerWithLinearWarmup callback)

Loss Function Reference

Loss

Category

When to Use

RankNet

Pairwise

Default choice for training with triples (pos/neg pairs). Optimizes pairwise ranking accuracy.

ConstantMarginMSE

Pairwise

Pairwise MSE with a fixed margin between positive and negative scores.

SupervisedMarginMSE

Pairwise

Pairwise MSE where the margin is derived from relevance labels.

KLDivergence

Listwise

Knowledge distillation from a teacher model’s score distribution. Requires RunDataset with targets: score.

InfoNCE

Listwise

Contrastive loss over a list of scored candidates.

PearsonCorrelation

Listwise

Optimizes correlation between predicted and target scores.

ApproxNDCG

Approximate

Differentiable approximation of nDCG. Directly optimizes the target metric.

ApproxMRR

Approximate

Differentiable approximation of MRR.

ApproxRankMSE

Approximate

MSE on approximate rank positions.

InBatchCrossEntropy

In-batch

Uses other queries’ positives as negatives within a batch. Very effective with large batch sizes.

ScoreBasedInBatchCrossEntropy

In-batch

In-batch negatives weighted by teacher scores.

FLOPSRegularization

Regularization

Encourages sparsity in SPLADE embeddings. Always combine with a primary loss and a warmup scheduler.

L1Regularization

Regularization

L1 penalty on embedding values.

L2Regularization

Regularization

L2 penalty on embedding values.

Quick Example: Combining Losses for SPLADE

SPLADE models typically require a primary ranking loss plus a FLOPS regularization loss with a warmup schedule:

# splade-training.yaml
model:
  class_path: lightning_ir.BiEncoderModule
  init_args:
    model_name_or_path: bert-base-uncased
    config:
      class_path: lightning_ir.models.SpladeConfig
    loss_functions:
      - lightning_ir.InBatchCrossEntropy
      - class_path: lightning_ir.FLOPSRegularization
        init_args:
          query_weight: 0.0008
          doc_weight: 0.0006
trainer:
  max_steps: 100_000
  callbacks:
    - class_path: lightning_ir.GenericConstantSchedulerWithLinearWarmup
      init_args:
        keys:
          - loss_functions.1.query_weight
          - loss_functions.1.doc_weight
        num_warmup_steps: 20_000
        num_delay_steps: 50_000
from lightning_ir import (
    BiEncoderModule, LightningIRTrainer, LightningIRDataModule,
    TupleDataset, InBatchCrossEntropy, FLOPSRegularization,
    GenericConstantSchedulerWithLinearWarmup,
)
from lightning_ir.models import SpladeConfig
from torch.optim import AdamW

module = BiEncoderModule(
    model_name_or_path="bert-base-uncased",
    config=SpladeConfig(),
    loss_functions=[
        InBatchCrossEntropy(),
        FLOPSRegularization(query_weight=0.0008, doc_weight=0.0006),
    ],
)
module.set_optimizer(AdamW, lr=1e-5)

data_module = LightningIRDataModule(
    train_dataset=TupleDataset("msmarco-passage/train/triples-small"),
    train_batch_size=32,
)
scheduler = GenericConstantSchedulerWithLinearWarmup(
    keys=["loss_functions.1.query_weight", "loss_functions.1.doc_weight"],
    num_warmup_steps=20_000,
    num_delay_steps=50_000,
)
trainer = LightningIRTrainer(max_steps=100_000, callbacks=[scheduler])
trainer.fit(module, data_module)

Quick Example: Knowledge Distillation

To distill from a teacher model’s run file scores into a student model:

# distillation.yaml
model:
  class_path: lightning_ir.BiEncoderModule
  init_args:
    model_name_or_path: bert-base-uncased
    config:
      class_path: lightning_ir.models.DprConfig
    loss_functions:
      - lightning_ir.KLDivergence
data:
  class_path: lightning_ir.LightningIRDataModule
  init_args:
    train_dataset:
      class_path: lightning_ir.RunDataset
      init_args:
        run_path_or_id: msmarco-passage/train/rank-distillm/set-encoder
        depth: 50
        sample_size: 8
        sampling_strategy: random
        targets: score
    train_batch_size: 16
from lightning_ir import (
    BiEncoderModule, LightningIRTrainer, LightningIRDataModule,
    RunDataset, KLDivergence,
)
from lightning_ir.models import DprConfig
from torch.optim import AdamW

module = BiEncoderModule(
    model_name_or_path="bert-base-uncased",
    config=DprConfig(),
    loss_functions=[KLDivergence()],
)
module.set_optimizer(AdamW, lr=1e-5)

data_module = LightningIRDataModule(
    train_dataset=RunDataset(
        run_path_or_id="msmarco-passage/train/rank-distillm/set-encoder",
        depth=50,
        sample_size=8,
        sampling_strategy="random",
        targets="score",
    ),
    train_batch_size=16,
)
trainer = LightningIRTrainer(max_steps=50_000)
trainer.fit(module, data_module)

Which Dataset Format to Use?

Lightning IR provides four dataset classes. The right one depends on your workflow and the shape of your data.

What are you trying to do?
├── Fine-tune a model (fit)
│   │
│   ├── Have query + positive + negative triples?
│   │   └── TupleDataset
│   │       (uses an ir_datasets ID, e.g. "msmarco-passage/train/triples-small")
│   │
│   └── Have a run file with ranked docs and teacher scores?
│       └── RunDataset (targets: score, sampling_strategy: random)
├── Index documents (index)
│   └── DocDataset
│       (uses an ir_datasets ID, e.g. "msmarco-passage")
├── Search / retrieve (search)
│   └── QueryDataset
│       (uses an ir_datasets ID, e.g. "msmarco-passage/trec-dl-2019/judged")
└── Re-rank (re_rank)
    └── RunDataset
        (path to a TREC-format run file or an ir_datasets ID)

Dataset Class Reference

Dataset

Workflow

Description

TupleDataset

fit

Iterates over (query, doc₁, doc₂, …) tuples with relevance targets. Backed by ir_datasets.

RunDataset

fit, re_rank

Loads a ranked list of documents per query from a TREC-format run file or an ir_datasets ID. Key parameters: depth (max rank to load), sample_size (docs per query), sampling_strategy (top or random), targets (relevance or score).

DocDataset

index

Iterates over all documents in a collection. Backed by ir_datasets.

QueryDataset

search

Iterates over queries in a dataset split. Backed by ir_datasets.

Tip

When using RunDataset for training (knowledge distillation), set sampling_strategy: random so the model sees diverse negatives, and targets: score to use the teacher’s relevance scores.

When using RunDataset for re-ranking (inference), set sampling_strategy: top and increase depth / sample_size to cover the full candidate list.

End-to-End Recipes

The following recipes chain together the decisions above into complete, copy-pasteable pipelines. Each recipe shows both the CLI (YAML) and programmatic (Python) approach.

Recipe 1: DPR Dense Retrieval Pipeline

Goal: Fine-tune a simple dense bi-encoder, index a collection, search, then re-rank with a cross-encoder.

Step 1 — Fine-tune the DPR model:

lightning-ir fit --config recipe-dpr-fit.yaml
recipe-dpr-fit.yaml
trainer:
  max_steps: 100_000
  precision: bf16-mixed
  accumulate_grad_batches: 4
  gradient_clip_val: 1
model:
  class_path: lightning_ir.BiEncoderModule
  init_args:
    model_name_or_path: bert-base-uncased
    config:
      class_path: lightning_ir.models.DprConfig
    loss_functions:
      - lightning_ir.RankNet
data:
  class_path: lightning_ir.LightningIRDataModule
  init_args:
    train_dataset:
      class_path: lightning_ir.TupleDataset
      init_args:
        tuples_dataset: msmarco-passage/train/triples-small
    train_batch_size: 32
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 1e-5
recipe_dpr_fit.py
from torch.optim import AdamW
from lightning_ir import (
    BiEncoderModule, LightningIRDataModule,
    LightningIRTrainer, RankNet, TupleDataset,
)
from lightning_ir.models import DprConfig

module = BiEncoderModule(
    model_name_or_path="bert-base-uncased",
    config=DprConfig(),
    loss_functions=[RankNet()],
)
module.set_optimizer(AdamW, lr=1e-5)

data_module = LightningIRDataModule(
    train_dataset=TupleDataset("msmarco-passage/train/triples-small"),
    train_batch_size=32,
)
trainer = LightningIRTrainer(
    max_steps=100_000,
    precision="bf16-mixed",
    accumulate_grad_batches=4,
    gradient_clip_val=1,
)
trainer.fit(module, data_module)

Step 2 — Index the collection:

lightning-ir index --config recipe-dpr-index.yaml
recipe-dpr-index.yaml
trainer:
  logger: false
  callbacks:
    - class_path: lightning_ir.IndexCallback
      init_args:
        index_dir: ./msmarco-passage-index
        index_config:
          class_path: lightning_ir.FaissFlatIndexConfig
model:
  class_path: lightning_ir.BiEncoderModule
  init_args:
    model_name_or_path: ./my-dpr-checkpoint  # or a Model Zoo model
data:
  class_path: lightning_ir.LightningIRDataModule
  init_args:
    inference_datasets:
      - class_path: lightning_ir.DocDataset
        init_args:
          doc_dataset: msmarco-passage
    inference_batch_size: 256
recipe_dpr_index.py
from lightning_ir import (
    BiEncoderModule, DocDataset, IndexCallback,
    LightningIRDataModule, LightningIRTrainer,
    FaissFlatIndexConfig,
)

module = BiEncoderModule(model_name_or_path="./my-dpr-checkpoint")
data_module = LightningIRDataModule(
    inference_datasets=[DocDataset("msmarco-passage")],
    inference_batch_size=256,
)
callback = IndexCallback(
    index_dir="./msmarco-passage-index",
    index_config=FaissFlatIndexConfig(),
)
trainer = LightningIRTrainer(
    callbacks=[callback], logger=False, enable_checkpointing=False
)
trainer.index(module, data_module)

Step 3 — Search:

lightning-ir search --config recipe-dpr-search.yaml
recipe-dpr-search.yaml
trainer:
  logger: false
  callbacks:
    - class_path: lightning_ir.SearchCallback
      init_args:
        index_dir: ./msmarco-passage-index
        search_config:
          class_path: lightning_ir.FaissSearchConfig
          init_args:
            k: 100
        save_dir: ./runs
model:
  class_path: lightning_ir.BiEncoderModule
  init_args:
    model_name_or_path: ./my-dpr-checkpoint
    evaluation_metrics:
      - nDCG@10
data:
  class_path: lightning_ir.LightningIRDataModule
  init_args:
    inference_datasets:
      - class_path: lightning_ir.QueryDataset
        init_args:
          query_dataset: msmarco-passage/trec-dl-2019/judged
    inference_batch_size: 4
recipe_dpr_search.py
from lightning_ir import (
    BiEncoderModule, QueryDataset, SearchCallback,
    LightningIRDataModule, LightningIRTrainer,
    FaissSearchConfig,
)

module = BiEncoderModule(
    model_name_or_path="./my-dpr-checkpoint",
    evaluation_metrics=["nDCG@10"],
)
data_module = LightningIRDataModule(
    inference_datasets=[
        QueryDataset("msmarco-passage/trec-dl-2019/judged"),
    ],
    inference_batch_size=4,
)
callback = SearchCallback(
    index_dir="./msmarco-passage-index",
    search_config=FaissSearchConfig(k=100),
    save_dir="./runs",
)
trainer = LightningIRTrainer(
    callbacks=[callback], logger=False, enable_checkpointing=False
)
trainer.search(module, data_module)

Step 4 — Re-rank with a cross-encoder:

lightning-ir re_rank --config recipe-dpr-rerank.yaml
recipe-dpr-rerank.yaml
trainer:
  logger: false
  callbacks:
    - class_path: lightning_ir.ReRankCallback
      init_args:
        save_dir: ./re-ranked-runs
model:
  class_path: lightning_ir.CrossEncoderModule
  init_args:
    model_name_or_path: webis/monoelectra-base
    evaluation_metrics:
      - nDCG@10
data:
  class_path: lightning_ir.LightningIRDataModule
  init_args:
    inference_datasets:
      - class_path: lightning_ir.RunDataset
        init_args:
          run_path_or_id: ./runs/msmarco-passage-trec-dl-2019-judged.run
    inference_batch_size: 4
recipe_dpr_rerank.py
from lightning_ir import (
    CrossEncoderModule, RunDataset, ReRankCallback,
    LightningIRDataModule, LightningIRTrainer,
)

module = CrossEncoderModule(
    model_name_or_path="webis/monoelectra-base",
    evaluation_metrics=["nDCG@10"],
)
data_module = LightningIRDataModule(
    inference_datasets=[
        RunDataset("./runs/msmarco-passage-trec-dl-2019-judged.run"),
    ],
    inference_batch_size=4,
)
callback = ReRankCallback(save_dir="./re-ranked-runs")
trainer = LightningIRTrainer(
    callbacks=[callback], logger=False, enable_checkpointing=False
)
trainer.re_rank(module, data_module)

Recipe 2: SPLADE Sparse Retrieval Pipeline

Goal: Train a SPLADE model with proper regularization, build a sparse index, and retrieve.

Step 1 — Fine-tune SPLADE with FLOPS regularization:

lightning-ir fit --config recipe-splade-fit.yaml
recipe-splade-fit.yaml
trainer:
  max_steps: 100_000
  precision: bf16-mixed
  callbacks:
    - class_path: lightning_ir.GenericConstantSchedulerWithLinearWarmup
      init_args:
        keys:
          - loss_functions.1.query_weight
          - loss_functions.1.doc_weight
        num_warmup_steps: 20_000
        num_delay_steps: 50_000
model:
  class_path: lightning_ir.BiEncoderModule
  init_args:
    model_name_or_path: bert-base-uncased
    config:
      class_path: lightning_ir.models.SpladeConfig
    loss_functions:
      - lightning_ir.InBatchCrossEntropy
      - class_path: lightning_ir.FLOPSRegularization
        init_args:
          query_weight: 0.0008
          doc_weight: 0.0006
data:
  class_path: lightning_ir.LightningIRDataModule
  init_args:
    train_dataset:
      class_path: lightning_ir.TupleDataset
      init_args:
        tuples_dataset: msmarco-passage/train/triples-small
    train_batch_size: 32
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 1e-5
recipe_splade_fit.py
from torch.optim import AdamW
from lightning_ir import (
    BiEncoderModule, LightningIRDataModule, LightningIRTrainer,
    TupleDataset, InBatchCrossEntropy, FLOPSRegularization,
    GenericConstantSchedulerWithLinearWarmup,
)
from lightning_ir.models import SpladeConfig

module = BiEncoderModule(
    model_name_or_path="bert-base-uncased",
    config=SpladeConfig(),
    loss_functions=[
        InBatchCrossEntropy(),
        FLOPSRegularization(query_weight=0.0008, doc_weight=0.0006),
    ],
)
module.set_optimizer(AdamW, lr=1e-5)

data_module = LightningIRDataModule(
    train_dataset=TupleDataset("msmarco-passage/train/triples-small"),
    train_batch_size=32,
)
scheduler = GenericConstantSchedulerWithLinearWarmup(
    keys=[
        "loss_functions.1.query_weight",
        "loss_functions.1.doc_weight",
    ],
    num_warmup_steps=20_000,
    num_delay_steps=50_000,
)
trainer = LightningIRTrainer(
    max_steps=100_000,
    precision="bf16-mixed",
    callbacks=[scheduler],
)
trainer.fit(module, data_module)

Step 2 — Build a sparse index:

lightning-ir index --config recipe-splade-index.yaml
recipe-splade-index.yaml
trainer:
  logger: false
  callbacks:
    - class_path: lightning_ir.IndexCallback
      init_args:
        index_dir: ./splade-index
        index_config:
          class_path: lightning_ir.TorchSparseIndexConfig
model:
  class_path: lightning_ir.BiEncoderModule
  init_args:
    model_name_or_path: ./my-splade-checkpoint
data:
  class_path: lightning_ir.LightningIRDataModule
  init_args:
    inference_datasets:
      - class_path: lightning_ir.DocDataset
        init_args:
          doc_dataset: msmarco-passage
    inference_batch_size: 256
recipe_splade_index.py
from lightning_ir import (
    BiEncoderModule, DocDataset, IndexCallback,
    LightningIRDataModule, LightningIRTrainer,
    TorchSparseIndexConfig,
)

module = BiEncoderModule(model_name_or_path="./my-splade-checkpoint")
data_module = LightningIRDataModule(
    inference_datasets=[DocDataset("msmarco-passage")],
    inference_batch_size=256,
)
callback = IndexCallback(
    index_dir="./splade-index",
    index_config=TorchSparseIndexConfig(),
)
trainer = LightningIRTrainer(
    callbacks=[callback], logger=False, enable_checkpointing=False
)
trainer.index(module, data_module)

Step 3 — Sparse search:

lightning-ir search --config recipe-splade-search.yaml
recipe-splade-search.yaml
trainer:
  logger: false
  callbacks:
    - class_path: lightning_ir.SearchCallback
      init_args:
        index_dir: ./splade-index
        search_config:
          class_path: lightning_ir.TorchSparseSearchConfig
          init_args:
            k: 100
        save_dir: ./runs
model:
  class_path: lightning_ir.BiEncoderModule
  init_args:
    model_name_or_path: ./my-splade-checkpoint
    evaluation_metrics:
      - nDCG@10
data:
  class_path: lightning_ir.LightningIRDataModule
  init_args:
    inference_datasets:
      - class_path: lightning_ir.QueryDataset
        init_args:
          query_dataset: msmarco-passage/trec-dl-2019/judged
    inference_batch_size: 4
recipe_splade_search.py
from lightning_ir import (
    BiEncoderModule, QueryDataset, SearchCallback,
    LightningIRDataModule, LightningIRTrainer,
    TorchSparseSearchConfig,
)

module = BiEncoderModule(
    model_name_or_path="./my-splade-checkpoint",
    evaluation_metrics=["nDCG@10"],
)
data_module = LightningIRDataModule(
    inference_datasets=[
        QueryDataset("msmarco-passage/trec-dl-2019/judged"),
    ],
    inference_batch_size=4,
)
callback = SearchCallback(
    index_dir="./splade-index",
    search_config=TorchSparseSearchConfig(k=100),
    save_dir="./runs",
)
trainer = LightningIRTrainer(
    callbacks=[callback], logger=False, enable_checkpointing=False
)
trainer.search(module, data_module)

Recipe 3: ColBERT Multi-Vector Pipeline

Goal: Fine-tune a ColBERT model, build a PLAID index for fast retrieval, and search.

Step 1 — Fine-tune ColBERT:

lightning-ir fit --config recipe-colbert-fit.yaml
recipe-colbert-fit.yaml
trainer:
  max_steps: 100_000
  precision: bf16-mixed
  accumulate_grad_batches: 4
  gradient_clip_val: 1
model:
  class_path: lightning_ir.BiEncoderModule
  init_args:
    model_name_or_path: bert-base-uncased
    config:
      class_path: lightning_ir.models.ColConfig
      init_args:
        similarity_function: dot
        query_aggregation_function: sum
        query_expansion: true
        query_length: 32
        doc_length: 256
        normalization_strategy: l2
        embedding_dim: 128
        projection: linear_no_bias
        add_marker_tokens: true
    loss_functions:
      - lightning_ir.RankNet
data:
  class_path: lightning_ir.LightningIRDataModule
  init_args:
    train_dataset:
      class_path: lightning_ir.TupleDataset
      init_args:
        tuples_dataset: msmarco-passage/train/triples-small
    train_batch_size: 32
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 1e-5
recipe_colbert_fit.py
from torch.optim import AdamW
from lightning_ir import (
    BiEncoderModule, LightningIRDataModule,
    LightningIRTrainer, RankNet, TupleDataset,
)
from lightning_ir.models import ColConfig

module = BiEncoderModule(
    model_name_or_path="bert-base-uncased",
    config=ColConfig(
        similarity_function="dot",
        query_aggregation_function="sum",
        query_expansion=True,
        query_length=32,
        doc_length=256,
        normalization_strategy="l2",
        embedding_dim=128,
        projection="linear_no_bias",
        add_marker_tokens=True,
    ),
    loss_functions=[RankNet()],
)
module.set_optimizer(AdamW, lr=1e-5)

data_module = LightningIRDataModule(
    train_dataset=TupleDataset("msmarco-passage/train/triples-small"),
    train_batch_size=32,
)
trainer = LightningIRTrainer(
    max_steps=100_000,
    precision="bf16-mixed",
    accumulate_grad_batches=4,
    gradient_clip_val=1,
)
trainer.fit(module, data_module)

Step 2 — Build a PLAID index:

lightning-ir index --config recipe-colbert-index.yaml
recipe-colbert-index.yaml
trainer:
  logger: false
  callbacks:
    - class_path: lightning_ir.IndexCallback
      init_args:
        index_dir: ./colbert-index
        index_config:
          class_path: lightning_ir.PlaidIndexConfig
model:
  class_path: lightning_ir.BiEncoderModule
  init_args:
    model_name_or_path: ./my-colbert-checkpoint
data:
  class_path: lightning_ir.LightningIRDataModule
  init_args:
    inference_datasets:
      - class_path: lightning_ir.DocDataset
        init_args:
          doc_dataset: msmarco-passage
    inference_batch_size: 256
recipe_colbert_index.py
from lightning_ir import (
    BiEncoderModule, DocDataset, IndexCallback,
    LightningIRDataModule, LightningIRTrainer,
    PlaidIndexConfig,
)

module = BiEncoderModule(model_name_or_path="./my-colbert-checkpoint")
data_module = LightningIRDataModule(
    inference_datasets=[DocDataset("msmarco-passage")],
    inference_batch_size=256,
)
callback = IndexCallback(
    index_dir="./colbert-index",
    index_config=PlaidIndexConfig(),
)
trainer = LightningIRTrainer(
    callbacks=[callback], logger=False, enable_checkpointing=False
)
trainer.index(module, data_module)

Step 3 — Search with PLAID:

lightning-ir search --config recipe-colbert-search.yaml
recipe-colbert-search.yaml
trainer:
  logger: false
  callbacks:
    - class_path: lightning_ir.SearchCallback
      init_args:
        index_dir: ./colbert-index
        search_config:
          class_path: lightning_ir.PlaidSearchConfig
          init_args:
            k: 100
        save_dir: ./runs
model:
  class_path: lightning_ir.BiEncoderModule
  init_args:
    model_name_or_path: ./my-colbert-checkpoint
    evaluation_metrics:
      - nDCG@10
data:
  class_path: lightning_ir.LightningIRDataModule
  init_args:
    inference_datasets:
      - class_path: lightning_ir.QueryDataset
        init_args:
          query_dataset: msmarco-passage/trec-dl-2019/judged
    inference_batch_size: 4
recipe_colbert_search.py
from lightning_ir import (
    BiEncoderModule, QueryDataset, SearchCallback,
    LightningIRDataModule, LightningIRTrainer,
    PlaidSearchConfig,
)

module = BiEncoderModule(
    model_name_or_path="./my-colbert-checkpoint",
    evaluation_metrics=["nDCG@10"],
)
data_module = LightningIRDataModule(
    inference_datasets=[
        QueryDataset("msmarco-passage/trec-dl-2019/judged"),
    ],
    inference_batch_size=4,
)
callback = SearchCallback(
    index_dir="./colbert-index",
    search_config=PlaidSearchConfig(k=100),
    save_dir="./runs",
)
trainer = LightningIRTrainer(
    callbacks=[callback], logger=False, enable_checkpointing=False
)
trainer.search(module, data_module)

Quick Reference: Compatibility

Use this table as a cheat sheet when composing configurations.

Model Config

Module

Compatible Index

Compatible Search

Supported Workflows

DprConfig

BiEncoderModule

TorchDenseIndexConfig, FaissFlatIndexConfig, FaissIVFIndexConfig, FaissIVFPQIndexConfig

TorchDenseSearchConfig, FaissSearchConfig

fit, index, search, re_rank

SpladeConfig

BiEncoderModule

TorchSparseIndexConfig, SeismicIndexConfig

TorchSparseSearchConfig, SeismicSearchConfig

fit, index, search, re_rank

ColConfig

BiEncoderModule

TorchDenseIndexConfig, FaissFlatIndexConfig, FaissIVFIndexConfig, FaissIVFPQIndexConfig, PlaidIndexConfig

TorchDenseSearchConfig, FaissSearchConfig, PlaidSearchConfig

fit, index, search, re_rank

MonoConfig

CrossEncoderModule

fit, re_rank

SetEncoderConfig

CrossEncoderModule

fit, re_rank