Source code for lightning_ir.loss.embedding
1from __future__ import annotations
2
3from typing import TYPE_CHECKING
4
5import torch
6
7from .base import EmbeddingLossFunction
8
9if TYPE_CHECKING:
10 from ..bi_encoder import BiEncoderOutput
11
12
[docs]
13class ContrastiveLocalLoss(EmbeddingLossFunction):
14 """Loss function that computes a contrastive loss between a query and multiple document embeddings, such that only
15 one document embedding has a a high similarity to the query embedding, while all other document embeddings
16 have a low similarity. Originally proposed in:
17 `Multi-View Document Representation Learning for Open-Domain Dense Retrieval \
18 <https://aclanthology.org/2022.acl-long.414/>`_"""
19
[docs]
20 def __init__(self, temperature: float = 1.0) -> None:
21 super().__init__()
22 self.temperature = temperature
23
[docs]
24 def compute_loss(self, output: BiEncoderOutput) -> torch.Tensor:
25 """Compute the loss based on the embeddings in the output.
26
27 Args:
28 output (BiEncoderOutput): The output from the model containing query and document embeddings.
29 Returns:
30 torch.Tensor: The computed loss.
31 """
32 similarity = output.similarity
33 if similarity is None:
34 raise ValueError("Expected similarity in BiEncoderOutput")
35 targets = similarity.argmax(-1)
36 loss = torch.nn.functional.cross_entropy(similarity, targets)
37 return loss