Source code for lightning_ir.models.t5_cross_encoder

  1from typing import Dict, Literal, Sequence, Type
  2
  3import torch
  4from transformers import BatchEncoding
  5
  6from ..cross_encoder.cross_encoder_config import CrossEncoderConfig
  7from ..cross_encoder.cross_encoder_model import CrossEncoderModel, CrossEncoderOutput
  8from ..cross_encoder.cross_encoder_tokenizer import CrossEncoderTokenizer
  9
 10
[docs] 11class T5CrossEncoderConfig(CrossEncoderConfig): 12 13 model_type = "encoder-decoder-cross-encoder" 14
[docs] 15 def __init__( 16 self, 17 query_length: int = 32, 18 doc_length: int = 512, 19 decoder_strategy: Literal["mono", "rank"] = "mono", 20 **kwargs, 21 ) -> None: 22 kwargs["pooling_strategy"] = "first" 23 super().__init__(query_length=query_length, doc_length=doc_length, **kwargs) 24 self.decoder_strategy = decoder_strategy
25 26
[docs] 27class ScaleLinear(torch.nn.Linear): 28
[docs] 29 def forward(self, input: torch.Tensor) -> torch.Tensor: 30 # Rescale output before projecting on vocab 31 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa 32 input = input * (input.shape[-1] ** -0.5) 33 return super().forward(input)
34 35
[docs] 36class T5CrossEncoderModel(CrossEncoderModel): 37 config_class = T5CrossEncoderConfig 38 39 _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "linear.weight"] 40
[docs] 41 def __init__(self, config: T5CrossEncoderConfig, *args, **kwargs): 42 super().__init__(config, *args, **kwargs) 43 self.config: T5CrossEncoderConfig 44 if self.config.decoder_strategy == "mono": 45 self.linear = ScaleLinear(config.hidden_size, 2, bias=config.linear_bias) 46 else: 47 self.linear = ScaleLinear(config.hidden_size, 1, bias=config.linear_bias)
48 49 # TODO tieing of weights does not work when setting linear to only use slice of lm head for efficiency 50 # def get_output_embeddings(self): 51 # shared = self.shared 52 # if self.config.decoder_strategy == "mono": 53 # self.linear.weight.data = shared.weight.data[[1176, 6136]] 54 # elif self.config.decoder_strategy == "rank": 55 # self.linear.weight.data = shared.weight.data[[32089]] 56 # else: 57 # raise ValueError("Unknown decoder strategy") 58 # return shared 59
[docs] 60 def forward(self, encoding: BatchEncoding) -> CrossEncoderOutput: 61 decoder_input_ids = torch.zeros( 62 (encoding["input_ids"].shape[0], 1), device=encoding["input_ids"].device, dtype=torch.long 63 ) 64 encoding["decoder_input_ids"] = decoder_input_ids 65 output = super().forward(encoding) 66 if output.scores is None: 67 raise ValueError("Scores are None") 68 if self.config.decoder_strategy == "mono": 69 scores = output.scores.view(-1, 2) 70 scores = torch.nn.functional.log_softmax(scores, dim=-1)[:, 0] 71 output.scores = scores.view(-1) 72 return output
73 74
[docs] 75class T5CrossEncoderTokenizer(CrossEncoderTokenizer): 76 77 config_class: Type[T5CrossEncoderConfig] = T5CrossEncoderConfig 78
[docs] 79 def __init__( 80 self, 81 *args, 82 query_length: int = 32, 83 doc_length: int = 512, 84 decoder_strategy: Literal["mono", "rank"] = "mono", 85 **kwargs, 86 ): 87 super().__init__( 88 *args, query_length=query_length, doc_length=doc_length, decoder_strategy=decoder_strategy, **kwargs 89 ) 90 self.decoder_strategy = decoder_strategy
91
[docs] 92 def tokenize( 93 self, 94 queries: str | Sequence[str] | None = None, 95 docs: str | Sequence[str] | None = None, 96 num_docs: Sequence[int] | int | None = None, 97 **kwargs, 98 ) -> Dict[str, BatchEncoding]: 99 expanded_queries, docs = self._preprocess(queries, docs, num_docs) 100 if self.decoder_strategy == "mono": 101 pattern = "Query: {query} Document: {doc} Relevant:" 102 elif self.decoder_strategy == "rank": 103 pattern = "Query: {query} Document: {doc}" 104 else: 105 raise ValueError(f"Unknown decoder strategy: {self.decoder_strategy}") 106 input_texts = [pattern.format(query=query, doc=doc) for query, doc in zip(expanded_queries, docs)] 107 108 return_tensors = kwargs.get("return_tensors", None) 109 if return_tensors is not None: 110 kwargs["pad_to_multiple_of"] = 8 111 return {"encoding": self(input_texts, **kwargs)}