-
Notifications
You must be signed in to change notification settings - Fork 266
Expand file tree
/
Copy pathtest_utils.py
More file actions
91 lines (77 loc) · 2.81 KB
/
test_utils.py
File metadata and controls
91 lines (77 loc) · 2.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
Test utilities for OptILLM tests
Provides common functions and constants for consistent testing
"""
import os
import sys
import time
import subprocess
import platform
from typing import Optional
from openai import OpenAI
# Standard test model for all tests - small and fast
TEST_MODEL = "google/gemma-3-270m-it"
TEST_MODEL_MLX = "mlx-community/gemma-3-270m-it-bf16"
def setup_test_env():
"""Set up test environment with local inference"""
os.environ["OPTILLM_API_KEY"] = "optillm"
return TEST_MODEL
def get_test_client(base_url: str = "http://localhost:8000/v1") -> OpenAI:
"""Get OpenAI client configured for local optillm"""
return OpenAI(api_key="optillm", base_url=base_url)
def is_mlx_available():
"""Check if MLX is available (macOS only)"""
if platform.system() != "Darwin":
return False
try:
from optillm.inference import MLX_AVAILABLE
return MLX_AVAILABLE
except ImportError:
return False
def start_test_server(model: str = TEST_MODEL, port: int = 8000) -> subprocess.Popen:
"""
Start optillm server for testing
Returns the process handle
"""
# Set environment for local inference
env = os.environ.copy()
env["OPTILLM_API_KEY"] = "optillm"
# Enable MPS fallback to CPU for unsupported operations (fixes macOS compatibility)
env["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
# Get the project root directory (parent of tests directory)
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Start server from project root where optillm.py is located
proc = subprocess.Popen([
sys.executable, "optillm.py",
"--model", model,
"--port", str(port)
], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=project_root)
# Wait for server to start
time.sleep(5)
return proc
def stop_test_server(proc: subprocess.Popen):
"""Stop the test server"""
try:
proc.terminate()
proc.wait(timeout=5)
except subprocess.TimeoutExpired:
proc.kill()
proc.wait()
def get_simple_test_messages():
"""Get simple test messages for basic validation"""
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Say hello in one word."}
]
def get_math_test_messages():
"""Get math test messages for reasoning validation"""
return [
{"role": "system", "content": "You are a helpful math assistant."},
{"role": "user", "content": "What is 2 + 2? Answer with just the number."}
]
def get_thinking_test_messages():
"""Get test messages that should generate thinking tokens"""
return [
{"role": "system", "content": "Think step by step and use <think></think> tags."},
{"role": "user", "content": "What is 3 * 4? Show your thinking."}
]