Source code for lightning_ir.modeling_utils.mlm_head

 1import torch
 2from transformers.activations import get_activation
 3from transformers.models.bert.modeling_bert import BertOnlyMLMHead
 4from transformers.models.distilbert.configuration_distilbert import DistilBertConfig
 5
 6
[docs] 7class DistilBertOnlyMLMHead(torch.nn.Module): 8
[docs] 9 def __init__(self, config: DistilBertConfig) -> None: 10 super().__init__() 11 self.activation = get_activation(config.activation) 12 self.vocab_transform = torch.nn.Linear(config.dim, config.dim) 13 self.vocab_layer_norm = torch.nn.LayerNorm(config.dim, eps=1e-12) 14 self.vocab_projector = torch.nn.Linear(config.dim, config.vocab_size)
15
[docs] 16 def forward(self, x: torch.Tensor) -> torch.Tensor: 17 x = self.vocab_transform(x) 18 x = self.activation(x) 19 x = self.vocab_layer_norm(x) 20 x = self.vocab_projector(x) 21 return x
22 23 24MODEL_TYPE_TO_LM_HEAD = { 25 "bert": BertOnlyMLMHead, 26 "distilbert": DistilBertOnlyMLMHead, 27} 28 29MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING = { 30 "bert": {"cls": "bert.projection"}, 31 "distilbert": { 32 "vocab_transform": "distilbert.projection.vocab_transform", 33 "vocab_layer_norm": "distilbert.projection.vocab_layer_norm", 34 "vocab_projector": "distilbert.projection.vocab_projector", 35 }, 36} 37 38# NOTE: In the output embeddings and tied weight keys the cls key has already been unified and replaced by the 39# projection key 40 41MODEL_TYPE_TO_OUTPUT_EMBEDDINGS = { 42 "bert": "predictions.decoder", 43 "distilbert": "vocab_projector", 44} 45 46MODEL_TYPE_TO_TIED_WEIGHTS_KEYS = { 47 "bert": ["predictions.decoder.bias", "predictions.decoder.weight"], 48 "distilbert": ["vocab_projector.bias", "vocab_projector.weight"], 49}