Source code for lightning_ir.loss.listwise

 1"""
 2Listwise loss functions for the Lightning IR framework.
 3
 4This module contains loss functions that operate on entire lists of items,
 5considering the ranking of all items simultaneously.
 6"""
 7
 8from __future__ import annotations
 9
10from typing import TYPE_CHECKING
11
12import torch
13
14from .base import ListwiseLossFunction
15
16if TYPE_CHECKING:
17    from ..base import LightningIROutput
18    from ..data import TrainBatch
19
20
[docs] 21class KLDivergence(ListwiseLossFunction): 22 """Kullback-Leibler Divergence loss for listwise ranking tasks. 23 24 KL Divergence loss for listwise ranking treats both the ground truth relevance labels and the predicted scores as 25 probability distributions over the entire list of items. The loss is computed by minimizing the divergence between 26 them to align the global ranking structure rather than just local comparisons. 27 28 Originally proposed in: `On Information and Sufficiency \ 29 <https://projecteuclid.org/journals/annals-of-mathematical-statistics/volume-22/issue-1/On-Information-and-Sufficiency/10.1214/aoms/1177729694.full>`_ 30 """ 31
[docs] 32 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 33 """Compute the Kullback-Leibler Divergence loss. 34 35 Args: 36 output (LightningIROutput): The output from the model containing scores. 37 batch (TrainBatch): The training batch containing targets. 38 Returns: 39 torch.Tensor: The computed loss. 40 """ 41 scores = self.process_scores(output) 42 targets = self.process_targets(scores, batch) 43 scores = torch.nn.functional.log_softmax(scores, dim=-1) 44 targets = torch.nn.functional.log_softmax(targets.to(scores), dim=-1) 45 loss = torch.nn.functional.kl_div(scores, targets, log_target=True, reduction="batchmean") 46 return loss
47 48
[docs] 49class PearsonCorrelation(ListwiseLossFunction): 50 """Pearson Correlation loss for listwise ranking tasks. 51 52 Pearson Correlation for listwise ranking maximizes the linear alignment between the vector of predicted scores and 53 the vector of ground truth relevance labels, ensuring that the relative trends across the entire list are preserved 54 regardless of the absolute scale of the scores. 55 """ 56
[docs] 57 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 58 """Compute the Pearson Correlation loss. 59 60 Args: 61 output (LightningIROutput): The output from the model containing scores. 62 batch (TrainBatch): The training batch containing targets. 63 Returns: 64 torch.Tensor: The computed loss. 65 """ 66 scores = self.process_scores(output) 67 targets = self.process_targets(scores, batch).to(scores) 68 centered_scores = scores - scores.mean(dim=-1, keepdim=True) 69 centered_targets = targets - targets.mean(dim=-1, keepdim=True) 70 pearson = torch.nn.functional.cosine_similarity(centered_scores, centered_targets, dim=-1) 71 loss = (1 - pearson).mean() 72 return loss
73 74
[docs] 75class InfoNCE(ListwiseLossFunction): 76 """InfoNCE loss for listwise ranking tasks. 77 78 Information Noise-Contrastive Estimation loss for listwise ranking adapts contrastive learning by treating the 79 relevant item as the positive signal and all other items in the list as negative noise, maximizing the likelihood 80 of the correct document relative to the entire candidate set via a softmax-normalized objective. 81 82 Originally proposed in: `Representation Learning with Contrastive Predictive Coding \ 83 <https://arxiv.org/abs/1807.03748>`_ 84 """ 85
[docs] 86 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 87 """Compute the InfoNCE loss. 88 89 Args: 90 output (LightningIROutput): The output from the model containing scores. 91 batch (TrainBatch): The training batch containing targets. 92 Returns: 93 torch.Tensor: The computed loss. 94 """ 95 scores = self.process_scores(output) 96 targets = self.process_targets(scores, batch) 97 targets = targets.argmax(dim=1) 98 loss = torch.nn.functional.cross_entropy(scores, targets) 99 return loss