Source code for lightning_ir.models.bi_encoders.mvr

  1"""Configuration, model, and tokenizer for MVR (Multi-View Representation) type models. Originally proposed in
  2`Multi-View Document Representation Learning for Open-Domain Dense Retrieval \
  3<https://aclanthology.org/2022.acl-long.414/>`_.
  4"""
  5
  6from typing import Literal
  7
  8import torch
  9from tokenizers.processors import TemplateProcessing
 10from transformers import BatchEncoding
 11
 12from lightning_ir.bi_encoder.bi_encoder_model import BiEncoderEmbedding
 13
 14from ...bi_encoder import BiEncoderTokenizer, MultiVectorBiEncoderConfig, MultiVectorBiEncoderModel
 15from ...modeling_utils.embedding_post_processing import Pooler
 16
 17
[docs] 18class MvrConfig(MultiVectorBiEncoderConfig): 19 """Configuration class for a MVR model.""" 20 21 model_type = "mvr" 22 """Model type for a MVR model.""" 23
[docs] 24 def __init__( 25 self, 26 query_length: int | None = 32, 27 doc_length: int | None = 512, 28 similarity_function: Literal["cosine", "dot"] = "dot", 29 normalization_strategy: Literal["l2"] | None = None, 30 add_marker_tokens: bool = False, 31 pooling_strategy: Literal["first", "mean", "max", "sum"] = "first", 32 embedding_dim: int | None = None, 33 projection: Literal["linear", "linear_no_bias"] | None = "linear", 34 num_viewer_tokens: int | None = 8, 35 **kwargs, 36 ): 37 """A MVR model encodes queries and document separately. It uses a single vector to represent the query and 38 multiple vectors to represent the document. The document representation is obtained from n viewer tokens ([VIE]) 39 prepended to the document. During training, a contrastive loss pushes the viewer token representations away 40 from one another, such that they represent different "views" of the document. Only the maximum similarity 41 between the query vector and the viewer token vectors is used to compute the relevance score. 42 43 Args: 44 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 45 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 46 similarity_function (Literal['cosine', 'dot']): Similarity function to compute scores between query and 47 document embeddings. Defaults to "dot". 48 normalization_strategy (Literal['l2'] | None): Whether to normalize query and document embeddings. 49 Defaults to None. 50 add_marker_tokens (bool): Whether to prepend extra marker tokens [Q] / [D] to queries / documents. 51 Defaults to False. 52 pooling_strategy (Literal['first', 'mean', 'max', 'sum']): Pooling strategy for query embeddings. Defaults 53 to "first". 54 embedding_dim (int | None): Dimension of the final embeddings. If None, it will be set to the hidden size 55 of the backbone model. Defaults to None. 56 projection (Literal["linear", "linear_no_bias"] | None): type of projection layer to apply on the pooled 57 embeddings. If None, no projection is applied. Defaults to "linear". 58 num_viewer_tokens (int | None): Number of viewer tokens to prepend to the document. Defaults to 8. 59 """ 60 super().__init__( 61 query_length=query_length, 62 doc_length=doc_length, 63 similarity_function=similarity_function, 64 normalization_strategy=normalization_strategy, 65 add_marker_tokens=add_marker_tokens, 66 embedding_dim=embedding_dim, 67 projection=projection, 68 **kwargs, 69 ) 70 self.pooling_strategy = pooling_strategy 71 self.num_viewer_tokens = num_viewer_tokens
72 73
[docs] 74class MvrModel(MultiVectorBiEncoderModel): 75 """MVR model for multi-view representation learning.""" 76 77 config_class = MvrConfig 78 """Configuration class for MVR models.""" 79
[docs] 80 def __init__(self, config: MvrConfig, *args, **kwargs): 81 super().__init__(config, *args, **kwargs) 82 """Initialize a MVR model. 83 84 Args: 85 config (MvrConfig): Configuration for the MVR model. 86 """ 87 if self.config.projection is None: 88 self.projection: torch.nn.Module = torch.nn.Identity() 89 else: 90 if self.config.embedding_dim is None: 91 raise ValueError("Unable to determine embedding dimension.") 92 self.projection = torch.nn.Linear( 93 self.config.hidden_size, 94 self.config.embedding_dim, 95 bias="no_bias" not in self.config.projection, 96 ) 97 self.query_pooler = Pooler(config)
98
[docs] 99 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor: 100 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask 101 out vectors during scoring. 102 103 Args: 104 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 105 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 106 Returns: 107 torch.Tensor: Scoring mask. 108 """ 109 if input_type == "query": 110 return torch.ones(encoding.input_ids.shape[0], 1, dtype=torch.bool, device=encoding.input_ids.device) 111 elif input_type == "doc": 112 return torch.ones( 113 encoding.input_ids.shape[0], 114 self.config.num_viewer_tokens, 115 dtype=torch.bool, 116 device=encoding.input_ids.device, 117 ) 118 else: 119 raise ValueError(f"Invalid input type: {input_type}")
120
[docs] 121 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 122 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 123 124 Args: 125 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 126 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 127 Returns: 128 BiEncoderEmbedding: Embeddings and scoring mask. 129 """ 130 embeddings = self._backbone_forward(**encoding).last_hidden_state 131 embeddings = self.projection(embeddings) 132 if input_type == "query": 133 embeddings = self.query_pooler(embeddings, None) 134 elif input_type == "doc": 135 embeddings = embeddings[:, 1 : self.config.num_viewer_tokens + 1] 136 else: 137 raise ValueError(f"Invalid input type: {input_type}") 138 if self.config.normalization_strategy == "l2": 139 embeddings = torch.nn.functional.normalize(embeddings, dim=-1) 140 scoring_mask = self.scoring_mask(encoding, input_type) 141 return BiEncoderEmbedding(embeddings, scoring_mask, encoding)
142 143
[docs] 144class MvrTokenizer(BiEncoderTokenizer): 145 config_class = MvrConfig 146
[docs] 147 def __init__( 148 self, 149 *args, 150 query_length: int | None = 32, 151 doc_length: int | None = 512, 152 add_marker_tokens: bool = False, 153 num_viewer_tokens: int = 8, 154 **kwargs, 155 ): 156 super().__init__( 157 *args, 158 query_length=query_length, 159 doc_length=doc_length, 160 add_marker_tokens=add_marker_tokens, 161 num_viewer_tokens=num_viewer_tokens, 162 **kwargs, 163 ) 164 self.num_viewer_tokens = num_viewer_tokens 165 if num_viewer_tokens is not None: 166 viewer_tokens = [f"[VIE{idx}]" for idx in range(num_viewer_tokens)] 167 self.add_tokens(viewer_tokens, special_tokens=True) 168 special_tokens = [ 169 ("[CLS]", self.cls_token_id), 170 ("[SEP]", self.sep_token_id), 171 ] + [ 172 (viewer_tokens[viewer_token_id], self.viewer_token_id(viewer_token_id)) 173 for viewer_token_id in range(num_viewer_tokens) 174 ] 175 viewer_tokens_string = " ".join(viewer_tokens) 176 if self.doc_token_id is not None: 177 prefix = f"[CLS] {self.DOC_TOKEN}" 178 special_tokens.append((self.DOC_TOKEN, self.doc_token_id)) 179 else: 180 prefix = "[CLS]" 181 self.doc_post_processor = TemplateProcessing( 182 single=f"{prefix} {viewer_tokens_string} $0 [SEP]", 183 pair="[CLS] $A [SEP] $B:1 [SEP]:1", 184 special_tokens=special_tokens, 185 )
186
[docs] 187 def viewer_token_id(self, viewer_token_id: int) -> int | None: 188 """The token id of the query token if marker tokens are added. 189 190 :return: Token id of the query token 191 :rtype: int | None 192 """ 193 if f"[VIE{viewer_token_id}]" in self.added_tokens_encoder: 194 return self.added_tokens_encoder[f"[VIE{viewer_token_id}]"] 195 return None