Source code for lightning_ir.base.validation_utils

 1"""Validation utilities module for Lightning IR.
 2
 3This module contains utility functions for validation and evaluation of Lightning IR models."""
 4
 5from typing import Dict, Sequence
 6
 7import ir_measures
 8import numpy as np
 9import pandas as pd
10import torch
11
12
[docs] 13def create_run_from_scores( 14 query_ids: Sequence[str], doc_ids: Sequence[Sequence[str]], scores: torch.Tensor 15) -> pd.DataFrame: 16 """Convience function to create a run DataFrame from query and document ids and scores. 17 18 :param query_ids: Query ids 19 :type query_ids: Sequence[str] 20 :param doc_ids: Document ids 21 :type doc_ids: Sequence[Sequence[str]] 22 :param scores: Scores 23 :type scores: torch.Tensor 24 :return: DataFrame with query_id, q0, doc_id, score, rank, and system columns 25 :rtype: pd.DataFrame 26 """ 27 num_docs = [len(ids) for ids in doc_ids] 28 df = pd.DataFrame( 29 { 30 "query_id": np.array(query_ids).repeat(num_docs), 31 "q0": 0, 32 "doc_id": sum(map(lambda x: list(x), doc_ids), []), 33 "score": scores.float().numpy(force=True).reshape(-1), 34 "system": "lightning_ir", 35 } 36 ) 37 df["rank"] = df.groupby("query_id")["score"].rank(ascending=False, method="first") 38 39 def key(series: pd.Series) -> pd.Series: 40 if series.name == "query_id": 41 return series.map({query_id: i for i, query_id in enumerate(query_ids)}) 42 return series 43 44 df = df.sort_values(["query_id", "rank"], ascending=[True, True], key=key) 45 return df
46 47
[docs] 48def create_qrels_from_dicts(qrels: Sequence[Dict[str, int]]) -> pd.DataFrame: 49 """Convience function to create a qrels DataFrame from a list of dictionaries. 50 51 :param qrels: Mappings of doc_id -> relevance for each query, defaults to None 52 :type qrels: Sequence[Dict[str, int]] 53 :return: DataFrame with query_id, q0, doc_id, and relevance columns 54 :rtype: pd.DataFrame 55 """ 56 return pd.DataFrame.from_records(qrels)
57 58
[docs] 59def evaluate_run(run: pd.DataFrame, qrels: pd.DataFrame, measures: Sequence[str]) -> Dict[str, float]: 60 """Convience function to evaluate a run against qrels using a set of measures. 61 62 .. _ir-measures: https://ir-measur.es/en/latest/index.html 63 64 :param run: Parse TREC run 65 :type run: pd.DataFrame 66 :param qrels: Parse TREC qrels 67 :type qrels: pd.DataFrame 68 :param measures: Metrics corresponding to ir-measures_ measure strings 69 :type measures: Sequence[str] 70 :return: Calculated metrics 71 :rtype: Dict[str, float] 72 """ 73 parsed_measures = [ir_measures.parse_measure(measure) for measure in measures] 74 metrics = {str(measure): measure.calc_aggregate(qrels, run) for measure in parsed_measures} 75 return metrics