1"""Configuration, model, and tokenizer for MVR (Multi-View Representation) type models.
2
3MVR (Multi-View Representation) is an information retrieval approach that strikes a balance between the simplicity of
4single-vector models and the complexity of token-level architectures by representing a document using a small, fixed
5number of dense vectors. Instead of compressing an entire passage into one embedding or storing a separate vector for
6every single word. MVR generates multiple distinct representations of the text, with each vector capturing a different
7semantic facet or topic. During a search, the query is compared against these multiple document views to find the
8strongest match.
9
10Originally proposed in
11`Multi-View Document Representation Learning for Open-Domain Dense Retrieval \
12<https://aclanthology.org/2022.acl-long.414/>`_.
13"""
14
15from typing import Literal
16
17import torch
18from tokenizers.processors import TemplateProcessing
19from transformers import BatchEncoding
20
21from lightning_ir.bi_encoder.bi_encoder_model import BiEncoderEmbedding
22
23from ...bi_encoder import BiEncoderTokenizer, MultiVectorBiEncoderConfig, MultiVectorBiEncoderModel
24from ...modeling_utils.embedding_post_processing import Pooler
25
26
[docs]
27class MvrConfig(MultiVectorBiEncoderConfig):
28 """Configuration class for a MVR model."""
29
30 model_type = "mvr"
31 """Model type for a MVR model."""
32
[docs]
33 def __init__(
34 self,
35 query_length: int | None = 32,
36 doc_length: int | None = 512,
37 similarity_function: Literal["cosine", "dot"] = "dot",
38 normalization_strategy: Literal["l2"] | None = None,
39 add_marker_tokens: bool = False,
40 pooling_strategy: Literal["first", "mean", "max", "sum"] = "first",
41 embedding_dim: int | None = None,
42 projection: Literal["linear", "linear_no_bias"] | None = "linear",
43 num_viewer_tokens: int | None = 8,
44 **kwargs,
45 ):
46 """A MVR model encodes queries and document separately. It uses a single vector to represent the query and
47 multiple vectors to represent the document. The document representation is obtained from n viewer tokens ([VIE])
48 prepended to the document. During training, a contrastive loss pushes the viewer token representations away
49 from one another, such that they represent different "views" of the document. Only the maximum similarity
50 between the query vector and the viewer token vectors is used to compute the relevance score.
51
52 Args:
53 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
54 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
55 similarity_function (Literal['cosine', 'dot']): Similarity function to compute scores between query and
56 document embeddings. Defaults to "dot".
57 normalization_strategy (Literal['l2'] | None): Whether to normalize query and document embeddings.
58 Defaults to None.
59 add_marker_tokens (bool): Whether to prepend extra marker tokens [Q] / [D] to queries / documents.
60 Defaults to False.
61 pooling_strategy (Literal['first', 'mean', 'max', 'sum']): Pooling strategy for query embeddings. Defaults
62 to "first".
63 embedding_dim (int | None): Dimension of the final embeddings. If None, it will be set to the hidden size
64 of the backbone model. Defaults to None.
65 projection (Literal["linear", "linear_no_bias"] | None): type of projection layer to apply on the pooled
66 embeddings. If None, no projection is applied. Defaults to "linear".
67 num_viewer_tokens (int | None): Number of viewer tokens to prepend to the document. Defaults to 8.
68 """
69 super().__init__(
70 query_length=query_length,
71 doc_length=doc_length,
72 similarity_function=similarity_function,
73 normalization_strategy=normalization_strategy,
74 add_marker_tokens=add_marker_tokens,
75 embedding_dim=embedding_dim,
76 projection=projection,
77 **kwargs,
78 )
79 self.pooling_strategy = pooling_strategy
80 self.num_viewer_tokens = num_viewer_tokens
81
82
[docs]
83class MvrModel(MultiVectorBiEncoderModel):
84 """MVR model for multi-view representation learning."""
85
86 config_class = MvrConfig
87 """Configuration class for MVR models."""
88
[docs]
89 def __init__(self, config: MvrConfig, *args, **kwargs):
90 super().__init__(config, *args, **kwargs)
91 """Initialize a MVR model.
92
93 Args:
94 config (MvrConfig): Configuration for the MVR model.
95 """
96 if self.config.projection is None:
97 self.projection: torch.nn.Module = torch.nn.Identity()
98 else:
99 if self.config.embedding_dim is None:
100 raise ValueError("Unable to determine embedding dimension.")
101 self.projection = torch.nn.Linear(
102 self.config.hidden_size,
103 self.config.embedding_dim,
104 bias="no_bias" not in self.config.projection,
105 )
106 self.query_pooler = Pooler(config)
107
[docs]
108 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor:
109 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask
110 out vectors during scoring.
111
112 Args:
113 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
114 input_type (Literal["query", "doc"]): type of input, either "query" or "doc".
115 Returns:
116 torch.Tensor: Scoring mask.
117 """
118 if input_type == "query":
119 return torch.ones(encoding.input_ids.shape[0], 1, dtype=torch.bool, device=encoding.input_ids.device)
120 elif input_type == "doc":
121 return torch.ones(
122 encoding.input_ids.shape[0],
123 self.config.num_viewer_tokens,
124 dtype=torch.bool,
125 device=encoding.input_ids.device,
126 )
127 else:
128 raise ValueError(f"Invalid input type: {input_type}")
129
[docs]
130 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
131 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
132
133 Args:
134 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
135 input_type (Literal["query", "doc"]): type of input, either "query" or "doc".
136 Returns:
137 BiEncoderEmbedding: Embeddings and scoring mask.
138 """
139 embeddings = self._backbone_forward(**encoding).last_hidden_state
140 embeddings = self.projection(embeddings)
141 if input_type == "query":
142 embeddings = self.query_pooler(embeddings, None)
143 elif input_type == "doc":
144 embeddings = embeddings[:, 1 : self.config.num_viewer_tokens + 1]
145 else:
146 raise ValueError(f"Invalid input type: {input_type}")
147 if self.config.normalization_strategy == "l2":
148 embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
149 scoring_mask = self.scoring_mask(encoding, input_type)
150 return BiEncoderEmbedding(embeddings, scoring_mask, encoding)
151
152
[docs]
153class MvrTokenizer(BiEncoderTokenizer):
154 config_class = MvrConfig
155
[docs]
156 def __init__(
157 self,
158 *args,
159 query_length: int | None = 32,
160 doc_length: int | None = 512,
161 add_marker_tokens: bool = False,
162 num_viewer_tokens: int = 8,
163 **kwargs,
164 ):
165 super().__init__(
166 *args,
167 query_length=query_length,
168 doc_length=doc_length,
169 add_marker_tokens=add_marker_tokens,
170 num_viewer_tokens=num_viewer_tokens,
171 **kwargs,
172 )
173 self.num_viewer_tokens = num_viewer_tokens
174 if num_viewer_tokens is not None:
175 viewer_tokens = [f"[VIE{idx}]" for idx in range(num_viewer_tokens)]
176 self.add_tokens(viewer_tokens, special_tokens=True)
177 special_tokens = [
178 ("[CLS]", self.cls_token_id),
179 ("[SEP]", self.sep_token_id),
180 ] + [
181 (viewer_tokens[viewer_token_id], self.viewer_token_id(viewer_token_id))
182 for viewer_token_id in range(num_viewer_tokens)
183 ]
184 viewer_tokens_string = " ".join(viewer_tokens)
185 if self.doc_token_id is not None:
186 prefix = f"[CLS] {self.DOC_TOKEN}"
187 special_tokens.append((self.DOC_TOKEN, self.doc_token_id))
188 else:
189 prefix = "[CLS]"
190 self.doc_post_processor = TemplateProcessing(
191 single=f"{prefix} {viewer_tokens_string} $0 [SEP]",
192 pair="[CLS] $A [SEP] $B:1 [SEP]:1",
193 special_tokens=special_tokens,
194 )
195
[docs]
196 def viewer_token_id(self, viewer_token_id: int) -> int | None:
197 """The token id of the query token if marker tokens are added.
198
199 :return: Token id of the query token
200 :rtype: int | None
201 """
202 if f"[VIE{viewer_token_id}]" in self.added_tokens_encoder:
203 return self.added_tokens_encoder[f"[VIE{viewer_token_id}]"]
204 return None