T5CrossEncoderModel

class lightning_ir.models.t5_cross_encoder.T5CrossEncoderModel(config: T5CrossEncoderConfig, *args, **kwargs)[source]

Bases: CrossEncoderModel

__init__(config: T5CrossEncoderConfig, *args, **kwargs)[source]

A cross-encoder model that jointly encodes a query and document(s). The contextualized embeddings are aggragated into a single vector and fed to a linear layer which computes a final relevance score.

Parameters:

config (CrossEncoderConfig) – Configuration for the cross-encoder model

Methods

__init__(config, *args, **kwargs)

A cross-encoder model that jointly encodes a query and document(s).

forward(encoding)

Computes contextualized embeddings for the joint query-document input sequence and computes a relevance score.

Attributes

training

ALLOW_SUB_BATCHING = True

Flag to allow mini batches of documents for a single query. Set to false for listwise models to ensure correctness.

config_class

alias of T5CrossEncoderConfig

forward(encoding: BatchEncoding) CrossEncoderOutput[source]

Computes contextualized embeddings for the joint query-document input sequence and computes a relevance score.

Parameters:

encoding (BatchEncoding) – Tokenizer encoding for the joint query-document input sequence

Returns:

Output of the model

Return type:

CrossEncoderOutput

classmethod from_pretrained(model_name_or_path: str | Path, *args, **kwargs) Self
Loads a pretrained model. Wraps the transformers.PreTrainedModel.from_pretrained method to return a

derived LightningIRModel. See LightningIRModelClassFactory for more details.

param model_name_or_path:

Name or path of the pretrained model

type model_name_or_path:

str | Path

raises ValueError:

If called on the abstract class LightningIRModel and no config is passed

return:

A derived LightningIRModel consisting of a backbone model and a LightningIRModel mixin

rtype:

LightningIRModel

>>> # Loading using model class and backbone checkpoint
>>> type(CrossEncoderModel.from_pretrained("bert-base-uncased"))
<class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
>>> # Loading using base class and backbone checkpoint
>>> type(LightningIRModel.from_pretrained("bert-base-uncased", config=CrossEncoderConfig()))
<class 'lightning_ir.base.class_factory.CrossEncoderBertModel'>
pooling(embeddings: Tensor, attention_mask: Tensor | None, pooling_strategy: Literal['first', 'mean', 'max', 'sum'] | None) Tensor

Helper method to apply pooling to the embeddings.

Parameters:
  • embeddings (torch.Tensor) – Query or document embeddings

  • attention_mask (torch.Tensor | None) – Query or document attention mask

  • pooling_strategy (Literal['first', 'mean', 'max', 'sum'] | None) – The pooling strategy. No pooling is applied if None.

Raises:

ValueError – If an unknown pooling strategy is passed

Returns:

(Optionally) pooled embeddings

Return type:

torch.Tensor

sparsification(embeddings: Tensor, sparsification_strategy: Literal['relu', 'relu_log'] | None = None) Tensor

Helper method to apply sparsification to the embeddings.

Parameters:
  • embeddings (torch.Tensor) – Query or document embeddings

  • sparsification_strategy (Literal['relu', 'relu_log'] | None, optional) – The sparsification strategy. No sparsification is applied if None, defaults to None

Raises:

ValueError – If an unknown sparsification strategy is passed

Returns:

(Optionally) sparsified embeddings

Return type:

torch.Tensor