Source code for lightning_ir.base.validation_utils
1"""Validation utilities module for Lightning IR.
2
3This module contains utility functions for validation and evaluation of Lightning IR models."""
4
5from typing import Dict, Sequence
6
7import ir_measures
8import numpy as np
9import pandas as pd
10import torch
11
12
[docs]
13def create_run_from_scores(
14 query_ids: Sequence[str], doc_ids: Sequence[Sequence[str]], scores: torch.Tensor
15) -> pd.DataFrame:
16 """Convience function to create a run DataFrame from query and document ids and scores.
17
18 :param query_ids: Query ids
19 :type query_ids: Sequence[str]
20 :param doc_ids: Document ids
21 :type doc_ids: Sequence[Sequence[str]]
22 :param scores: Scores
23 :type scores: torch.Tensor
24 :return: DataFrame with query_id, q0, doc_id, score, rank, and system columns
25 :rtype: pd.DataFrame
26 """
27 num_docs = [len(ids) for ids in doc_ids]
28 df = pd.DataFrame(
29 {
30 "query_id": np.array(query_ids).repeat(num_docs),
31 "q0": 0,
32 "doc_id": sum(map(lambda x: list(x), doc_ids), []),
33 "score": scores.float().numpy(force=True).reshape(-1),
34 "system": "lightning_ir",
35 }
36 )
37 df["rank"] = df.groupby("query_id")["score"].rank(ascending=False, method="first")
38
39 def key(series: pd.Series) -> pd.Series:
40 if series.name == "query_id":
41 return series.map({query_id: i for i, query_id in enumerate(query_ids)})
42 return series
43
44 df = df.sort_values(["query_id", "rank"], ascending=[True, True], key=key)
45 return df
46
47
[docs]
48def create_qrels_from_dicts(qrels: Sequence[Dict[str, int]]) -> pd.DataFrame:
49 """Convience function to create a qrels DataFrame from a list of dictionaries.
50
51 :param qrels: Mappings of doc_id -> relevance for each query, defaults to None
52 :type qrels: Sequence[Dict[str, int]]
53 :return: DataFrame with query_id, q0, doc_id, and relevance columns
54 :rtype: pd.DataFrame
55 """
56 return pd.DataFrame.from_records(qrels)
57
58
[docs]
59def evaluate_run(run: pd.DataFrame, qrels: pd.DataFrame, measures: Sequence[str]) -> Dict[str, float]:
60 """Convience function to evaluate a run against qrels using a set of measures.
61
62 .. _ir-measures: https://ir-measur.es/en/latest/index.html
63
64 :param run: Parse TREC run
65 :type run: pd.DataFrame
66 :param qrels: Parse TREC qrels
67 :type qrels: pd.DataFrame
68 :param measures: Metrics corresponding to ir-measures_ measure strings
69 :type measures: Sequence[str]
70 :return: Calculated metrics
71 :rtype: Dict[str, float]
72 """
73 parsed_measures = [ir_measures.parse_measure(measure) for measure in measures]
74 metrics = {str(measure): measure.calc_aggregate(qrels, run) for measure in parsed_measures}
75 return metrics