Source code for lightning_ir.base.validation_utils
1"""Validation utilities module for Lightning IR.
2
3This module contains utility functions for validation and evaluation of Lightning IR models."""
4
5from typing import Dict, Sequence
6
7import ir_measures
8import numpy as np
9import pandas as pd
10import torch
11
12
[docs]
13def create_run_from_scores(
14 query_ids: Sequence[str], doc_ids: Sequence[Sequence[str]], scores: torch.Tensor
15) -> pd.DataFrame:
16 """Convenience function to create a run DataFrame from query and document ids and scores.
17
18 Args:
19 query_ids (Sequence[str]): List of query IDs.
20 doc_ids (Sequence[Sequence[str]]): List of lists containing document IDs for each query.
21 scores (torch.Tensor): Tensor containing scores for each query-document pair.
22 Returns:
23 pd.DataFrame: DataFrame containing the run information with columns:
24 query_id, q0, doc_id, score, rank, and system.
25 """
26 num_docs = [len(ids) for ids in doc_ids]
27 df = pd.DataFrame(
28 {
29 "query_id": np.array(query_ids).repeat(num_docs),
30 "q0": 0,
31 "doc_id": sum(map(lambda x: list(x), doc_ids), []),
32 "score": scores.float().numpy(force=True).reshape(-1),
33 "system": "lightning_ir",
34 }
35 )
36 df["rank"] = df.groupby("query_id")["score"].rank(ascending=False, method="first")
37
38 def key(series: pd.Series) -> pd.Series:
39 if series.name == "query_id":
40 return series.map({query_id: i for i, query_id in enumerate(query_ids)})
41 return series
42
43 df = df.sort_values(["query_id", "rank"], ascending=[True, True], key=key)
44 return df
45
46
[docs]
47def create_qrels_from_dicts(qrels: Sequence[Dict[str, int]]) -> pd.DataFrame:
48 """Convenience function to create a qrels DataFrame from a list of dictionaries.
49
50 Args:
51 qrels (Sequence[Dict[str, int]]): Mappings of doc_id -> relevance for each query. Defaults to None.
52 Returns:
53 pd.DataFrame: DataFrame with columns: query_id, q0, doc_id, and relevance.
54 """
55 return pd.DataFrame.from_records(qrels)
56
57
[docs]
58def evaluate_run(run: pd.DataFrame, qrels: pd.DataFrame, measures: Sequence[str]) -> Dict[str, float]:
59 """Convenience function to evaluate a run against qrels using a set of measures.
60
61 .. _ir-measures: https://ir-measur.es/en/latest/index.html
62
63 Args:
64 run (pd.DataFrame): Parse TREC run.
65 qrels (pd.DataFrame): Parse TREC qrels.
66 measures (Sequence[str]): Metrics corresponding to ir-measures_ measure strings.
67 Returns:
68 Dict[str, float]: Calculated metrics.
69 """
70 parsed_measures = [ir_measures.parse_measure(measure) for measure in measures]
71 metrics = {str(measure): measure.calc_aggregate(qrels, run) for measure in parsed_measures}
72 return metrics