Skip to content

Commit e8fe963

Browse files
Replace usage of OpenAI SDK with direct API calls via httpx. (#198)
* Replace usage of OpenAI SDK with direct API calls via httpx. * Add `override` back. * Fix pyright errors.
1 parent 8d6eb85 commit e8fe963

2 files changed

Lines changed: 126 additions & 37 deletions

File tree

python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ classifiers = [
2323
]
2424
dependencies = [
2525
"pydantic>=2.5.2",
26+
"httpx>=0.27.0",
2627
]
2728

2829
[project.urls]
@@ -39,7 +40,6 @@ virtual = ".hatch"
3940
[tool.hatch.envs.default]
4041
dependencies = [
4142
"coverage[toml]>=6.5",
42-
"openai>=1.3.6",
4343
"python-dotenv>=1.0.0",
4444
"pytest",
4545
"spotipy", # for examples
Lines changed: 125 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
from typing_extensions import Protocol, override
2-
import openai
1+
import asyncio
2+
from types import TracebackType
3+
from typing_extensions import AsyncContextManager, Literal, Protocol, Self, TypedDict, cast, override
34

45
from typechat._internal.result import Failure, Result, Success
56

7+
import httpx
68

79
class TypeChatLanguageModel(Protocol):
8-
async def complete(self, input: str) -> Result[str]:
10+
async def complete(self, prompt: str) -> Result[str]:
911
"""
1012
Represents a AI language model that can complete prompts.
1113
@@ -16,30 +18,88 @@ async def complete(self, input: str) -> Result[str]:
1618
"""
1719
...
1820

21+
class _PromptSection(TypedDict):
22+
"""
23+
Represents a section of an LLM prompt with an associated role. TypeChat uses the "user" role for
24+
prompts it generates and the "assistant" role for previous LLM responses (which will be part of
25+
the prompt in repair attempts). TypeChat currently doesn't use the "system" role.
26+
"""
27+
role: Literal["system", "user", "assistant"]
28+
content: str
29+
30+
_TRANSIENT_ERROR_CODES = [
31+
429,
32+
500,
33+
502,
34+
503,
35+
504,
36+
]
37+
38+
class HttpxLanguageModel(TypeChatLanguageModel, AsyncContextManager):
39+
url: str
40+
headers: dict[str, str]
41+
default_params: dict[str, str]
42+
_async_client: httpx.AsyncClient
43+
_max_retry_attempts: int = 3
44+
_retry_pause_seconds: float = 1.0
45+
46+
def __init__(self, url: str, headers: dict[str, str], default_params: dict[str, str]):
47+
super().__init__()
48+
self.url = url
49+
self.headers = headers
50+
self.default_params = default_params
51+
self._async_client = httpx.AsyncClient()
1952

20-
class DefaultOpenAIModel(TypeChatLanguageModel):
21-
model_name: str
22-
client: openai.AsyncOpenAI | openai.AsyncAzureOpenAI
53+
@override
54+
async def complete(self, prompt: str) -> Success[str] | Failure:
55+
headers = {
56+
"Content-Type": "application/json",
57+
**self.headers,
58+
}
59+
messages = [{"role": "user", "content": prompt}]
60+
body = {
61+
**self.default_params,
62+
"messages": messages,
63+
"temperature": 0.0,
64+
"n": 1,
65+
}
66+
retry_count = 0
67+
while True:
68+
try:
69+
response = await self._async_client.post(
70+
self.url,
71+
headers=headers,
72+
json=body,
73+
)
74+
if response.is_success:
75+
json_result = cast(
76+
dict[Literal["choices"], list[dict[Literal["message"], _PromptSection]]],
77+
response.json()
78+
)
79+
return Success(json_result["choices"][0]["message"]["content"] or "")
80+
81+
if response.status_code not in _TRANSIENT_ERROR_CODES or retry_count >= self._max_retry_attempts:
82+
return Failure(f"REST API error {response.status_code}: {response.reason_phrase}")
83+
except Exception as e:
84+
if retry_count >= self._max_retry_attempts:
85+
return Failure(str(e))
86+
87+
await asyncio.sleep(self._retry_pause_seconds)
88+
retry_count += 1
2389

24-
def __init__(self, model_name: str, client: openai.AsyncOpenAI | openai.AsyncAzureOpenAI):
25-
super().__init__()
26-
self.model_name = model_name
27-
self.client = client
90+
@override
91+
async def __aenter__(self) -> Self:
92+
return self
2893

2994
@override
30-
async def complete(self, input: str) -> Result[str]:
95+
async def __aexit__(self, __exc_type: type[BaseException] | None, __exc_value: BaseException | None, __traceback: TracebackType | None) -> bool | None:
96+
await self._async_client.aclose()
97+
98+
def __del__(self):
3199
try:
32-
response = await self.client.chat.completions.create(
33-
model=self.model_name,
34-
messages=[{"role": "user", "content": input}],
35-
temperature=0.0,
36-
)
37-
content = response.choices[0].message.content
38-
if content is None:
39-
return Failure("Response did not contain any text.")
40-
return Success(content)
41-
except Exception as e:
42-
return Failure(str(e))
100+
asyncio.get_running_loop().create_task(self._async_client.aclose())
101+
except Exception:
102+
pass
43103

44104
def create_language_model(vals: dict[str, str | None]) -> TypeChatLanguageModel:
45105
"""
@@ -58,29 +118,58 @@ def create_language_model(vals: dict[str, str | None]) -> TypeChatLanguageModel:
58118
Args:
59119
vals: A dictionary of variables. Typically just `os.environ`.
60120
"""
61-
model: TypeChatLanguageModel
62-
client: openai.AsyncOpenAI | openai.AsyncAzureOpenAI
63-
121+
64122
def required_var(name: str) -> str:
65123
val = vals.get(name, None)
66124
if val is None:
67125
raise ValueError(f"Missing environment variable {name}.")
68126
return val
69127

70128
if "OPENAI_API_KEY" in vals:
71-
client = openai.AsyncOpenAI(api_key=required_var("OPENAI_API_KEY"))
72-
model = DefaultOpenAIModel(model_name=required_var("OPENAI_MODEL"), client=client)
129+
api_key = required_var("OPENAI_API_KEY")
130+
model = required_var("OPENAI_MODEL")
131+
endpoint = vals.get("OPENAI_ENDPOINT", None) or "https://api.openai.com/v1/chat/completions"
132+
org = vals.get("OPENAI_ORG", None) or ""
133+
return create_openai_language_model(api_key, model, endpoint, org)
73134

74135
elif "AZURE_OPENAI_API_KEY" in vals:
75-
openai.api_type = "azure"
76-
client = openai.AsyncAzureOpenAI(
77-
api_key=required_var("AZURE_OPENAI_API_KEY"),
78-
azure_endpoint=required_var("AZURE_OPENAI_ENDPOINT"),
79-
api_version="2023-03-15-preview",
80-
)
81-
model = DefaultOpenAIModel(model_name=vals.get("AZURE_OPENAI_MODEL", None) or "gpt-35-turbo", client=client)
82-
136+
api_key=required_var("AZURE_OPENAI_API_KEY")
137+
endpoint=required_var("AZURE_OPENAI_ENDPOINT")
138+
return create_azure_openai_language_model(api_key, endpoint)
83139
else:
84140
raise ValueError("Missing environment variables for OPENAI_API_KEY or AZURE_OPENAI_API_KEY.")
85141

86-
return model
142+
def create_openai_language_model(api_key: str, model: str, endpoint: str = "https://api.openai.com/v1/chat/completions", org: str = ""):
143+
"""
144+
Creates a language model encapsulation of an OpenAI REST API endpoint.
145+
146+
Args:
147+
api_key: The OpenAI API key.
148+
model: The OpenAI model name.
149+
endpoint: The OpenAI REST API endpoint.
150+
org: The OpenAI organization.
151+
"""
152+
headers = {
153+
"Authorization": f"Bearer {api_key}",
154+
"OpenAI-Organization": org,
155+
}
156+
default_params = {
157+
"model": model,
158+
}
159+
return HttpxLanguageModel(url=endpoint, headers=headers, default_params=default_params)
160+
161+
def create_azure_openai_language_model(api_key: str, endpoint: str):
162+
"""
163+
Creates a language model encapsulation of an Azure OpenAI REST API endpoint.
164+
165+
Args:
166+
api_key: The Azure OpenAI API key.
167+
endpoint: The Azure OpenAI REST API endpoint.
168+
"""
169+
headers = {
170+
# Needed when using managed identity
171+
"Authorization": f"Bearer {api_key}",
172+
# Needed when using regular API key
173+
"api-key": api_key,
174+
}
175+
return HttpxLanguageModel(url=endpoint, headers=headers, default_params={})

0 commit comments

Comments
 (0)