ApproxMRR
- class lightning_ir.loss.approximate.ApproxMRR(temperature: float = 1)[source]
Bases:
ApproxLossFunctionApproximate Mean Reciprocal Rank (MRR) loss function for ranking tasks.
Mean Reciprocal Rank (MRR) is a metric used to evaluate ranking systems by focusing on the position of the first relevant result, making it ideal for tasks like question answering where the user wants one correct answer immediately. It assigns a score of 1/k, where k is the rank of the first relevant document; for example, if the correct result is at position 1, the score is 1, but if it is at position 10, the score drops to 0.1. The final MRR is simply the average of these reciprocal scores across all queries in the dataset. Approximate MRR replaces the non-differentiable discrete ranking operation with a smooth, differentiable surrogate function based on pairwise score comparisons, enabling the model to directly maximize the reciprocal rank of the relevant document via gradient descent.
- __init__(temperature: float = 1)[source]
Initialize the ApproxMRR loss function.
- Parameters:
temperature (float) – Temperature parameter for scaling the scores. Defaults to 1.
Methods
__init__([temperature])Initialize the ApproxMRR loss function.
compute_loss(output, batch)Compute the ApproxMRR loss.
get_mrr(ranks, targets[, k])Compute the Mean Reciprocal Rank (MRR) for the given ranks and targets.
- compute_loss(output: LightningIROutput, batch: TrainBatch) torch.Tensor[source]
Compute the ApproxMRR loss.
- Parameters:
output (LightningIROutput) – The output from the model containing scores.
batch (TrainBatch) – The training batch containing targets.
- Returns:
The computed loss.
- Return type:
torch.Tensor
- static get_mrr(ranks: Tensor, targets: Tensor, k: int | None = None) Tensor[source]
Compute the Mean Reciprocal Rank (MRR) for the given ranks and targets.
- Parameters:
ranks (torch.Tensor) – The ranks of the items.
targets (torch.Tensor) – The relevance scores of the items.
k (int | None) – Optional cutoff for the ranks. If provided, only computes MRR for the top k items.
- Returns:
The computed MRR values.
- Return type:
torch.Tensor