1"""Configuration, model, and tokenizer for MVR (Multi-View Representation) type models. Originally proposed in
2`Multi-View Document Representation Learning for Open-Domain Dense Retrieval \
3<https://aclanthology.org/2022.acl-long.414/>`_.
4"""
5
6from typing import Literal
7
8import torch
9from tokenizers.processors import TemplateProcessing
10from transformers import BatchEncoding
11
12from lightning_ir.bi_encoder.bi_encoder_model import BiEncoderEmbedding
13
14from ...bi_encoder import BiEncoderTokenizer, MultiVectorBiEncoderConfig, MultiVectorBiEncoderModel
15
16
[docs]
17class MvrConfig(MultiVectorBiEncoderConfig):
18 """Configuration class for a MVR model."""
19
20 model_type = "mvr"
21 """Model type for a MVR model."""
22
[docs]
23 def __init__(
24 self,
25 query_length: int | None = 32,
26 doc_length: int | None = 512,
27 similarity_function: Literal["cosine", "dot"] = "dot",
28 normalization: Literal["l2"] | None = None,
29 add_marker_tokens: bool = False,
30 embedding_dim: int | None = None,
31 projection: Literal["linear", "linear_no_bias"] | None = "linear",
32 num_viewer_tokens: int | None = 8,
33 **kwargs,
34 ):
35 """A MVR model encodes queries and document separately. It uses a single vector to represent the query and
36 multiple vectors to represent the document. The document representation is obtained from n viewer tokens ([VIE])
37 prepended to the document. During training, a contrastive loss pushes the viewer token representations away
38 from one another, such that they represent different "views" of the document. Only the maximum similarity
39 between the query vector and the viewer token vectors is used to compute the relevance score.
40
41 Args:
42 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
43 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
44 similarity_function (Literal['cosine', 'dot']): Similarity function to compute scores between query and
45 document embeddings. Defaults to "dot".
46 normalization (Literal['l2'] | None): Whether to normalize query and document embeddings. Defaults to None.
47 add_marker_tokens (bool): Whether to prepend extra marker tokens [Q] / [D] to queries / documents.
48 Defaults to False.
49 embedding_dim (int | None): Dimension of the final embeddings. If None, it will be set to the hidden size
50 of the backbone model. Defaults to None.
51 projection (Literal["linear", "linear_no_bias"] | None): Type of projection layer to apply on the pooled
52 embeddings. If None, no projection is applied. Defaults to "linear".
53 num_viewer_tokens (int | None): Number of viewer tokens to prepend to the document. Defaults to 8.
54 """
55 super().__init__(
56 query_length=query_length,
57 doc_length=doc_length,
58 similarity_function=similarity_function,
59 normalization=normalization,
60 add_marker_tokens=add_marker_tokens,
61 embedding_dim=embedding_dim,
62 projection=projection,
63 **kwargs,
64 )
65 self.num_viewer_tokens = num_viewer_tokens
66
67
[docs]
68class MvrModel(MultiVectorBiEncoderModel):
69 """MVR model for multi-view representation learning."""
70
71 config_class = MvrConfig
72 """Configuration class for MVR models."""
73
[docs]
74 def __init__(self, config: MvrConfig, *args, **kwargs):
75 super().__init__(config, *args, **kwargs)
76 """Initialize a MVR model.
77
78 Args:
79 config (MvrConfig): Configuration for the MVR model.
80 """
81 if self.config.projection is None:
82 self.projection: torch.nn.Module = torch.nn.Identity()
83 else:
84 if self.config.embedding_dim is None:
85 raise ValueError("Unable to determine embedding dimension.")
86 self.projection = torch.nn.Linear(
87 self.config.hidden_size,
88 self.config.embedding_dim,
89 bias="no_bias" not in self.config.projection,
90 )
91
[docs]
92 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor:
93 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask
94 out vectors during scoring.
95
96 Args:
97 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
98 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc".
99 Returns:
100 torch.Tensor: Scoring mask.
101 """
102 if input_type == "query":
103 return torch.ones(encoding.input_ids.shape[0], 1, dtype=torch.bool, device=encoding.input_ids.device)
104 elif input_type == "doc":
105 return torch.ones(
106 encoding.input_ids.shape[0],
107 self.config.num_viewer_tokens,
108 dtype=torch.bool,
109 device=encoding.input_ids.device,
110 )
111 else:
112 raise ValueError(f"Invalid input type: {input_type}")
113
[docs]
114 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
115 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
116
117 Args:
118 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
119 input_type (Literal["query", "doc"]): Type of input, either "query" or "doc".
120 Returns:
121 BiEncoderEmbedding: Embeddings and scoring mask.
122 """
123 embeddings = self._backbone_forward(**encoding).last_hidden_state
124 embeddings = self.projection(embeddings)
125 if input_type == "query":
126 embeddings = self.pooling(embeddings, None, "first")
127 elif input_type == "doc":
128 embeddings = embeddings[:, 1 : self.config.num_viewer_tokens + 1]
129 if self.config.normalization == "l2":
130 embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
131 scoring_mask = self.scoring_mask(encoding, input_type)
132 return BiEncoderEmbedding(embeddings, scoring_mask, encoding)
133
134
[docs]
135class MvrTokenizer(BiEncoderTokenizer):
136 config_class = MvrConfig
137
[docs]
138 def __init__(
139 self,
140 *args,
141 query_length: int | None = 32,
142 doc_length: int | None = 512,
143 add_marker_tokens: bool = False,
144 num_viewer_tokens: int = 8,
145 **kwargs,
146 ):
147 super().__init__(
148 *args,
149 query_length=query_length,
150 doc_length=doc_length,
151 add_marker_tokens=add_marker_tokens,
152 num_viewer_tokens=num_viewer_tokens,
153 **kwargs,
154 )
155 self.num_viewer_tokens = num_viewer_tokens
156 if num_viewer_tokens is not None:
157 viewer_tokens = [f"[VIE{idx}]" for idx in range(num_viewer_tokens)]
158 self.add_tokens(viewer_tokens, special_tokens=True)
159 special_tokens = [
160 ("[CLS]", self.cls_token_id),
161 ("[SEP]", self.sep_token_id),
162 ] + [
163 (viewer_tokens[viewer_token_id], self.viewer_token_id(viewer_token_id))
164 for viewer_token_id in range(num_viewer_tokens)
165 ]
166 viewer_tokens_string = " ".join(viewer_tokens)
167 if self.doc_token_id is not None:
168 prefix = f"[CLS] {self.DOC_TOKEN}"
169 special_tokens.append((self.DOC_TOKEN, self.doc_token_id))
170 else:
171 prefix = "[CLS]"
172 self.doc_post_processor = TemplateProcessing(
173 single=f"{prefix} {viewer_tokens_string} $0 [SEP]",
174 pair="[CLS] $A [SEP] $B:1 [SEP]:1",
175 special_tokens=special_tokens,
176 )
177
[docs]
178 def viewer_token_id(self, viewer_token_id: int) -> int | None:
179 """The token id of the query token if marker tokens are added.
180
181 :return: Token id of the query token
182 :rtype: int | None
183 """
184 if f"[VIE{viewer_token_id}]" in self.added_tokens_encoder:
185 return self.added_tokens_encoder[f"[VIE{viewer_token_id}]"]
186 return None