1"""Configuration, model, and tokenizer for MVR (Multi-View Representation) type models. Originally proposed in
2`Multi-View Document Representation Learning for Open-Domain Dense Retrieval \
3<https://aclanthology.org/2022.acl-long.414/>`_.
4"""
5
6from typing import Literal
7
8import torch
9from tokenizers.processors import TemplateProcessing
10from transformers import BatchEncoding
11
12from lightning_ir.bi_encoder.bi_encoder_model import BiEncoderEmbedding
13
14from ...bi_encoder import BiEncoderTokenizer, MultiVectorBiEncoderConfig, MultiVectorBiEncoderModel
15from ...modeling_utils.embedding_post_processing import Pooler
16
17
[docs]
18class MvrConfig(MultiVectorBiEncoderConfig):
19 """Configuration class for a MVR model."""
20
21 model_type = "mvr"
22 """Model type for a MVR model."""
23
[docs]
24 def __init__(
25 self,
26 query_length: int | None = 32,
27 doc_length: int | None = 512,
28 similarity_function: Literal["cosine", "dot"] = "dot",
29 normalization_strategy: Literal["l2"] | None = None,
30 add_marker_tokens: bool = False,
31 pooling_strategy: Literal["first", "mean", "max", "sum"] = "first",
32 embedding_dim: int | None = None,
33 projection: Literal["linear", "linear_no_bias"] | None = "linear",
34 num_viewer_tokens: int | None = 8,
35 **kwargs,
36 ):
37 """A MVR model encodes queries and document separately. It uses a single vector to represent the query and
38 multiple vectors to represent the document. The document representation is obtained from n viewer tokens ([VIE])
39 prepended to the document. During training, a contrastive loss pushes the viewer token representations away
40 from one another, such that they represent different "views" of the document. Only the maximum similarity
41 between the query vector and the viewer token vectors is used to compute the relevance score.
42
43 Args:
44 query_length (int | None): Maximum number of tokens per query. If None does not truncate. Defaults to 32.
45 doc_length (int | None): Maximum number of tokens per document. If None does not truncate. Defaults to 512.
46 similarity_function (Literal['cosine', 'dot']): Similarity function to compute scores between query and
47 document embeddings. Defaults to "dot".
48 normalization_strategy (Literal['l2'] | None): Whether to normalize query and document embeddings.
49 Defaults to None.
50 add_marker_tokens (bool): Whether to prepend extra marker tokens [Q] / [D] to queries / documents.
51 Defaults to False.
52 pooling_strategy (Literal['first', 'mean', 'max', 'sum']): Pooling strategy for query embeddings. Defaults
53 to "first".
54 embedding_dim (int | None): Dimension of the final embeddings. If None, it will be set to the hidden size
55 of the backbone model. Defaults to None.
56 projection (Literal["linear", "linear_no_bias"] | None): type of projection layer to apply on the pooled
57 embeddings. If None, no projection is applied. Defaults to "linear".
58 num_viewer_tokens (int | None): Number of viewer tokens to prepend to the document. Defaults to 8.
59 """
60 super().__init__(
61 query_length=query_length,
62 doc_length=doc_length,
63 similarity_function=similarity_function,
64 normalization_strategy=normalization_strategy,
65 add_marker_tokens=add_marker_tokens,
66 embedding_dim=embedding_dim,
67 projection=projection,
68 **kwargs,
69 )
70 self.pooling_strategy = pooling_strategy
71 self.num_viewer_tokens = num_viewer_tokens
72
73
[docs]
74class MvrModel(MultiVectorBiEncoderModel):
75 """MVR model for multi-view representation learning."""
76
77 config_class = MvrConfig
78 """Configuration class for MVR models."""
79
[docs]
80 def __init__(self, config: MvrConfig, *args, **kwargs):
81 super().__init__(config, *args, **kwargs)
82 """Initialize a MVR model.
83
84 Args:
85 config (MvrConfig): Configuration for the MVR model.
86 """
87 if self.config.projection is None:
88 self.projection: torch.nn.Module = torch.nn.Identity()
89 else:
90 if self.config.embedding_dim is None:
91 raise ValueError("Unable to determine embedding dimension.")
92 self.projection = torch.nn.Linear(
93 self.config.hidden_size,
94 self.config.embedding_dim,
95 bias="no_bias" not in self.config.projection,
96 )
97 self.query_pooler = Pooler(config)
98
[docs]
99 def scoring_mask(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> torch.Tensor:
100 """Computes a scoring mask for batched tokenized text sequences which is used in the scoring function to mask
101 out vectors during scoring.
102
103 Args:
104 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
105 input_type (Literal["query", "doc"]): type of input, either "query" or "doc".
106 Returns:
107 torch.Tensor: Scoring mask.
108 """
109 if input_type == "query":
110 return torch.ones(encoding.input_ids.shape[0], 1, dtype=torch.bool, device=encoding.input_ids.device)
111 elif input_type == "doc":
112 return torch.ones(
113 encoding.input_ids.shape[0],
114 self.config.num_viewer_tokens,
115 dtype=torch.bool,
116 device=encoding.input_ids.device,
117 )
118 else:
119 raise ValueError(f"Invalid input type: {input_type}")
120
[docs]
121 def encode(self, encoding: BatchEncoding, input_type: Literal["query", "doc"]) -> BiEncoderEmbedding:
122 """Encodes a batched tokenized text sequences and returns the embeddings and scoring mask.
123
124 Args:
125 encoding (BatchEncoding): Tokenizer encodings for the text sequence.
126 input_type (Literal["query", "doc"]): type of input, either "query" or "doc".
127 Returns:
128 BiEncoderEmbedding: Embeddings and scoring mask.
129 """
130 embeddings = self._backbone_forward(**encoding).last_hidden_state
131 embeddings = self.projection(embeddings)
132 if input_type == "query":
133 embeddings = self.query_pooler(embeddings, None)
134 elif input_type == "doc":
135 embeddings = embeddings[:, 1 : self.config.num_viewer_tokens + 1]
136 else:
137 raise ValueError(f"Invalid input type: {input_type}")
138 if self.config.normalization_strategy == "l2":
139 embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
140 scoring_mask = self.scoring_mask(encoding, input_type)
141 return BiEncoderEmbedding(embeddings, scoring_mask, encoding)
142
143
[docs]
144class MvrTokenizer(BiEncoderTokenizer):
145 config_class = MvrConfig
146
[docs]
147 def __init__(
148 self,
149 *args,
150 query_length: int | None = 32,
151 doc_length: int | None = 512,
152 add_marker_tokens: bool = False,
153 num_viewer_tokens: int = 8,
154 **kwargs,
155 ):
156 super().__init__(
157 *args,
158 query_length=query_length,
159 doc_length=doc_length,
160 add_marker_tokens=add_marker_tokens,
161 num_viewer_tokens=num_viewer_tokens,
162 **kwargs,
163 )
164 self.num_viewer_tokens = num_viewer_tokens
165 if num_viewer_tokens is not None:
166 viewer_tokens = [f"[VIE{idx}]" for idx in range(num_viewer_tokens)]
167 self.add_tokens(viewer_tokens, special_tokens=True)
168 special_tokens = [
169 ("[CLS]", self.cls_token_id),
170 ("[SEP]", self.sep_token_id),
171 ] + [
172 (viewer_tokens[viewer_token_id], self.viewer_token_id(viewer_token_id))
173 for viewer_token_id in range(num_viewer_tokens)
174 ]
175 viewer_tokens_string = " ".join(viewer_tokens)
176 if self.doc_token_id is not None:
177 prefix = f"[CLS] {self.DOC_TOKEN}"
178 special_tokens.append((self.DOC_TOKEN, self.doc_token_id))
179 else:
180 prefix = "[CLS]"
181 self.doc_post_processor = TemplateProcessing(
182 single=f"{prefix} {viewer_tokens_string} $0 [SEP]",
183 pair="[CLS] $A [SEP] $B:1 [SEP]:1",
184 special_tokens=special_tokens,
185 )
186
[docs]
187 def viewer_token_id(self, viewer_token_id: int) -> int | None:
188 """The token id of the query token if marker tokens are added.
189
190 :return: Token id of the query token
191 :rtype: int | None
192 """
193 if f"[VIE{viewer_token_id}]" in self.added_tokens_encoder:
194 return self.added_tokens_encoder[f"[VIE{viewer_token_id}]"]
195 return None