Source code for lightning_ir.modeling_utils.embedding_post_processing
1import torch
2
3from ..base import LightningIRConfig
4
5
[docs]
6class Pooler(torch.nn.Module):
7 """Applies pooling to the embeddings based on the pooling strategy defined in the configuration."""
8
[docs]
9 def __init__(self, config: LightningIRConfig) -> None:
10 """Initializes the pooler.
11
12 Args:
13 config (LightningIRConfig): Configuration containing the pooling strategy to apply
14 """
15 super().__init__()
16 self.pooling_strategy = getattr(config, "pooling_strategy", None)
17
[docs]
18 def forward(self, embeddings: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
19 """Applies optional pooling to the embeddings.
20
21 Args:
22 embeddings (torch.Tensor): Query, document, or joint query-document embeddings
23 attention_mask (torch.Tensor | None): Query, document, or joint query-document attention mask
24 Returns:
25 torch.Tensor: (Optionally) pooled embeddings
26 Raises:
27 ValueError: If an unknown pooling strategy is passed
28 """
29 if self.pooling_strategy is None:
30 return embeddings
31 if self.pooling_strategy == "first":
32 return embeddings[:, [0]]
33 if self.pooling_strategy in ("sum", "mean"):
34 if attention_mask is not None:
35 embeddings = embeddings * attention_mask.unsqueeze(-1)
36 embeddings = embeddings.sum(dim=1, keepdim=True)
37 if self.pooling_strategy == "mean":
38 if attention_mask is not None:
39 embeddings = embeddings / attention_mask.sum(dim=1, keepdim=True).unsqueeze(-1)
40 return embeddings
41 if self.pooling_strategy == "max":
42 if attention_mask is not None:
43 embeddings = embeddings.masked_fill(~attention_mask.bool().unsqueeze(-1), float("-inf"))
44 return embeddings.amax(dim=1, keepdim=True)
45 raise ValueError(f"Unknown pooling strategy: {self.pooling_strategy}")
46
47
[docs]
48class Sparsifier(torch.nn.Module):
49 """Applies sparsification to the embeddings based on the sparsification strategy defined in the configuration."""
50
[docs]
51 def __init__(self, config: LightningIRConfig) -> None:
52 """Initializes the sparsifier.
53
54 Args:
55 config (LightningIRConfig): Configuration containing the sparsification strategy to apply
56 """
57 super().__init__()
58 self.sparsification_strategy = getattr(config, "sparsification_strategy", None)
59
[docs]
60 def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
61 """Applies optional sparsification to the embeddings.
62
63 Args:
64 embeddings (torch.Tensor): Query, document, or joint query-document embeddings
65 Returns:
66 torch.Tensor: (Optionally) sparsified embeddings
67 Raises:
68 ValueError: If an unknown sparsification strategy is passed
69 """
70 if self.sparsification_strategy is None:
71 return embeddings
72 if self.sparsification_strategy == "relu":
73 return torch.relu(embeddings)
74 if self.sparsification_strategy == "relu_log":
75 return torch.log1p(torch.relu(embeddings))
76 if self.sparsification_strategy == "relu_2xlog":
77 return torch.log1p(torch.log1p(torch.relu(embeddings)))
78 raise ValueError(f"Unknown sparsification strategy: {self.sparsification_strategy}")