FLOPSRegularization
- class lightning_ir.loss.regularization.FLOPSRegularization(query_weight: float = 0.0001, doc_weight: float = 0.0001)[source]
Bases:
RegularizationLossFunctionFLOPS Regularization loss function for query and document embeddings.
FLOPS Regularization adds a penalty to the loss function that encourages the model to produce sparse embeddings, which can lead to more efficient inference by reducing the number of non-zero parameters. This is particularly beneficial for large-scale retrieval systems where computational efficiency is crucial. The FLOPS regularization term is designed to minimize the number of floating-point operations (FLOPS) required during inference by promoting sparsity in the embeddings, effectively encouraging the model to focus on the most important features while ignoring less relevant ones.
Originally proposed in: Minimizing FLOPS to Learn Efficient Sparse Representations
Methods
compute_loss(output)Compute the FLOPS regularization loss.
- compute_loss(output: BiEncoderOutput) torch.Tensor[source]
Compute the FLOPS regularization loss.
- Parameters:
output (BiEncoderOutput) – The output from the model containing query and document embeddings.
- Returns:
The computed loss.
- Return type:
torch.Tensor