Skip to content

Commit a1930e4

Browse files
Various doc updates and expose some internals (#236)
* Switch `sentiment` example to be a dataclass. * Make it possible to actually reference the attributes of the internally-constructed models. * Update doc comment.
1 parent 11f5992 commit a1930e4

4 files changed

Lines changed: 20 additions & 15 deletions

File tree

python/examples/sentiment/demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ async def request_handler(message: str):
1717
print(result.message)
1818
else:
1919
result = result.value
20-
print(f"The sentiment is {result['sentiment']}")
20+
print(f"The sentiment is {result.sentiment}")
2121

2222
file_path = sys.argv[1] if len(sys.argv) == 2 else None
2323
await process_requests("😀> ", file_path, request_handler)

python/examples/sentiment/schema.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing_extensions import Literal, TypedDict, Annotated, Doc
1+
from dataclasses import dataclass
2+
from typing_extensions import Literal, Annotated, Doc
23

3-
4-
class Sentiment(TypedDict):
4+
@dataclass
5+
class Sentiment:
56
"""
67
The following is a schema definition for determining the sentiment of a some user input.
78
"""

python/src/typechat/_internal/model.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,14 @@ class HttpxLanguageModel(TypeChatLanguageModel, AsyncContextManager):
3939
url: str
4040
headers: dict[str, str]
4141
default_params: dict[str, str]
42+
# Specifies the maximum number of retry attempts.
43+
max_retry_attempts: int = 3
44+
# Specifies the delay before retrying in milliseconds.
45+
retry_pause_seconds: float = 1.0
46+
# Specifies how long a request should wait in seconds
47+
# before timing out with a Failure.
48+
timeout_seconds = 10
4249
_async_client: httpx.AsyncClient
43-
_max_retry_attempts: int = 3
44-
_retry_pause_seconds: float = 1.0
45-
_timeout_seconds = 10
4650

4751
def __init__(self, url: str, headers: dict[str, str], default_params: dict[str, str]):
4852
super().__init__()
@@ -74,7 +78,7 @@ async def complete(self, prompt: str | list[PromptSection]) -> Success[str] | Fa
7478
self.url,
7579
headers=headers,
7680
json=body,
77-
timeout=self._timeout_seconds
81+
timeout=self.timeout_seconds
7882
)
7983
if response.is_success:
8084
json_result = cast(
@@ -83,13 +87,13 @@ async def complete(self, prompt: str | list[PromptSection]) -> Success[str] | Fa
8387
)
8488
return Success(json_result["choices"][0]["message"]["content"] or "")
8589

86-
if response.status_code not in _TRANSIENT_ERROR_CODES or retry_count >= self._max_retry_attempts:
90+
if response.status_code not in _TRANSIENT_ERROR_CODES or retry_count >= self.max_retry_attempts:
8791
return Failure(f"REST API error {response.status_code}: {response.reason_phrase}")
8892
except Exception as e:
89-
if retry_count >= self._max_retry_attempts:
93+
if retry_count >= self.max_retry_attempts:
9094
return Failure(str(e) or f"{repr(e)} raised from within internal TypeChat language model.")
9195

92-
await asyncio.sleep(self._retry_pause_seconds)
96+
await asyncio.sleep(self.retry_pause_seconds)
9397
retry_count += 1
9498

9599
@override
@@ -106,7 +110,7 @@ def __del__(self):
106110
except Exception:
107111
pass
108112

109-
def create_language_model(vals: dict[str, str | None]) -> TypeChatLanguageModel:
113+
def create_language_model(vals: dict[str, str | None]) -> HttpxLanguageModel:
110114
"""
111115
Creates a language model encapsulation of an OpenAI or Azure OpenAI REST API endpoint
112116
chosen by a dictionary of variables (typically just `os.environ`).
@@ -144,7 +148,7 @@ def required_var(name: str) -> str:
144148
else:
145149
raise ValueError("Missing environment variables for OPENAI_API_KEY or AZURE_OPENAI_API_KEY.")
146150

147-
def create_openai_language_model(api_key: str, model: str, endpoint: str = "https://api.openai.com/v1/chat/completions", org: str = ""):
151+
def create_openai_language_model(api_key: str, model: str, endpoint: str = "https://api.openai.com/v1/chat/completions", org: str = "") -> HttpxLanguageModel:
148152
"""
149153
Creates a language model encapsulation of an OpenAI REST API endpoint.
150154
@@ -163,7 +167,7 @@ def create_openai_language_model(api_key: str, model: str, endpoint: str = "http
163167
}
164168
return HttpxLanguageModel(url=endpoint, headers=headers, default_params=default_params)
165169

166-
def create_azure_openai_language_model(api_key: str, endpoint: str):
170+
def create_azure_openai_language_model(api_key: str, endpoint: str) -> HttpxLanguageModel:
167171
"""
168172
Creates a language model encapsulation of an Azure OpenAI REST API endpoint.
169173

python/src/typechat/_internal/validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
class TypeChatValidator(Generic[T]):
1212
"""
13-
Validates JSON text against a given Python type.
13+
Validates an object against a given Python type.
1414
"""
1515

1616
_adapted_type: pydantic.TypeAdapter[T]

0 commit comments

Comments
 (0)