Skip to content

Commit 13dac96

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add support for custom result parsing in LLM-based evaluation metrics
FUTURE_COPYBARA_INTEGRATE_REVIEW=#6486 from googleapis:release-please--branches--main fb49c58 PiperOrigin-RevId: 892593906
1 parent 1fba45b commit 13dac96

3 files changed

Lines changed: 56 additions & 58 deletions

File tree

tests/unit/vertexai/genai/replays/test_evaluate.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -361,41 +361,41 @@ def test_evaluation_metric_resource_name(client):
361361
)
362362
tone_check_metric = types.LLMMetric(
363363
name="tone_check",
364-
prompt_template="""
365-
# Instruction
366-
You are a professional writing evaluator. Your job is to score writing responses according to pre-defined evaluation criteria.
367-
368-
# Criteria
369-
Analyze the tone of the response based on these two criteria:
370-
1. Professionalism: The response should use appropriate language and maintain a business-like demeanor.
371-
2. Empathy: The response should acknowledge the user's feelings and show understanding.
372-
373-
# Input
374-
Prompt: {agent_data.turns[0].events[0]}
375-
Response: {agent_data.turns[0].events[1]}
376-
377-
# Output Format
378-
Respond in a JSON format with the following schema:
379-
{
380-
"type": "OBJECT",
381-
"properties": {
382-
"score": {"type": "NUMBER"},
383-
"explanation": {"type": "STRING"},
384-
},
385-
"required": ["score", "explanation"],
386-
}
387-
Return the JSON format output in a string representation of a Python dictionary directly, without strings like '```json' or '```'.
388-
389-
The output would include the following fields:
390-
score: based on your evaluation, the score should be a number based on the rating rubrics.
391-
explanation: your explanation for the score rating, in one line.
392-
393-
## Example Output Format:
394-
{"score" : -1, "explanation": "Here is the reason that the response is given a score of -1 based on the rating rubric."}
395-
{"score" : 3, "explanation": "Here is the reason that the response is given a score of 3 based on the rating rubric."}
396-
{"score" : 0, "explanation": "Here is the reason that the response is given a score of 0 based on the rating rubric."}
397-
{"score" : 5, "explanation": "Here is the reason that the response is given a score of 5 based on the rating rubric."}
398-
""",
364+
prompt_template="""Analyze the tone of the response based on these two criteria:\n
365+
1. Professionalism: The response should use appropriate language and maintain a business-like demeanor.\n
366+
2. Empathy: The response should acknowledge the user's feelings and show understanding.\n\n
367+
Prompt: {agent_data.turns[0].events[0]}
368+
Response: {agent_data.turns[0].events[1]}
369+
Return ONLY a JSON list of objects for these two properties:
370+
'[{"property": "Professionalism", "verdict": true, "reasoning": "..."}, '
371+
'{"property": "Empathy", "verdict": true, "reasoning": "..."}]'
372+
""",
373+
result_parsing_function="""
374+
import json, re
375+
def parse_results(responses):
376+
text = responses[0]
377+
# Use robust regex to find the JSON list block
378+
match = re.search("[\\[].*[]]", text, re.DOTALL)
379+
if not match: return {"score": 0.0, "explanation": "No valid JSON found"}
380+
381+
try:
382+
data = json.loads(match.group(0))
383+
# Calculate an overall score (e.g., average of verdicts)
384+
passed_count = sum(1 for r in data if r.get("verdict", False))
385+
total_count = len(data)
386+
score = passed_count / total_count if total_count > 0 else 0.0
387+
388+
# Consolidate reasoning into a single explanation string
389+
explanation = "\\n".join([f"{r.get('property')}: {r.get('reasoning')}" for r in data])
390+
391+
# IMPORTANT: Return a dictionary, not a list
392+
return {
393+
"score": float(score),
394+
"explanation": explanation
395+
}
396+
except Exception as e:
397+
return {"score": 0.0, "explanation": f"Parsing failed: {str(e)}"}
398+
""",
399399
)
400400
metric_resource_name = client.evals.create_evaluation_metric(
401401
metric=tone_check_metric,

tests/unit/vertexai/genai/replays/test_evaluation_metric.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,12 @@
2424

2525
def test_create_and_get_evaluation_metric(client):
2626
client._api_client._http_options.api_version = "v1beta1"
27-
client._api_client._http_options.base_url = (
28-
"https://us-central1-staging-aiplatform.sandbox.googleapis.com/"
29-
)
3027
result = client.evals.create_evaluation_metric(
3128
display_name="test_metric",
3229
description="test_description",
33-
metric=types.RubricMetric.GENERAL_QUALITY,
30+
metric=types.LLMMetric(
31+
name="custom_llm_metric", prompt_template="test_prompt_template"
32+
),
3433
)
3534
assert isinstance(result, str)
3635
assert re.match(
@@ -44,9 +43,6 @@ def test_create_and_get_evaluation_metric(client):
4443

4544
def test_list_evaluation_metrics(client):
4645
client._api_client._http_options.api_version = "v1beta1"
47-
client._api_client._http_options.base_url = (
48-
"https://us-central1-staging-aiplatform.sandbox.googleapis.com/"
49-
)
5046
response = client.evals.list_evaluation_metrics()
5147
assert isinstance(response, types.ListEvaluationMetricsResponse)
5248
assert len(response.evaluation_metrics) >= 0

vertexai/_genai/_transformers.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,14 @@ def t_metrics(
119119
if autorater_config:
120120
llm_based_spec["judge_autorater_config"] = autorater_config
121121

122+
result_parsing_function = getv(metric, ["result_parsing_function"])
123+
if result_parsing_function:
124+
llm_based_spec["result_parser_config"] = {
125+
"custom_code_parser_config": {
126+
"parsing_function": result_parsing_function
127+
}
128+
}
129+
122130
metric_payload_item["llm_based_metric_spec"] = llm_based_spec
123131
elif getattr(metric, "metric_resource_name", None) is not None:
124132
# Safe pass
@@ -187,22 +195,8 @@ def t_metric_for_registry(
187195
if metric_name:
188196
metric_name = metric_name.lower()
189197

190-
# Handle standard computation metrics
191-
if metric_name == "exact_match":
192-
metric_payload_item["exact_match_spec"] = {}
193-
elif metric_name == "bleu":
194-
metric_payload_item["bleu_spec"] = {}
195-
elif metric_name and metric_name.startswith("rouge"):
196-
rouge_type = metric_name.replace("_", "")
197-
metric_payload_item["rouge_spec"] = {"rouge_type": rouge_type}
198-
# API Pre-defined metrics
199-
elif metric_name and metric_name in _evals_constant.SUPPORTED_PREDEFINED_METRICS:
200-
metric_payload_item["predefined_metric_spec"] = {
201-
"metric_spec_name": metric_name,
202-
"metric_spec_parameters": metric.metric_spec_parameters,
203-
}
204198
# Custom Code Execution Metric
205-
elif hasattr(metric, "remote_custom_function") and metric.remote_custom_function:
199+
if hasattr(metric, "remote_custom_function") and metric.remote_custom_function:
206200
metric_payload_item["custom_code_execution_spec"] = {
207201
"evaluation_function": metric.remote_custom_function
208202
}
@@ -217,7 +211,7 @@ def t_metric_for_registry(
217211
"evaluation_function": metric.custom_function
218212
}
219213

220-
# Map LLM-based metrics to the new llm_based_metric_spec
214+
# LLM-based metric
221215
elif (hasattr(metric, "prompt_template") and metric.prompt_template) or (
222216
hasattr(metric, "rubric_group_name") and metric.rubric_group_name
223217
):
@@ -249,6 +243,14 @@ def t_metric_for_registry(
249243
if autorater_config:
250244
llm_based_spec["judge_autorater_config"] = autorater_config
251245

246+
result_parsing_function = getv(metric, ["result_parsing_function"])
247+
if result_parsing_function:
248+
llm_based_spec["result_parser_config"] = {
249+
"custom_code_parser_config": {
250+
"parsing_function": result_parsing_function
251+
}
252+
}
253+
252254
metric_payload_item["llm_based_metric_spec"] = llm_based_spec
253255

254256
else:

0 commit comments

Comments
 (0)