1- from typing_extensions import Protocol , override
2- import openai
1+ import asyncio
2+ from types import TracebackType
3+ from typing_extensions import AsyncContextManager , Literal , Protocol , Self , TypedDict , cast , override
34
45from typechat ._internal .result import Failure , Result , Success
56
7+ import httpx
68
79class TypeChatLanguageModel (Protocol ):
8- async def complete (self , input : str ) -> Result [str ]:
10+ async def complete (self , prompt : str ) -> Result [str ]:
911 """
1012 Represents a AI language model that can complete prompts.
1113
@@ -16,30 +18,88 @@ async def complete(self, input: str) -> Result[str]:
1618 """
1719 ...
1820
21+ class _PromptSection (TypedDict ):
22+ """
23+ Represents a section of an LLM prompt with an associated role. TypeChat uses the "user" role for
24+ prompts it generates and the "assistant" role for previous LLM responses (which will be part of
25+ the prompt in repair attempts). TypeChat currently doesn't use the "system" role.
26+ """
27+ role : Literal ["system" , "user" , "assistant" ]
28+ content : str
29+
30+ _TRANSIENT_ERROR_CODES = [
31+ 429 ,
32+ 500 ,
33+ 502 ,
34+ 503 ,
35+ 504 ,
36+ ]
37+
38+ class HttpxLanguageModel (TypeChatLanguageModel , AsyncContextManager ):
39+ url : str
40+ headers : dict [str , str ]
41+ default_params : dict [str , str ]
42+ _async_client : httpx .AsyncClient
43+ _max_retry_attempts : int = 3
44+ _retry_pause_seconds : float = 1.0
45+
46+ def __init__ (self , url : str , headers : dict [str , str ], default_params : dict [str , str ]):
47+ super ().__init__ ()
48+ self .url = url
49+ self .headers = headers
50+ self .default_params = default_params
51+ self ._async_client = httpx .AsyncClient ()
1952
20- class DefaultOpenAIModel (TypeChatLanguageModel ):
21- model_name : str
22- client : openai .AsyncOpenAI | openai .AsyncAzureOpenAI
53+ @override
54+ async def complete (self , prompt : str ) -> Success [str ] | Failure :
55+ headers = {
56+ "Content-Type" : "application/json" ,
57+ ** self .headers ,
58+ }
59+ messages = [{"role" : "user" , "content" : prompt }]
60+ body = {
61+ ** self .default_params ,
62+ "messages" : messages ,
63+ "temperature" : 0.0 ,
64+ "n" : 1 ,
65+ }
66+ retry_count = 0
67+ while True :
68+ try :
69+ response = await self ._async_client .post (
70+ self .url ,
71+ headers = headers ,
72+ json = body ,
73+ )
74+ if response .is_success :
75+ json_result = cast (
76+ dict [Literal ["choices" ], list [dict [Literal ["message" ], _PromptSection ]]],
77+ response .json ()
78+ )
79+ return Success (json_result ["choices" ][0 ]["message" ]["content" ] or "" )
80+
81+ if response .status_code not in _TRANSIENT_ERROR_CODES or retry_count >= self ._max_retry_attempts :
82+ return Failure (f"REST API error { response .status_code } : { response .reason_phrase } " )
83+ except Exception as e :
84+ if retry_count >= self ._max_retry_attempts :
85+ return Failure (str (e ))
86+
87+ await asyncio .sleep (self ._retry_pause_seconds )
88+ retry_count += 1
2389
24- def __init__ (self , model_name : str , client : openai .AsyncOpenAI | openai .AsyncAzureOpenAI ):
25- super ().__init__ ()
26- self .model_name = model_name
27- self .client = client
90+ @override
91+ async def __aenter__ (self ) -> Self :
92+ return self
2893
2994 @override
30- async def complete (self , input : str ) -> Result [str ]:
95+ async def __aexit__ (self , __exc_type : type [BaseException ] | None , __exc_value : BaseException | None , __traceback : TracebackType | None ) -> bool | None :
96+ await self ._async_client .aclose ()
97+
98+ def __del__ (self ):
3199 try :
32- response = await self .client .chat .completions .create (
33- model = self .model_name ,
34- messages = [{"role" : "user" , "content" : input }],
35- temperature = 0.0 ,
36- )
37- content = response .choices [0 ].message .content
38- if content is None :
39- return Failure ("Response did not contain any text." )
40- return Success (content )
41- except Exception as e :
42- return Failure (str (e ))
100+ asyncio .get_running_loop ().create_task (self ._async_client .aclose ())
101+ except Exception :
102+ pass
43103
44104def create_language_model (vals : dict [str , str | None ]) -> TypeChatLanguageModel :
45105 """
@@ -58,29 +118,58 @@ def create_language_model(vals: dict[str, str | None]) -> TypeChatLanguageModel:
58118 Args:
59119 vals: A dictionary of variables. Typically just `os.environ`.
60120 """
61- model : TypeChatLanguageModel
62- client : openai .AsyncOpenAI | openai .AsyncAzureOpenAI
63-
121+
64122 def required_var (name : str ) -> str :
65123 val = vals .get (name , None )
66124 if val is None :
67125 raise ValueError (f"Missing environment variable { name } ." )
68126 return val
69127
70128 if "OPENAI_API_KEY" in vals :
71- client = openai .AsyncOpenAI (api_key = required_var ("OPENAI_API_KEY" ))
72- model = DefaultOpenAIModel (model_name = required_var ("OPENAI_MODEL" ), client = client )
129+ api_key = required_var ("OPENAI_API_KEY" )
130+ model = required_var ("OPENAI_MODEL" )
131+ endpoint = vals .get ("OPENAI_ENDPOINT" , None ) or "https://api.openai.com/v1/chat/completions"
132+ org = vals .get ("OPENAI_ORG" , None ) or ""
133+ return create_openai_language_model (api_key , model , endpoint , org )
73134
74135 elif "AZURE_OPENAI_API_KEY" in vals :
75- openai .api_type = "azure"
76- client = openai .AsyncAzureOpenAI (
77- api_key = required_var ("AZURE_OPENAI_API_KEY" ),
78- azure_endpoint = required_var ("AZURE_OPENAI_ENDPOINT" ),
79- api_version = "2023-03-15-preview" ,
80- )
81- model = DefaultOpenAIModel (model_name = vals .get ("AZURE_OPENAI_MODEL" , None ) or "gpt-35-turbo" , client = client )
82-
136+ api_key = required_var ("AZURE_OPENAI_API_KEY" )
137+ endpoint = required_var ("AZURE_OPENAI_ENDPOINT" )
138+ return create_azure_openai_language_model (api_key , endpoint )
83139 else :
84140 raise ValueError ("Missing environment variables for OPENAI_API_KEY or AZURE_OPENAI_API_KEY." )
85141
86- return model
142+ def create_openai_language_model (api_key : str , model : str , endpoint : str = "https://api.openai.com/v1/chat/completions" , org : str = "" ):
143+ """
144+ Creates a language model encapsulation of an OpenAI REST API endpoint.
145+
146+ Args:
147+ api_key: The OpenAI API key.
148+ model: The OpenAI model name.
149+ endpoint: The OpenAI REST API endpoint.
150+ org: The OpenAI organization.
151+ """
152+ headers = {
153+ "Authorization" : f"Bearer { api_key } " ,
154+ "OpenAI-Organization" : org ,
155+ }
156+ default_params = {
157+ "model" : model ,
158+ }
159+ return HttpxLanguageModel (url = endpoint , headers = headers , default_params = default_params )
160+
161+ def create_azure_openai_language_model (api_key : str , endpoint : str ):
162+ """
163+ Creates a language model encapsulation of an Azure OpenAI REST API endpoint.
164+
165+ Args:
166+ api_key: The Azure OpenAI API key.
167+ endpoint: The Azure OpenAI REST API endpoint.
168+ """
169+ headers = {
170+ # Needed when using managed identity
171+ "Authorization" : f"Bearer { api_key } " ,
172+ # Needed when using regular API key
173+ "api-key" : api_key ,
174+ }
175+ return HttpxLanguageModel (url = endpoint , headers = headers , default_params = {})
0 commit comments