Source code for lightning_ir.loss.neural

 1"""
 2Neural sorting-based loss functions for the Lightning IR framework.
 3
 4This module contains loss functions that use neural sorting techniques
 5to compute differentiable ranking-based losses.
 6"""
 7
 8from __future__ import annotations
 9
10import torch
11
12from .base import ListwiseLossFunction
13
14
[docs] 15class NeuralLossFunction(ListwiseLossFunction): 16 """Base class for neural loss functions that compute ranks from scores using neural sorting.""" 17 18 # TODO add neural loss functions 19
[docs] 20 def __init__(self, temperature: float = 1, tol: float = 1e-5, max_iter: int = 50) -> None: 21 """Initialize the NeuralLossFunction. 22 23 Args: 24 temperature (float): Temperature parameter for scaling the scores. Defaults to 1. 25 tol (float): Tolerance for convergence. Defaults to 1e-5. 26 max_iter (int): Maximum number of iterations for convergence. Defaults to 50. 27 """ 28 super().__init__() 29 self.temperature = temperature 30 self.tol = tol 31 self.max_iter = max_iter
32
[docs] 33 def neural_sort(self, scores: torch.Tensor) -> torch.Tensor: 34 """Compute the neural sort permutation matrix from scores. 35 36 Args: 37 scores (torch.Tensor): The input scores tensor. 38 Returns: 39 torch.Tensor: The computed permutation matrix. 40 """ 41 # https://github.com/ermongroup/neuralsort/blob/master/pytorch/neuralsort.py 42 scores = scores.unsqueeze(-1) 43 dim = scores.shape[1] 44 one = torch.ones((dim, 1), device=scores.device) 45 46 A_scores = torch.abs(scores - scores.permute(0, 2, 1)) 47 B = torch.matmul(A_scores, torch.matmul(one, torch.transpose(one, 0, 1))) 48 scaling = dim + 1 - 2 * (torch.arange(dim, device=scores.device) + 1) 49 C = torch.matmul(scores, scaling.to(scores).unsqueeze(0)) 50 51 P_max = (C - B).permute(0, 2, 1) 52 P_hat = torch.nn.functional.softmax(P_max / self.temperature, dim=-1) 53 54 P_hat = self.sinkhorn_scaling(P_hat) 55 56 return P_hat
57
[docs] 58 def sinkhorn_scaling(self, mat: torch.Tensor) -> torch.Tensor: 59 """Apply Sinkhorn scaling to the permutation matrix. 60 61 Args: 62 mat (torch.Tensor): The input permutation matrix. 63 Returns: 64 torch.Tensor: The scaled permutation matrix. 65 """ 66 # https://github.com/allegro/allRank/blob/master/allrank/models/losses/loss_utils.py#L8 67 idx = 0 68 while True: 69 if ( 70 torch.max(torch.abs(mat.sum(dim=2) - 1.0)) < self.tol 71 and torch.max(torch.abs(mat.sum(dim=1) - 1.0)) < self.tol 72 ) or idx > self.max_iter: 73 break 74 mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=1e-12) 75 mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=1e-12) 76 idx += 1 77 78 return mat
79
[docs] 80 def get_sorted_targets(self, scores: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 81 """Get the sorted targets based on the neural sort permutation matrix. 82 83 Args: 84 scores (torch.Tensor): The input scores tensor. 85 targets (torch.Tensor): The targets tensor. 86 Returns: 87 torch.Tensor: The sorted targets tensor. 88 """ 89 permutation_matrix = self.neural_sort(scores) 90 pred_sorted_targets = torch.matmul(permutation_matrix, targets[..., None].to(permutation_matrix)).squeeze(-1) 91 return pred_sorted_targets