SetEncoderModel
- class lightning_ir.models.set_encoder.SetEncoderModel(config: SetEncoderConfig, *args, **kwargs)[source]
Bases:
MonoModelSetEncoder model. See
SetEncoderConfigfor configuration options.- __init__(config: SetEncoderConfig, *args, **kwargs)[source]
Initializes a SetEncoder model give a
SetEncoderConfig.- Parameters:
config (SetEncoderConfig) – Configuration for the SetEncoder model.
Methods
__init__(config, *args, **kwargs)Initializes a SetEncoder model give a
SetEncoderConfig.attention_forward(_self, self, ...)Performs the attention forward pass for the SetEncoder model.
cat_other_doc_hidden_states(hidden_states, ...)Concatenates the hidden states of other documents to the hidden states of the query and documents.
forward(encoding)Computes contextualized embeddings for the joint query-document input sequence and computes a relevance score.
get_extended_attention_mask(attention_mask, ...)Extends the attention mask to account for the number of documents per query.
Attributes
Flag to allow mini batches of documents for a single query.
self_attention_patterntraining- ALLOW_SUB_BATCHING = False
Flag to allow mini batches of documents for a single query. Set to false for listwise models to ensure correctness.
- static attention_forward(_self, self: Module, hidden_states: Tensor, attention_mask: FloatTensor | None, *args, num_docs: Sequence[int], **kwargs) Tuple[Tensor][source]
Performs the attention forward pass for the SetEncoder model.
- Parameters:
_self (SetEncoderModel) – Reference to the SetEncoder instance.
self (torch.nn.Module) – Reference to the attention module.
hidden_states (torch.Tensor) – Hidden states from the previous layer.
attention_mask (torch.FloatTensor | None) – Attention mask for the input sequence.
num_docs (Sequence[int]) – Specifies how many documents are passed per query. If a sequence of integers, len(num_doc) should be equal to the number of queries and sum(num_docs) equal to the number of documents, i.e., the sequence contains one value per query specifying the number of documents for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing the number of documents by the number of queries.
- Returns:
Contextualized embeddings.
- Return type:
Tuple[torch.Tensor]
Concatenates the hidden states of other documents to the hidden states of the query and documents.
- Parameters:
hidden_states (torch.Tensor) – Hidden states of the query and documents.
num_docs (Sequence[int]) – Specifies how many documents are passed per query. If a sequence of integers, len(num_doc) should be equal to the number of queries and sum(num_docs) equal to the number of documents, i.e., the sequence contains one value per query specifying the number of documents for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing the number of documents by the number of queries.
- Returns:
Concatenated hidden states of the query and documents.
- Return type:
torch.Tensor
- config_class
alias of
SetEncoderConfig
- forward(encoding: BatchEncoding) CrossEncoderOutput[source]
Computes contextualized embeddings for the joint query-document input sequence and computes a relevance score.
- Parameters:
encoding (BatchEncoding) – Tokenizer encoding for the joint query-document input sequence.
- Returns:
Output of the model.
- Return type:
- classmethod from_pretrained(model_name_or_path: str | Path, *args, **kwargs) Self
- Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained method to return a
derived LightningIRModel. See
LightningIRModelClassFactoryfor more details.
>>> # Loading using model class and backbone checkpoint >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> >>> # Loading using base class and backbone checkpoint >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>- Args:
model_name_or_path (str | Path): Name or path of the pretrained model.
- Raises:
ValueError: If called on the abstract class LightningIRModel and no config is passed.
- Returns:
LightningIRModel: A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin.
- get_extended_attention_mask(attention_mask: Tensor, input_shape: Tuple[int, ...], device: device | None = None, dtype: dtype | None = None, num_docs: Sequence[int] | None = None) Tensor[source]
Extends the attention mask to account for the number of documents per query.
- Parameters:
attention_mask (torch.Tensor) – Attention mask for the input sequence.
input_shape (Tuple[int, ...]) – Shape of the input sequence.
device (torch.device | None) – Device to move the attention mask to. Defaults to None.
dtype (torch.dtype | None) – Data type of the attention mask. Defaults to None.
num_docs (Sequence[int] | None) – Specifies how many documents are passed per query. If a sequence of integers, len(num_doc) should be equal to the number of queries and sum(num_docs) equal to the number of documents, i.e., the sequence contains one value per query specifying the number of documents for that query. If an integer, assumes an equal number of documents per query. If None, tries to infer the number of documents by dividing the number of documents by the number of queries. Defaults to None.
Returns – torch.Tensor: Extended attention mask.
- pooling(embeddings: Tensor, attention_mask: Tensor | None, pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None) Tensor
Helper method to apply pooling to the embeddings.
- Parameters:
embeddings (torch.Tensor) – Query or document embeddings
attention_mask (torch.Tensor | None) – Query or document attention mask
pooling_strategy (Literal['first', 'mean', 'max', 'sum'] | None) – The pooling strategy. No pooling is applied if None.
- Returns:
(Optionally) pooled embeddings.
- Return type:
torch.Tensor
- Raises:
ValueError – If an unknown pooling strategy is passed.
- sparsification(embeddings: Tensor, sparsification_strategy: Literal['relu', 'relu_log'] | None = None) Tensor
Helper method to apply sparsification to the embeddings.
- Parameters:
embeddings (torch.Tensor) – Query or document embeddings
sparsification_strategy (Literal['relu', 'relu_log'] | None) – The sparsification strategy. No sparsification is applied if None. Defaults to None.
- Returns:
(Optionally) sparsified embeddings.
- Return type:
torch.Tensor
- Raises:
ValueError – If an unknown sparsification strategy is passed.