SetEncoderModel
- class lightning_ir.models.set_encoder.SetEncoderModel(config: SetEncoderConfig, *args, **kwargs)[source]
Bases:
CrossEncoderModel
- __init__(config: SetEncoderConfig, *args, **kwargs)[source]
A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are aggragated into a single vector and fed to a linear layer which computes a final relevance score.
- Parameters:
config (CrossEncoderConfig) – Configuration for the cross-encoder model
Methods
__init__
(config, *args, **kwargs)A cross-encoder model that jointly encodes a query and document(s).
attention_forward
(_self, self, ...)cat_other_doc_hidden_states
(hidden_states, ...)forward
(encoding)Computes contextualized embeddings for the joint query-document input sequence and computes a relevance score.
get_extended_attention_mask
(attention_mask, ...)Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Attributes
Flag to allow mini batches of documents for a single query.
self_attention_pattern
training
- ALLOW_SUB_BATCHING = False
Flag to allow mini batches of documents for a single query. Set to false for listwise models to ensure correctness.
- config_class
alias of
SetEncoderConfig
- forward(encoding: BatchEncoding) CrossEncoderOutput [source]
Computes contextualized embeddings for the joint query-document input sequence and computes a relevance score.
- Parameters:
encoding (BatchEncoding) – Tokenizer encoding for the joint query-document input sequence
- Returns:
Output of the model
- Return type:
- classmethod from_pretrained(model_name_or_path: str | Path, *args, **kwargs) Self
- Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained method to return a
derived LightningIRModel. See
LightningIRModelClassFactory
for more details.
- param model_name_or_path:
Name or path of the pretrained model
- type model_name_or_path:
str | Path
- raises ValueError:
If called on the abstract class
LightningIRModel
and no config is passed- return:
A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin
- rtype:
LightningIRModel
>>> # Loading using model class and backbone checkpoint >>> type(CrossEncoderModel.from_pretrained("bert-base-uncased")) <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'> >>> # Loading using base class and backbone checkpoint >>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig())) <class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
- get_extended_attention_mask(attention_mask: Tensor, input_shape: Tuple[int, ...], device: device | None = None, dtype: dtype | None = None, num_docs: Sequence[int] | None = None) Tensor [source]
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
- Parameters:
attention_mask (torch.Tensor) – Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (tuple[int]) – The shape of the input to the model.
- Returns:
torch.Tensor The extended attention mask, with a the same dtype as attention_mask.dtype.
- pooling(embeddings: Tensor, attention_mask: Tensor | None, pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None) Tensor
Helper method to apply pooling to the embeddings.
- Parameters:
embeddings (torch.Tensor) – Query or document embeddings
attention_mask (torch.Tensor | None) – Query or document attention mask
pooling_strategy (Literal['first', 'mean', 'max', 'sum'] | None) – The pooling strategy. No pooling is applied if None.
- Raises:
ValueError – If an unknown pooling strategy is passed
- Returns:
(Optionally) pooled embeddings
- Return type:
torch.Tensor
- sparsification(embeddings: Tensor, sparsification_strategy: Literal['relu', 'relu_log'] | None = None) Tensor
Helper method to apply sparsification to the embeddings.
- Parameters:
embeddings (torch.Tensor) – Query or document embeddings
sparsification_strategy (Literal['relu', 'relu_log'] | None, optional) – The sparsification strategy. No sparsification is applied if None, defaults to None
- Raises:
ValueError – If an unknown sparsification strategy is passed
- Returns:
(Optionally) sparsified embeddings
- Return type:
torch.Tensor