Source code for lightning_ir.loss.listwise
1"""
2Listwise loss functions for the Lightning IR framework.
3
4This module contains loss functions that operate on entire lists of items,
5considering the ranking of all items simultaneously.
6"""
7
8from __future__ import annotations
9
10from typing import TYPE_CHECKING
11
12import torch
13
14from .base import ListwiseLossFunction
15
16if TYPE_CHECKING:
17 from ..base import LightningIROutput
18 from ..data import TrainBatch
19
20
[docs]
21class KLDivergence(ListwiseLossFunction):
22 """Kullback-Leibler Divergence loss for listwise ranking tasks.
23 Originally proposed in: `On Information and Sufficiency \
24 <https://projecteuclid.org/journals/annals-of-mathematical-statistics/volume-22/issue-1/On-Information-and-Sufficiency/10.1214/aoms/1177729694.full>`_
25 """
26
[docs]
27 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
28 """Compute the Kullback-Leibler Divergence loss.
29
30 Args:
31 output (LightningIROutput): The output from the model containing scores.
32 batch (TrainBatch): The training batch containing targets.
33 Returns:
34 torch.Tensor: The computed loss.
35 """
36 scores = self.process_scores(output)
37 targets = self.process_targets(scores, batch)
38 scores = torch.nn.functional.log_softmax(scores, dim=-1)
39 targets = torch.nn.functional.log_softmax(targets.to(scores), dim=-1)
40 loss = torch.nn.functional.kl_div(scores, targets, log_target=True, reduction="batchmean")
41 return loss
42
43
[docs]
44class PearsonCorrelation(ListwiseLossFunction):
45 """Pearson Correlation loss for listwise ranking tasks."""
46
[docs]
47 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
48 """Compute the Pearson Correlation loss.
49
50 Args:
51 output (LightningIROutput): The output from the model containing scores.
52 batch (TrainBatch): The training batch containing targets.
53 Returns:
54 torch.Tensor: The computed loss.
55 """
56 scores = self.process_scores(output)
57 targets = self.process_targets(scores, batch).to(scores)
58 centered_scores = scores - scores.mean(dim=-1, keepdim=True)
59 centered_targets = targets - targets.mean(dim=-1, keepdim=True)
60 pearson = torch.nn.functional.cosine_similarity(centered_scores, centered_targets, dim=-1)
61 loss = (1 - pearson).mean()
62 return loss
63
64
[docs]
65class InfoNCE(ListwiseLossFunction):
66 """InfoNCE loss for listwise ranking tasks.
67 Originally proposed in: `Representation Learning with Contrastive Predictive Coding \
68 <https://arxiv.org/abs/1807.03748>`_
69 """
70
[docs]
71 def compute_loss(self, output: LightningIROutput, batch: TrainBatch) -> torch.Tensor:
72 """Compute the InfoNCE loss.
73
74 Args:
75 output (LightningIROutput): The output from the model containing scores.
76 batch (TrainBatch): The training batch containing targets.
77 Returns:
78 torch.Tensor: The computed loss.
79 """
80 scores = self.process_scores(output)
81 targets = self.process_targets(scores, batch)
82 targets = targets.argmax(dim=1)
83 loss = torch.nn.functional.cross_entropy(scores, targets)
84 return loss