Source code for lightning_ir.modeling_utils.mlm_head
1import torch
2from transformers.activations import get_activation
3from transformers.models.bert.modeling_bert import BertOnlyMLMHead
4from transformers.models.distilbert.configuration_distilbert import DistilBertConfig
5
6
[docs]
7class DistilBertOnlyMLMHead(torch.nn.Module):
8
[docs]
9 def __init__(self, config: DistilBertConfig) -> None:
10 super().__init__()
11 self.activation = get_activation(config.activation)
12 self.vocab_transform = torch.nn.Linear(config.dim, config.dim)
13 self.vocab_layer_norm = torch.nn.LayerNorm(config.dim, eps=1e-12)
14 self.vocab_projector = torch.nn.Linear(config.dim, config.vocab_size)
15
[docs]
16 def forward(self, x: torch.Tensor) -> torch.Tensor:
17 x = self.vocab_transform(x)
18 x = self.activation(x)
19 x = self.vocab_layer_norm(x)
20 x = self.vocab_projector(x)
21 return x
22
23
24MODEL_TYPE_TO_LM_HEAD = {
25 "bert": BertOnlyMLMHead,
26 "distilbert": DistilBertOnlyMLMHead,
27}
28
29MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING = {
30 "bert": {"cls": "bert.projection"},
31 "distilbert": {
32 "vocab_transform": "distilbert.projection.vocab_transform",
33 "vocab_layer_norm": "distilbert.projection.vocab_layer_norm",
34 "vocab_projector": "distilbert.projection.vocab_projector",
35 },
36}
37
38# NOTE: In the output embeddings and tied weight keys the cls key has already been unified and replaced by the
39# projection key
40
41MODEL_TYPE_TO_OUTPUT_EMBEDDINGS = {
42 "bert": "predictions.decoder",
43 "distilbert": "vocab_projector",
44}
45
46MODEL_TYPE_TO_TIED_WEIGHTS_KEYS = {
47 "bert": ["predictions.decoder.bias", "predictions.decoder.weight"],
48 "distilbert": ["vocab_projector.bias", "vocab_projector.weight"],
49}