Source code for lightning_ir.loss.neural
1"""
2Neural sorting-based loss functions for the Lightning IR framework.
3
4This module contains loss functions that use neural sorting techniques
5to compute differentiable ranking-based losses.
6"""
7
8from __future__ import annotations
9
10import torch
11
12from .base import ListwiseLossFunction
13
14
[docs]
15class NeuralLossFunction(ListwiseLossFunction):
16 """Base class for neural loss functions that compute ranks from scores using neural sorting."""
17
18 # TODO add neural loss functions
19
[docs]
20 def __init__(self, temperature: float = 1, tol: float = 1e-5, max_iter: int = 50) -> None:
21 """Initialize the NeuralLossFunction.
22
23 Args:
24 temperature (float): Temperature parameter for scaling the scores. Defaults to 1.
25 tol (float): Tolerance for convergence. Defaults to 1e-5.
26 max_iter (int): Maximum number of iterations for convergence. Defaults to 50.
27 """
28 super().__init__()
29 self.temperature = temperature
30 self.tol = tol
31 self.max_iter = max_iter
32
[docs]
33 def neural_sort(self, scores: torch.Tensor) -> torch.Tensor:
34 """Compute the neural sort permutation matrix from scores.
35
36 Args:
37 scores (torch.Tensor): The input scores tensor.
38 Returns:
39 torch.Tensor: The computed permutation matrix.
40 """
41 # https://github.com/ermongroup/neuralsort/blob/master/pytorch/neuralsort.py
42 scores = scores.unsqueeze(-1)
43 dim = scores.shape[1]
44 one = torch.ones((dim, 1), device=scores.device)
45
46 A_scores = torch.abs(scores - scores.permute(0, 2, 1))
47 B = torch.matmul(A_scores, torch.matmul(one, torch.transpose(one, 0, 1)))
48 scaling = dim + 1 - 2 * (torch.arange(dim, device=scores.device) + 1)
49 C = torch.matmul(scores, scaling.to(scores).unsqueeze(0))
50
51 P_max = (C - B).permute(0, 2, 1)
52 P_hat = torch.nn.functional.softmax(P_max / self.temperature, dim=-1)
53
54 P_hat = self.sinkhorn_scaling(P_hat)
55
56 return P_hat
57
[docs]
58 def sinkhorn_scaling(self, mat: torch.Tensor) -> torch.Tensor:
59 """Apply Sinkhorn scaling to the permutation matrix.
60
61 Args:
62 mat (torch.Tensor): The input permutation matrix.
63 Returns:
64 torch.Tensor: The scaled permutation matrix.
65 """
66 # https://github.com/allegro/allRank/blob/master/allrank/models/losses/loss_utils.py#L8
67 idx = 0
68 while True:
69 if (
70 torch.max(torch.abs(mat.sum(dim=2) - 1.0)) < self.tol
71 and torch.max(torch.abs(mat.sum(dim=1) - 1.0)) < self.tol
72 ) or idx > self.max_iter:
73 break
74 mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=1e-12)
75 mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=1e-12)
76 idx += 1
77
78 return mat
79
[docs]
80 def get_sorted_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
81 """Get the sorted targets based on the neural sort permutation matrix.
82
83 Args:
84 scores (torch.Tensor): The input scores tensor.
85 targets (torch.Tensor): The targets tensor.
86 Returns:
87 torch.Tensor: The sorted targets tensor.
88 """
89 permutation_matrix = self.neural_sort(scores)
90 pred_sorted_targets = torch.matmul(permutation_matrix, targets[..., None].to(permutation_matrix)).squeeze(-1)
91 return pred_sorted_targets