11from typing_extensions import Protocol , override
2- import os
32import openai
43
54from typechat ._internal .result import Failure , Result , Success
65
76
8- class TypeChatModel (Protocol ):
7+ class TypeChatLanguageModel (Protocol ):
98 async def complete (self , input : str ) -> Result [str ]:
9+ """
10+ Represents a AI language model that can complete prompts.
11+
12+ TypeChat uses an implementation of this protocol to communicate
13+ with an AI service that can translate natural language requests to JSON
14+ instances according to a provided schema.
15+ The `create_language_model` function can create an instance.
16+ """
1017 ...
1118
1219
13- class DefaultOpenAIModel (TypeChatModel ):
20+ class DefaultOpenAIModel (TypeChatLanguageModel ):
1421 model_name : str
1522 client : openai .AsyncOpenAI | openai .AsyncAzureOpenAI
1623
@@ -34,24 +41,46 @@ async def complete(self, input: str) -> Result[str]:
3441 except Exception as e :
3542 return Failure (str (e ))
3643
37- def create_language_model (vals : dict [str ,str | None ]) -> TypeChatModel :
38- model : TypeChatModel
44+ def create_language_model (vals : dict [str , str | None ]) -> TypeChatLanguageModel :
45+ """
46+ Creates a language model encapsulation of an OpenAI or Azure OpenAI REST API endpoint
47+ chosen by a dictionary of variables (typically just `os.environ`).
48+
49+ If an `OPENAI_API_KEY` environment variable exists, an OpenAI model is constructed.
50+ The `OPENAI_ENDPOINT` and `OPENAI_MODEL` environment variables must also be defined or an error will be raised.
51+
52+ If an `AZURE_OPENAI_API_KEY` environment variable exists, an Azure OpenAI model is constructed.
53+ The `AZURE_OPENAI_ENDPOINT` environment variable must also be defined or an exception will be thrown.
54+
55+ If none of these key variables are defined, an exception is thrown.
56+ @returns An instance of `TypeChatLanguageModel`.
57+
58+ Args:
59+ vals: A dictionary of variables. Typically just `os.environ`.
60+ """
61+ model : TypeChatLanguageModel
3962 client : openai .AsyncOpenAI | openai .AsyncAzureOpenAI
4063
64+ def required_var (name : str ) -> str :
65+ val = vals .get (name , None )
66+ if val is None :
67+ raise ValueError (f"Missing environment variable { name } ." )
68+ return val
69+
4170 if "OPENAI_API_KEY" in vals :
42- client = openai .AsyncOpenAI (api_key = vals [ "OPENAI_API_KEY" ] )
43- model = DefaultOpenAIModel (model_name = vals . get ("OPENAI_MODEL" , None ) or "gpt-35-turbo" , client = client )
71+ client = openai .AsyncOpenAI (api_key = required_var ( "OPENAI_API_KEY" ) )
72+ model = DefaultOpenAIModel (model_name = required_var ("OPENAI_MODEL" ) , client = client )
4473
45- elif "AZURE_OPENAI_API_KEY" in vals and "AZURE_OPENAI_ENDPOINT" in vals :
46- os . environ [ "OPENAI_API_TYPE" ] = "azure"
74+ elif "AZURE_OPENAI_API_KEY" in vals :
75+ openai . api_type = "azure"
4776 client = openai .AsyncAzureOpenAI (
48- azure_endpoint = vals . get ( "AZURE_OPENAI_ENDPOINT" , None ) or "" ,
49- api_key = vals [ "AZURE_OPENAI_API_KEY" ] ,
77+ api_key = required_var ( "AZURE_OPENAI_API_KEY" ) ,
78+ azure_endpoint = required_var ( "AZURE_OPENAI_ENDPOINT" ) ,
5079 api_version = "2023-03-15-preview" ,
5180 )
5281 model = DefaultOpenAIModel (model_name = vals .get ("AZURE_OPENAI_MODEL" , None ) or "gpt-35-turbo" , client = client )
5382
5483 else :
55- raise ValueError ("Missing environment variables for Open AI or Azure OpenAI model " )
84+ raise ValueError ("Missing environment variables for OPENAI_API_KEY or AZURE_OPENAI_API_KEY. " )
5685
5786 return model
0 commit comments