Source code for lightning_ir.models.bi_encoders.mvr

  1"""Configuration, model, and tokenizer for MVR (Multi-View Representation) type models.
  2
  3MVR (Multi-View Representation) is an information retrieval approach that strikes a balance between the simplicity of
  4single-vector models and the complexity of token-level architectures by representing a document using a small, fixed
  5number of dense vectors.  Instead of compressing an entire passage into one embedding or storing a separate vector for
  6every single word. MVR generates multiple distinct representations of the text, with each vector capturing a different
  7semantic facet or topic. During a search, the query is compared against these multiple document views to find the
  8strongest match.
  9
 10Originally proposed in
 11`Multi-View Document Representation Learning for Open-Domain Dense Retrieval \
 12<https://aclanthology.org/2022.acl-long.414/>`_.
 13"""
 14
 15from typing import Literal
 16
 17import torch
 18from tokenizers.processors import TemplateProcessing
 19from transformers import BatchEncoding
 20
 21from lightning_ir.bi_encoder.bi_encoder_model import BiEncoderEmbedding
 22
 23from ...bi_encoder import BiEncoderTokenizer, MultiVectorBiEncoderConfig, MultiVectorBiEncoderModel
 24from ...modeling_utils.embedding_post_processing import Pooler
 25
 26
[docs] 27class MvrConfig(MultiVectorBiEncoderConfig): 28 """Configuration class for a MVR model.""" 29 30 model_type = "mvr" 31 """Model type for a MVR model.""" 32
[docs] 33 def __init__( 34 self, 35 query_length: int | None = 32, 36 doc_length: int | None = 512, 37 similarity_function: Literal["cosine", "dot"] = "dot", 38 normalization_strategy: Literal["l2"] | None = None, 39 add_marker_tokens: bool = False, 40 pooling_strategy: Literal["first", "mean", "max", "sum"] = "first", 41 embedding_dim: int | None = None, 42 projection: Literal["linear", "linear_no_bias"] | None = "linear", 43 num_viewer_tokens: int | None = 8, 44 **kwargs, 45 ): 46 """A MVR model encodes queries and document separately. It uses a single vector to represent the query and 47 multiple vectors to represent the document. The document representation is obtained from n viewer tokens ([VIE]) 48 prepended to the document. During training, a contrastive loss pushes the viewer token representations away 49 from one another, such that they represent different "views" of the document. Only the maximum similarity 50 between the query vector and the viewer token vectors is used to compute the relevance score. 51 52 Args: 53 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32. 54 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512. 55 similarity_function (Literal['cosine', 'dot']): Similarity function to compute scores between query and 56 document embeddings. Defaults to "dot". 57 normalization_strategy (Literal['l2'] | None): Whether to normalize query and document embeddings. 58 Defaults to None. 59 add_marker_tokens (bool): Whether to prepend extra marker tokens [Q] / [D] to queries / documents. 60 Defaults to False. 61 pooling_strategy (Literal['first', 'mean', 'max', 'sum']): Pooling strategy for query embeddings. Defaults 62 to "first". 63 embedding_dim (int | None): Dimension of the final embeddings. If None, it will be set to the hidden size 64 of the backbone model. Defaults to None. 65 projection (Literal["linear", "linear_no_bias"] | None): type of projection layer to apply on the pooled 66 embeddings. If None, no projection is applied. Defaults to "linear". 67 num_viewer_tokens (int | None): Number of viewer tokens to prepend to the document. Defaults to 8. 68 """ 69 super().__init__( 70 query_length=query_length, 71 doc_length=doc_length, 72 similarity_function=similarity_function, 73 normalization_strategy=normalization_strategy, 74 add_marker_tokens=add_marker_tokens, 75 embedding_dim=embedding_dim, 76 projection=projection, 77 **kwargs, 78 ) 79 self.pooling_strategy = pooling_strategy 80 self.num_viewer_tokens = num_viewer_tokens
81 82
[docs] 83class MvrModel(MultiVectorBiEncoderModel): 84 """MVR model for multi-view representation learning.""" 85 86 config_class = MvrConfig 87 """Configuration class for MVR models.""" 88
[docs] 89 def __init__(self, config: MvrConfig, *args, **kwargs): 90 super().__init__(config, *args, **kwargs) 91 """Initialize a MVR model. 92 93 Args: 94 config (MvrConfig): Configuration for the MVR model. 95 """ 96 if self.config.projection is None: 97 self.projection: torch.nn.Module = torch.nn.Identity() 98 else: 99 if self.config.embedding_dim is None: 100 raise ValueError("Unable to determine embedding dimension.") 101 self.projection = torch.nn.Linear( 102 self.config.hidden_size, 103 self.config.embedding_dim, 104 bias="no_bias" not in self.config.projection, 105 ) 106 self.query_pooler = Pooler(config)
107
[docs] 108 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor: 109 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask 110 out vectors during scoring. 111 112 Args: 113 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 114 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 115 Returns: 116 torch.Tensor: Scoring mask. 117 """ 118 if input_type == "query": 119 return torch.ones(encoding.input_ids.shape[0], 1, dtype=torch.bool, device=encoding.input_ids.device) 120 elif input_type == "doc": 121 return torch.ones( 122 encoding.input_ids.shape[0], 123 self.config.num_viewer_tokens, 124 dtype=torch.bool, 125 device=encoding.input_ids.device, 126 ) 127 else: 128 raise ValueError(f"Invalid input type: {input_type}")
129
[docs] 130 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding: 131 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask. 132 133 Args: 134 encoding (BatchEncoding): Tokenizer encodings for the text sequence. 135 input_type (Literal["query", "doc"]): type of input, either "query" or "doc". 136 Returns: 137 BiEncoderEmbedding: Embeddings and scoring mask. 138 """ 139 embeddings = self._backbone_forward(**encoding).last_hidden_state 140 embeddings = self.projection(embeddings) 141 if input_type == "query": 142 embeddings = self.query_pooler(embeddings, None) 143 elif input_type == "doc": 144 embeddings = embeddings[:, 1 : self.config.num_viewer_tokens + 1] 145 else: 146 raise ValueError(f"Invalid input type: {input_type}") 147 if self.config.normalization_strategy == "l2": 148 embeddings = torch.nn.functional.normalize(embeddings, dim=-1) 149 scoring_mask = self.scoring_mask(encoding, input_type) 150 return BiEncoderEmbedding(embeddings, scoring_mask, encoding)
151 152
[docs] 153class MvrTokenizer(BiEncoderTokenizer): 154 config_class = MvrConfig 155
[docs] 156 def __init__( 157 self, 158 *args, 159 query_length: int | None = 32, 160 doc_length: int | None = 512, 161 add_marker_tokens: bool = False, 162 num_viewer_tokens: int = 8, 163 **kwargs, 164 ): 165 super().__init__( 166 *args, 167 query_length=query_length, 168 doc_length=doc_length, 169 add_marker_tokens=add_marker_tokens, 170 num_viewer_tokens=num_viewer_tokens, 171 **kwargs, 172 ) 173 self.num_viewer_tokens = num_viewer_tokens 174 if num_viewer_tokens is not None: 175 viewer_tokens = [f"[VIE{idx}]" for idx in range(num_viewer_tokens)] 176 self.add_tokens(viewer_tokens, special_tokens=True) 177 special_tokens = [ 178 ("[CLS]", self.cls_token_id), 179 ("[SEP]", self.sep_token_id), 180 ] + [ 181 (viewer_tokens[viewer_token_id], self.viewer_token_id(viewer_token_id)) 182 for viewer_token_id in range(num_viewer_tokens) 183 ] 184 viewer_tokens_string = " ".join(viewer_tokens) 185 if self.doc_token_id is not None: 186 prefix = f"[CLS] {self.DOC_TOKEN}" 187 special_tokens.append((self.DOC_TOKEN, self.doc_token_id)) 188 else: 189 prefix = "[CLS]" 190 self.doc_post_processor = TemplateProcessing( 191 single=f"{prefix} {viewer_tokens_string} $0 [SEP]", 192 pair="[CLS] $A [SEP] $B:1 [SEP]:1", 193 special_tokens=special_tokens, 194 )
195
[docs] 196 def viewer_token_id(self, viewer_token_id: int) -> int | None: 197 """The token id of the query token if marker tokens are added. 198 199 :return: Token id of the query token 200 :rtype: int | None 201 """ 202 if f"[VIE{viewer_token_id}]" in self.added_tokens_encoder: 203 return self.added_tokens_encoder[f"[VIE{viewer_token_id}]"] 204 return None