Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix cold start searching timeout bug
  • Loading branch information
Tianyang-Zhang committed Mar 26, 2026
commit 254c54e68ec91a9d4dc2624225096750699718d1
35 changes: 27 additions & 8 deletions evaluation/retrieval_skill/hotpotQA_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,25 @@ async def hotpotqa_ingest(
print(
f"Completed HotpotQA ingestion, added {len(dataset)} questions, {added_content} episodes."
)
warmup_query = next(
(
str(record.get("question", "")).strip()
for record in dataset
if str(record.get("question", "")).strip()
),
"",
)
if warmup_query:
elapsed = await skill_utils.warmup_rest_evaluation_search(
session_id=session_id,
query=warmup_query,
)
if elapsed is None:
print(
f"Warmup search for {session_id} did not finish before the retry budget was exhausted."
)
else:
print(f"Warmed up search for {session_id} in {elapsed:.3f}s")


async def hotpotqa_search(
Expand All @@ -130,7 +149,7 @@ async def hotpotqa_search(
length: int | None = None,
runner_kwargs: dict | None = None,
session_id: str = "hotpotqa_group",
concurrency: int | None = None,
concurrency: int = 10,
answer_llm: object | None = None,
):
if dataset is None:
Expand All @@ -151,12 +170,13 @@ async def hotpotqa_search(
attribute_matrix = skill_utils.init_attribute_matrix()
responses: list[tuple[int, dict[str, any]]] = []
num_searched = 0
default_batch = 30 if pure_llm else 2
question_batch_size = concurrency if concurrency is not None else default_batch
question_batch_size = concurrency
vector_graph_store = skill_utils.init_vector_graph_store(
neo4j_uri="bolt://localhost:7687"
)
answer_model_name = answer_llm.model_name if answer_llm is not None else "gpt-5-mini"
answer_model_name = (
answer_llm.model_name if answer_llm is not None else "gpt-5-mini"
)
_, model, query_skill = await skill_utils.init_memmachine_params(
vector_graph_store=vector_graph_store,
model_name=answer_model_name,
Expand Down Expand Up @@ -280,8 +300,8 @@ async def main():
parser.add_argument(
"--concurrency",
type=int,
default=None,
help="Maximum number of concurrent search requests (default: 2 for retrieval_skill, 30 for llm)",
default=10,
help="Maximum number of concurrent search requests (default: 10)",
)
parser.add_argument(
"--config",
Expand Down Expand Up @@ -311,8 +331,7 @@ async def main():
print(f"Length: {args.length}")
print(f"Dataset split: {args.split_name}")
print(f"Test target: {args.test_target}")
if args.concurrency is not None:
print(f"Concurrency: {args.concurrency}")
print(f"Concurrency: {args.concurrency}")
if answer_llm is not None:
print(f"Answer model: {answer_llm.provider} / {answer_llm.model_name}")

Expand Down
34 changes: 30 additions & 4 deletions evaluation/retrieval_skill/locomo_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def datetime_from_locomo_time(locomo_time_str: str) -> datetime:
)


async def main():
async def main(): # noqa: C901
parser = argparse.ArgumentParser()

parser.add_argument("--data-path", required=True, help="Path to the data file")
Expand All @@ -37,7 +37,7 @@ async def main():

async def process_conversation(idx, item):
if "conversation" not in item:
return
return None

conversation = item["conversation"]
speaker_a = conversation["speaker_a"]
Expand All @@ -48,6 +48,14 @@ async def process_conversation(idx, item):
)

group_id = f"group_{idx}"
warmup_query = ""
for qa in item.get("qa", []):
if str(qa.get("category")) == "5":
continue
question = str(qa.get("question", "")).strip()
if question:
warmup_query = question
break

session_idx = 0

Expand Down Expand Up @@ -102,11 +110,29 @@ async def process_conversation(idx, item):
},
)
for message_index, message in enumerate(session)
]
],
)

return group_id, warmup_query

tasks = [process_conversation(idx, item) for idx, item in enumerate(locomo_data)]
await asyncio.gather(*tasks)
session_warmups = await asyncio.gather(*tasks)
for session_warmup in session_warmups:
if session_warmup is None:
continue
group_id, warmup_query = session_warmup
if not warmup_query:
continue
elapsed = await skill_utils.warmup_rest_evaluation_search(
session_id=group_id,
query=warmup_query,
)
if elapsed is None:
print(
f"Warmup search for {group_id} did not finish before the retry budget was exhausted."
)
else:
print(f"Warmed up search for {group_id} in {elapsed:.3f}s")


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions evaluation/retrieval_skill/locomo_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ async def run_locomo( # noqa: C901
parser.add_argument(
"--concurrency",
type=int,
default=1,
help="Maximum number of concurrent search requests (default: 1)",
default=10,
help="Maximum number of concurrent search requests (default: 10)",
)
parser.add_argument(
"--config",
Expand Down
33 changes: 26 additions & 7 deletions evaluation/retrieval_skill/longmemeval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,22 +157,43 @@ async def longmemeval_ingest(dataset: list[dict[str, Any]], session_id: str):
print(
f"Completed LongMemEval ingestion, added {len(dataset)} questions, {added_content} episodes."
)
warmup_query = next(
(
str(sample.get("question", "")).strip()
for sample in dataset
if str(sample.get("question", "")).strip()
),
"",
)
if warmup_query:
elapsed = await skill_utils.warmup_rest_evaluation_search(
session_id=session_id,
query=warmup_query,
)
if elapsed is None:
print(
f"Warmup search for {session_id} did not finish before the retry budget was exhausted."
)
else:
print(f"Warmed up search for {session_id} in {elapsed:.3f}s")


async def longmemeval_search(
dataset: list[dict[str, Any]],
session_id: str,
eval_result_path: str | None = None,
pure_llm: bool = False,
concurrency: int = 30,
concurrency: int = 10,
answer_llm: object | None = None,
):
tasks = []
attribute_matrix = skill_utils.init_attribute_matrix()
responses: list[tuple[str, dict[str, Any]]] = []
num_searched = 0

answer_model_name = answer_llm.model_name if answer_llm is not None else "gpt-5-mini"
answer_model_name = (
answer_llm.model_name if answer_llm is not None else "gpt-5-mini"
)
vector_graph_store = skill_utils.init_vector_graph_store(
neo4j_uri="bolt://localhost:7687"
)
Expand All @@ -183,9 +204,7 @@ async def longmemeval_search(
build_runner=not pure_llm,
)
if not pure_llm and query_skill is None:
raise RuntimeError(
"LongMemEval benchmark requires an initialized SkillRunner."
)
raise RuntimeError("LongMemEval benchmark requires an initialized SkillRunner.")

for sample in dataset:
question = str(sample.get("question", "")).strip()
Expand Down Expand Up @@ -333,8 +352,8 @@ async def main():
parser.add_argument(
"--concurrency",
type=int,
default=30,
help="Maximum number of concurrent search requests (default: 30)",
default=10,
help="Maximum number of concurrent search requests (default: 10)",
)
parser.add_argument(
"--config",
Expand Down
10 changes: 5 additions & 5 deletions evaluation/retrieval_skill/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ usage_locomo() {
echo " TEST_TARGET [retrieval_skill | llm]"
echo "Options:"
echo " --search-concurrency N"
echo " Max concurrent search requests (default: 1)"
echo " Max concurrent search requests (default: 10)"
echo " --judge-concurrency N"
echo " Max concurrent LLM judge workers (default: 30)"
echo " --config PATH Path to benchmark_config.yml for answer/evaluation model"
Expand Down Expand Up @@ -46,7 +46,7 @@ usage_hotpotqa() {
echo " LENGTH Number of examples to run [train set 1 - 90447 | validation set 1 - 7405]"
echo "Options:"
echo " --search-concurrency N"
echo " Max concurrent search requests (default: 30)"
echo " Max concurrent search requests (default: 10)"
echo " --judge-concurrency N"
echo " Max concurrent LLM judge workers (default: 30)"
echo " --config PATH Path to benchmark_config.yml for answer/evaluation model"
Expand All @@ -64,7 +64,7 @@ usage_longmemeval() {
echo " LENGTH Number of examples to run [1 - split size]"
echo "Options:"
echo " --search-concurrency N"
echo " Max concurrent search requests (default: 30)"
echo " Max concurrent search requests (default: 10)"
echo " --judge-concurrency N"
echo " Max concurrent LLM judge workers (default: 30)"
echo " --config PATH Path to benchmark_config.yml for answer/evaluation model"
Expand Down Expand Up @@ -107,13 +107,13 @@ show_help() {
}

POSITIONAL_ARGS=()
SEARCH_CONCURRENCY=""
SEARCH_CONCURRENCY="10"
JUDGE_CONCURRENCY=""
BENCHMARK_CONFIG=""

parse_optional_flags() {
POSITIONAL_ARGS=()
SEARCH_CONCURRENCY=""
SEARCH_CONCURRENCY="10"
JUDGE_CONCURRENCY=""
BENCHMARK_CONFIG=""

Expand Down
15 changes: 14 additions & 1 deletion evaluation/retrieval_skill/wikimultihop_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def main():

data_path = args.data_path

contexts, _, _, _, _ = load_data(
contexts, questions, _, _, _ = load_data(
data_path=data_path, start_line=1, end_line=args.length, randomize="SENTENCE"
)
print("Loaded", len(contexts), "contexts, start ingestion...")
Expand Down Expand Up @@ -159,6 +159,19 @@ async def main():

print(f"Completed WIKI-Multihop ingestion, added {len(added_contexts)} episodes.")

warmup_query = next((question.strip() for question in questions if question), "")
if warmup_query:
elapsed = await skill_utils.warmup_rest_evaluation_search(
session_id=args.session_id,
query=warmup_query,
)
if elapsed is None:
print(
f"Warmup search for {args.session_id} did not finish before the retry budget was exhausted."
)
else:
print(f"Warmed up search for {args.session_id} in {elapsed:.3f}s")


if __name__ == "__main__":
load_dotenv()
Expand Down
29 changes: 20 additions & 9 deletions evaluation/retrieval_skill/wikimultihop_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"""


async def run_wiki(
async def run_wiki( # noqa: C901
dpath: str | None = None,
epath: str | None = None,
data_path: str | None = None,
Expand All @@ -84,13 +84,15 @@ async def run_wiki(
model_name: str = "gpt-5.2",
runner_kwargs: dict | None = None,
session_id: str = "group1",
concurrency: int = 10,
answer_llm: object | None = None,
) -> tuple[str, dict[str, Any]]:
if data_path is not None:
_data_path = data_path
_eval_result_path = eval_result_path
_length = length or 100
_test_target = "retrieval_skill"
_concurrency = concurrency
else:
parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -121,8 +123,8 @@ async def run_wiki(
parser.add_argument(
"--concurrency",
type=int,
default=None,
help="Maximum number of concurrent search requests (default: 2 for retrieval_skill, 10 for llm)",
default=10,
help="Maximum number of concurrent search requests (default: 10)",
)
parser.add_argument(
"--config",
Expand All @@ -137,6 +139,7 @@ async def run_wiki(
_length = args.length
_test_target = args.test_target
session_id = args.session_id
_concurrency = args.concurrency

if answer_llm is None and args.config:
from evaluation.retrieval_skill.benchmark_config import (
Expand All @@ -155,10 +158,9 @@ async def run_wiki(
print(f"Evaluation result path: {_eval_result_path}")
print(f"Length: {_length}")
print(f"Test target: {_test_target}")
print(f"Concurrency: {_concurrency}")
if answer_llm is not None:
print(f"Answer model: {answer_llm.provider} / {answer_llm.model_name}")
if data_path is None and args.concurrency is not None:
print(f"Concurrency: {args.concurrency}")
effective_runner_kwargs = dict(runner_kwargs or {})
if _test_target == "retrieval_skill":
effective_runner_kwargs.setdefault("stage_result_mode", True)
Expand All @@ -178,21 +180,30 @@ async def run_wiki(
build_runner=_test_target == "retrieval_skill",
)
if _test_target == "retrieval_skill" and query_skill is None:
raise RuntimeError("WikiMultiHop benchmark requires an initialized SkillRunner.")
raise RuntimeError(
"WikiMultiHop benchmark requires an initialized SkillRunner."
)

contexts, questions, answers, types, supporting_facts = load_data(
data_path=_data_path, start_line=1, end_line=_length, randomize="NONE"
)
warmup_query = next((question.strip() for question in questions if question), "")
if _test_target == "retrieval_skill" and warmup_query:
elapsed = await skill_utils.warmup_rest_evaluation_search(
session_id=session_id,
query=warmup_query,
raise_on_failure=True,
)
if elapsed is not None:
print(f"Search backend ready for {session_id} in {elapsed:.3f}s")
print(f"Loaded {len(questions)} questions, start querying...")

tasks = []
results: dict[str, Any] = {}
attribute_matrix = skill_utils.init_attribute_matrix()
full_content = "\n".join(contexts)
num_processed = 0
default_batch = 2 if _test_target == "retrieval_skill" else 10
_concurrency = (args.concurrency if data_path is None and args.concurrency is not None else None)
question_batch_size = _concurrency if _concurrency is not None else default_batch
question_batch_size = _concurrency
for q, a, t, f_list in zip(
questions, answers, types, supporting_facts, strict=True
):
Expand Down
Loading
Loading