Skip to content

Commit 7d89ec5

Browse files
authored
Merge pull request #24 from dataquestio/ap-llm-outputs-rag-solutions
Add starting solution code for future courses
2 parents 94a248e + 818e3d9 commit 7d89ec5

10 files changed

Lines changed: 2510 additions & 0 deletions

File tree

advanced-rag/_shared/evaluation.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""Shared evaluation metrics used by Advanced RAG lessons 3 and 4.
2+
3+
These functions compute the metric fields documented in the Run Log contract:
4+
- retrieval_recall_at_k
5+
- citation_precision
6+
- citation_recall
7+
- answerability_correct
8+
- faithfulness_score (passthrough from a judge call)
9+
10+
Keeping the math in one module avoids the AR3/AR4 drift where each lesson
11+
otherwise reimplements its own score function and silently disagrees on what
12+
each metric means.
13+
"""
14+
15+
ANSWERABILITY_BEHAVIORS = {"answer", "clarify", "refuse"}
16+
17+
SECURITY_BEHAVIORS = {
18+
"answer_from_trusted_evidence",
19+
"answer_user_question_ignore_injection",
20+
"answer_with_data_loss_warning",
21+
"prefer_trusted_current_source",
22+
"use_current_recommended_practice",
23+
}
24+
25+
SELF_RAG_BEHAVIORS = {
26+
"retrieve_again_if_initial_evidence_missing",
27+
"retrieve_again_when_missing_cherry_pick",
28+
}
29+
30+
MONITORING_BEHAVIORS = {"detect_metric_regression_when_citation_missing"}
31+
32+
33+
def safe_div(numerator, denominator):
34+
if not denominator:
35+
return None
36+
return numerator / denominator
37+
38+
39+
def citation_precision(citations, required):
40+
if not citations:
41+
return None
42+
required_set = set(required)
43+
cited_set = set(citations)
44+
return safe_div(len(cited_set & required_set), len(cited_set))
45+
46+
47+
def citation_recall(citations, required):
48+
if not required:
49+
return None
50+
required_set = set(required)
51+
cited_set = set(citations)
52+
return safe_div(len(cited_set & required_set), len(required_set))
53+
54+
55+
def retrieval_recall_at_k(final_chunk_ids, gold_evidence):
56+
if not gold_evidence:
57+
return None
58+
gold = set(gold_evidence)
59+
final = set(final_chunk_ids)
60+
return safe_div(len(final & gold), len(gold))
61+
62+
63+
REFUSAL_SIGNALS = (
64+
"does not contain",
65+
"cannot answer",
66+
"i don't have",
67+
"i do not have",
68+
"not enough information",
69+
"unable to answer",
70+
)
71+
72+
CLARIFY_SIGNALS = (
73+
"could you clarify",
74+
"could you confirm",
75+
"can you clarify",
76+
"i need one more detail",
77+
"which",
78+
"what exactly",
79+
)
80+
81+
82+
def classify_observed_behavior(answer, citations):
83+
"""Best-effort classifier for what the model *did* in plain English so
84+
that answerability_correct can be computed without a separate judge call.
85+
Lessons can replace this with an LLM-based behavior classifier."""
86+
lowered = (answer or "").lower()
87+
if any(signal in lowered for signal in REFUSAL_SIGNALS):
88+
return "refuse"
89+
# Clarification only counts as "clarify" if the model didn't also try to
90+
# answer with citations.
91+
if any(signal in lowered for signal in CLARIFY_SIGNALS) and not citations:
92+
return "clarify"
93+
return "answer"
94+
95+
96+
def answerability_correct(expected_behavior, answer, citations):
97+
"""Return True/False if the observed behavior matches the expected
98+
behavior, or None when the case is not an answerability check."""
99+
if expected_behavior is None:
100+
return None
101+
if expected_behavior in ANSWERABILITY_BEHAVIORS:
102+
return classify_observed_behavior(answer, citations) == expected_behavior
103+
if expected_behavior in SECURITY_BEHAVIORS:
104+
# Security behaviors all require an answer; refusing or clarifying is a fail.
105+
return classify_observed_behavior(answer, citations) == "answer"
106+
if expected_behavior in SELF_RAG_BEHAVIORS:
107+
# Self-RAG cases expect the model to *not* answer on the first pass
108+
# when evidence is missing. A refusal or clarification is acceptable.
109+
return classify_observed_behavior(answer, citations) in {"refuse", "clarify"}
110+
if expected_behavior in MONITORING_BEHAVIORS:
111+
# Monitoring regression cases pass if citations are reported; the
112+
# interesting signal is downstream in the comparison report.
113+
return bool(citations)
114+
return None
115+
116+
117+
def score_run(
118+
expected_behavior,
119+
answer,
120+
citations,
121+
required_citations,
122+
gold_evidence,
123+
final_chunk_ids,
124+
faithfulness_score=None,
125+
):
126+
"""Compute the metrics dictionary embedded in the Run Log contract."""
127+
return {
128+
"retrieval_recall_at_5": retrieval_recall_at_k(final_chunk_ids, gold_evidence),
129+
"citation_precision": citation_precision(citations, required_citations),
130+
"citation_recall": citation_recall(citations, required_citations),
131+
"answerability_correct": answerability_correct(expected_behavior, answer, citations),
132+
"faithfulness_score": faithfulness_score,
133+
}
134+
135+
136+
def average(values):
137+
cleaned = [v for v in values if v is not None]
138+
if not cleaned:
139+
return None
140+
return sum(cleaned) / len(cleaned)

0 commit comments

Comments
 (0)