Source code for lightning_ir.base.validation_utils

 1"""Validation utilities module for Lightning IR.
 2
 3This module contains utility functions for validation and evaluation of Lightning IR models."""
 4
 5from typing import Dict, Sequence
 6
 7import ir_measures
 8import numpy as np
 9import pandas as pd
10import torch
11
12
[docs] 13def create_run_from_scores( 14 query_ids: Sequence[str], doc_ids: Sequence[Sequence[str]], scores: torch.Tensor 15) -> pd.DataFrame: 16 """Convenience function to create a run DataFrame from query and document ids and scores. 17 18 Args: 19 query_ids (Sequence[str]): List of query IDs. 20 doc_ids (Sequence[Sequence[str]]): List of lists containing document IDs for each query. 21 scores (torch.Tensor): Tensor containing scores for each query-document pair. 22 Returns: 23 pd.DataFrame: DataFrame containing the run information with columns: 24 query_id, q0, doc_id, score, rank, and system. 25 """ 26 num_docs = [len(ids) for ids in doc_ids] 27 df = pd.DataFrame( 28 { 29 "query_id": np.array(query_ids).repeat(num_docs), 30 "q0": 0, 31 "doc_id": sum(map(lambda x: list(x), doc_ids), []), 32 "score": scores.float().numpy(force=True).reshape(-1), 33 "system": "lightning_ir", 34 } 35 ) 36 df["rank"] = df.groupby("query_id")["score"].rank(ascending=False, method="first") 37 38 def key(series: pd.Series) -> pd.Series: 39 if series.name == "query_id": 40 return series.map({query_id: i for i, query_id in enumerate(query_ids)}) 41 return series 42 43 df = df.sort_values(["query_id", "rank"], ascending=[True, True], key=key) 44 return df
45 46
[docs] 47def create_qrels_from_dicts(qrels: Sequence[Dict[str, int]]) -> pd.DataFrame: 48 """Convenience function to create a qrels DataFrame from a list of dictionaries. 49 50 Args: 51 qrels (Sequence[Dict[str, int]]): Mappings of doc_id -> relevance for each query. Defaults to None. 52 Returns: 53 pd.DataFrame: DataFrame with columns: query_id, q0, doc_id, and relevance. 54 """ 55 return pd.DataFrame.from_records(qrels)
56 57
[docs] 58def evaluate_run(run: pd.DataFrame, qrels: pd.DataFrame, measures: Sequence[str]) -> Dict[str, float]: 59 """Convenience function to evaluate a run against qrels using a set of measures. 60 61 .. _ir-measures: https://ir-measur.es/en/latest/index.html 62 63 Args: 64 run (pd.DataFrame): Parse TREC run. 65 qrels (pd.DataFrame): Parse TREC qrels. 66 measures (Sequence[str]): Metrics corresponding to ir-measures_ measure strings. 67 Returns: 68 Dict[str, float]: Calculated metrics. 69 """ 70 parsed_measures = [ir_measures.parse_measure(measure) for measure in measures] 71 metrics = {str(measure): measure.calc_aggregate(qrels, run) for measure in parsed_measures} 72 return metrics