Source code for lightning_ir.loss.listwise

 1"""
 2Listwise loss functions for the Lightning IR framework.
 3
 4This module contains loss functions that operate on entire lists of items,
 5considering the ranking of all items simultaneously.
 6"""
 7
 8from __future__ import annotations
 9
10from typing import TYPE_CHECKING
11
12import torch
13
14from .base import ListwiseLossFunction
15
16if TYPE_CHECKING:
17    from ..base import LightningIROutput
18    from ..data import TrainBatch
19
20
[docs] 21class KLDivergence(ListwiseLossFunction): 22 """Kullback-Leibler Divergence loss for listwise ranking tasks. 23 Originally proposed in: `On Information and Sufficiency \ 24 <https://projecteuclid.org/journals/annals-of-mathematical-statistics/volume-22/issue-1/On-Information-and-Sufficiency/10.1214/aoms/1177729694.full>`_ 25 """ 26
[docs] 27 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 28 """Compute the Kullback-Leibler Divergence loss. 29 30 Args: 31 output (LightningIROutput): The output from the model containing scores. 32 batch (TrainBatch): The training batch containing targets. 33 Returns: 34 torch.Tensor: The computed loss. 35 """ 36 scores = self.process_scores(output) 37 targets = self.process_targets(scores, batch) 38 scores = torch.nn.functional.log_softmax(scores, dim=-1) 39 targets = torch.nn.functional.log_softmax(targets.to(scores), dim=-1) 40 loss = torch.nn.functional.kl_div(scores, targets, log_target=True, reduction="batchmean") 41 return loss
42 43
[docs] 44class PearsonCorrelation(ListwiseLossFunction): 45 """Pearson Correlation loss for listwise ranking tasks.""" 46
[docs] 47 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 48 """Compute the Pearson Correlation loss. 49 50 Args: 51 output (LightningIROutput): The output from the model containing scores. 52 batch (TrainBatch): The training batch containing targets. 53 Returns: 54 torch.Tensor: The computed loss. 55 """ 56 scores = self.process_scores(output) 57 targets = self.process_targets(scores, batch).to(scores) 58 centered_scores = scores - scores.mean(dim=-1, keepdim=True) 59 centered_targets = targets - targets.mean(dim=-1, keepdim=True) 60 pearson = torch.nn.functional.cosine_similarity(centered_scores, centered_targets, dim=-1) 61 loss = (1 - pearson).mean() 62 return loss
63 64
[docs] 65class InfoNCE(ListwiseLossFunction): 66 """InfoNCE loss for listwise ranking tasks. 67 Originally proposed in: `Representation Learning with Contrastive Predictive Coding \ 68 <https://arxiv.org/abs/1807.03748>`_ 69 """ 70
[docs] 71 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor: 72 """Compute the InfoNCE loss. 73 74 Args: 75 output (LightningIROutput): The output from the model containing scores. 76 batch (TrainBatch): The training batch containing targets. 77 Returns: 78 torch.Tensor: The computed loss. 79 """ 80 scores = self.process_scores(output) 81 targets = self.process_targets(scores, batch) 82 targets = targets.argmax(dim=1) 83 loss = torch.nn.functional.cross_entropy(scores, targets) 84 return loss