1from functools import partial
2
3import torch
4from transformers import PretrainedConfig
5from transformers.activations import get_activation
6
7
[docs]
8class LMHead(torch.nn.Module):
[docs]
9 def __init__(
10 self,
11 config: PretrainedConfig,
12 hidden_dim_key: str,
13 activation_key: str,
14 classifier_bias_key: str | None = None,
15 norm_bias_key: str | None = None,
16 ):
17 super().__init__()
18 dim = getattr(config, hidden_dim_key)
19 activation = getattr(config, activation_key)
20 classifier_bias = True if classifier_bias_key is None else getattr(config, classifier_bias_key, True)
21 norm_bias = True if norm_bias_key is None else getattr(config, norm_bias_key, True)
22 self.dense = torch.nn.Linear(dim, dim, bias=classifier_bias)
23 self.act = get_activation(activation)
24 self.norm = torch.nn.LayerNorm(dim, eps=1e-12, bias=norm_bias)
25 self.decoder = torch.nn.Linear(dim, config.vocab_size)
26
[docs]
27 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
28 return self.decoder(self.norm(self.act(self.dense(hidden_states))))
29
30
31MODEL_TYPE_TO_LM_HEAD = {
32 "bert": partial(LMHead, hidden_dim_key="hidden_size", activation_key="hidden_act"),
33 "distilbert": partial(LMHead, hidden_dim_key="hidden_size", activation_key="activation"),
34 "modernbert": partial(
35 LMHead,
36 hidden_dim_key="hidden_size",
37 activation_key="classifier_activation",
38 classifier_bias_key="classifier_bias",
39 norm_bias_key="norm_bias",
40 ),
41 "roberta": partial(LMHead, hidden_dim_key="hidden_size", activation_key="hidden_act"),
42}
43
44MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING = {
45 "bert": {
46 "cls.predictions.transform.dense": "bert.projection.dense",
47 "cls.predictions.transform.LayerNorm": "bert.projection.norm",
48 "cls.predictions.decoder": "bert.projection.decoder",
49 "cls.predictions.bias": "bert.projection.decoder.bias",
50 },
51 "distilbert": {
52 "vocab_transform": "distilbert.projection.dense",
53 "vocab_layer_norm": "distilbert.projection.norm",
54 "vocab_projector": "distilbert.projection.decoder",
55 },
56 "modernbert": {
57 "head.dense": "model.projection.dense",
58 "head.norm": "model.projection.norm",
59 "decoder": "model.projection.decoder",
60 },
61 "roberta": {
62 "lm_head.dense": "roberta.projection.dense",
63 "lm_head.layer_norm": "roberta.projection.norm",
64 "lm_head.decoder": "roberta.projection.decoder",
65 "lm_head.bias": "roberta.projection.decoder.bias",
66 },
67}