Source code for lightning_ir.loss.embedding

 1from __future__ import annotations
 2
 3from typing import TYPE_CHECKING
 4
 5import torch
 6
 7from .base import EmbeddingLossFunction
 8
 9if TYPE_CHECKING:
10    from ..bi_encoder import BiEncoderOutput
11
12
[docs] 13class ContrastiveLocalLoss(EmbeddingLossFunction): 14 """Loss function that computes a contrastive loss between a query and multiple document embeddings, such that only 15 one document embedding has a a high similarity to the query embedding, while all other document embeddings 16 have a low similarity. Originally proposed in: 17 `Multi-View Document Representation Learning for Open-Domain Dense Retrieval \ 18 <https://aclanthology.org/2022.acl-long.414/>`_""" 19
[docs] 20 def __init__(self, temperature: float = 1.0) -> None: 21 super().__init__() 22 self.temperature = temperature
23
[docs] 24 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor: 25 """Compute the loss based on the embeddings in the output. 26 27 Args: 28 output (BiEncoderOutput): The output from the model containing query and document embeddings. 29 Returns: 30 torch.Tensor: The computed loss. 31 """ 32 similarity = output.similarity 33 if similarity is None: 34 raise ValueError("Expected similarity in BiEncoderOutput") 35 targets = similarity.argmax(-1) 36 loss = torch.nn.functional.cross_entropy(similarity, targets) 37 return loss