Source code for lightning_ir.modeling_utils.embedding_post_processing

 1import torch
 2
 3from ..base import LightningIRConfig
 4
 5
[docs] 6class Pooler(torch.nn.Module): 7 """Applies pooling to the embeddings based on the pooling strategy defined in the configuration.""" 8
[docs] 9 def __init__(self, config: LightningIRConfig) -> None: 10 """Initializes the pooler. 11 12 Args: 13 config (LightningIRConfig): Configuration containing the pooling strategy to apply 14 """ 15 super().__init__() 16 self.pooling_strategy = getattr(config, "pooling_strategy", None)
17
[docs] 18 def forward(self, embeddings: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: 19 """Applies optional pooling to the embeddings. 20 21 Args: 22 embeddings (torch.Tensor): Query, document, or joint query-document embeddings 23 attention_mask (torch.Tensor | None): Query, document, or joint query-document attention mask 24 Returns: 25 torch.Tensor: (Optionally) pooled embeddings 26 Raises: 27 ValueError: If an unknown pooling strategy is passed 28 """ 29 if self.pooling_strategy is None: 30 return embeddings 31 if self.pooling_strategy == "first": 32 return embeddings[:, [0]] 33 if self.pooling_strategy in ("sum", "mean"): 34 if attention_mask is not None: 35 embeddings = embeddings * attention_mask.unsqueeze(-1) 36 embeddings = embeddings.sum(dim=1, keepdim=True) 37 if self.pooling_strategy == "mean": 38 if attention_mask is not None: 39 embeddings = embeddings / attention_mask.sum(dim=1, keepdim=True).unsqueeze(-1) 40 return embeddings 41 if self.pooling_strategy == "max": 42 if attention_mask is not None: 43 embeddings = embeddings.masked_fill(~attention_mask.bool().unsqueeze(-1), float("-inf")) 44 return embeddings.amax(dim=1, keepdim=True) 45 raise ValueError(f"Unknown pooling strategy: {self.pooling_strategy}")
46 47
[docs] 48class Sparsifier(torch.nn.Module): 49 """Applies sparsification to the embeddings based on the sparsification strategy defined in the configuration.""" 50
[docs] 51 def __init__(self, config: LightningIRConfig) -> None: 52 """Initializes the sparsifier. 53 54 Args: 55 config (LightningIRConfig): Configuration containing the sparsification strategy to apply 56 """ 57 super().__init__() 58 self.sparsification_strategy = getattr(config, "sparsification_strategy", None)
59
[docs] 60 def forward(self, embeddings: torch.Tensor) -> torch.Tensor: 61 """Applies optional sparsification to the embeddings. 62 63 Args: 64 embeddings (torch.Tensor): Query, document, or joint query-document embeddings 65 Returns: 66 torch.Tensor: (Optionally) sparsified embeddings 67 Raises: 68 ValueError: If an unknown sparsification strategy is passed 69 """ 70 if self.sparsification_strategy is None: 71 return embeddings 72 if self.sparsification_strategy == "relu": 73 return torch.relu(embeddings) 74 if self.sparsification_strategy == "relu_log": 75 return torch.log1p(torch.relu(embeddings)) 76 if self.sparsification_strategy == "relu_2xlog": 77 return torch.log1p(torch.log1p(torch.relu(embeddings))) 78 raise ValueError(f"Unknown sparsification strategy: {self.sparsification_strategy}")