ApproxRankMSE

class lightning_ir.loss.approximate.ApproxRankMSE(temperature: float = 1, discount: 'log2' | 'reciprocal' | None = None)[source]

Bases: ApproxLossFunction

Approximate Rank Mean Squared Error (RankMSE) loss function for ranking tasks.

Rank Mean Squared Error (RankMSE) penalizes the squared differences between predicted document ranks and their ground truth ranks. Because standard discrete sorting prevents gradient descent, Approximate RankMSE uses a smooth, differentiable approximation of these ranks. It computes the Mean Squared Error between the continuous approximate ranks and the true target ranks, optionally applying position-based discounting (like log2 or reciprocal weights) to penalize errors at the top of the list more heavily.

Originally proposed in: Rank-DistiLLM: Closing the Effectiveness Gap Between Cross-Encoders and LLMs for Passage Re-ranking

__init__(temperature: float = 1, discount: 'log2' | 'reciprocal' | None = None)[source]

Initialize the ApproxRankMSE loss function.

Parameters:
  • temperature (float) – Temperature parameter for scaling the scores. Defaults to 1.

  • discount (Literal["log2", "reciprocal"] | None) – Discounting strategy for the loss. If None, no discounting is applied. Defaults to None.

Methods

__init__([temperature, discount])

Initialize the ApproxRankMSE loss function.

compute_loss(output, batch)

Compute the ApproxRankMSE loss.

compute_loss(output: LightningIROutput, batch: TrainBatch) torch.Tensor[source]

Compute the ApproxRankMSE loss.

Parameters:
  • output (LightningIROutput) – The output from the model containing scores.

  • batch (TrainBatch) – The training batch containing targets.

Returns:

The computed loss.

Return type:

torch.Tensor