Source code for lightning_ir.modeling_utils.lm_head

 1from functools import partial
 2
 3import torch
 4from transformers import PretrainedConfig
 5from transformers.activations import get_activation
 6
 7
[docs] 8class LMHead(torch.nn.Module):
[docs] 9 def __init__( 10 self, 11 config: PretrainedConfig, 12 hidden_dim_key: str, 13 activation_key: str, 14 classifier_bias_key: str | None = None, 15 norm_bias_key: str | None = None, 16 ): 17 super().__init__() 18 dim = getattr(config, hidden_dim_key) 19 activation = getattr(config, activation_key) 20 classifier_bias = True if classifier_bias_key is None else getattr(config, classifier_bias_key, True) 21 norm_bias = True if norm_bias_key is None else getattr(config, norm_bias_key, True) 22 self.dense = torch.nn.Linear(dim, dim, bias=classifier_bias) 23 self.act = get_activation(activation) 24 self.norm = torch.nn.LayerNorm(dim, eps=1e-12, bias=norm_bias) 25 self.decoder = torch.nn.Linear(dim, config.vocab_size)
26
[docs] 27 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 28 return self.decoder(self.norm(self.act(self.dense(hidden_states))))
29 30 31MODEL_TYPE_TO_LM_HEAD = { 32 "bert": partial(LMHead, hidden_dim_key="hidden_size", activation_key="hidden_act"), 33 "distilbert": partial(LMHead, hidden_dim_key="hidden_size", activation_key="activation"), 34 "modernbert": partial( 35 LMHead, 36 hidden_dim_key="hidden_size", 37 activation_key="classifier_activation", 38 classifier_bias_key="classifier_bias", 39 norm_bias_key="norm_bias", 40 ), 41 "roberta": partial(LMHead, hidden_dim_key="hidden_size", activation_key="hidden_act"), 42} 43 44MODEL_TYPE_TO_STATE_DICT_KEY_MAPPING = { 45 "bert": { 46 "cls.predictions.transform.dense": "bert.projection.dense", 47 "cls.predictions.transform.LayerNorm": "bert.projection.norm", 48 "cls.predictions.decoder": "bert.projection.decoder", 49 "cls.predictions.bias": "bert.projection.decoder.bias", 50 }, 51 "distilbert": { 52 "vocab_transform": "distilbert.projection.dense", 53 "vocab_layer_norm": "distilbert.projection.norm", 54 "vocab_projector": "distilbert.projection.decoder", 55 }, 56 "modernbert": { 57 "head.dense": "model.projection.dense", 58 "head.norm": "model.projection.norm", 59 "decoder": "model.projection.decoder", 60 }, 61 "roberta": { 62 "lm_head.dense": "roberta.projection.dense", 63 "lm_head.layer_norm": "roberta.projection.norm", 64 "lm_head.decoder": "roberta.projection.decoder", 65 "lm_head.bias": "roberta.projection.decoder.bias", 66 }, 67}