Skip to content

Commit 497f3c9

Browse files
Add docstring comments for TypeChat for Python (#191)
* Add docstring comments, rework some logic around `create_language_model`. * Rename `TypeChatModel` to `TypeChatLanguageModel`.
1 parent aeb4d89 commit 497f3c9

12 files changed

Lines changed: 113 additions & 31 deletions

File tree

python/examples/healthData/translator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from typing_extensions import TypeVar, Any, override, TypedDict, Literal
33

4-
from typechat import TypeChatValidator, TypeChatModel, TypeChatTranslator, Result, Failure
4+
from typechat import TypeChatValidator, TypeChatLanguageModel, TypeChatTranslator, Result, Failure
55

66
from datetime import datetime
77

@@ -19,7 +19,7 @@ class TranslatorWithHistory(TypeChatTranslator[T]):
1919
_additional_agent_instructions: str
2020

2121
def __init__(
22-
self, model: TypeChatModel, validator: TypeChatValidator[T], target_type: type[T], additional_agent_instructions: str
22+
self, model: TypeChatLanguageModel, validator: TypeChatValidator[T], target_type: type[T], additional_agent_instructions: str
2323
):
2424
super().__init__(model=model, validator=validator, target_type=target_type)
2525
self._chat_history = []

python/examples/math/program.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Failure,
1818
Result,
1919
Success,
20-
TypeChatModel,
20+
TypeChatLanguageModel,
2121
TypeChatValidator,
2222
TypeChatTranslator,
2323
python_type_to_typescript_schema,
@@ -149,7 +149,7 @@ def validate(self, json_text: str) -> Result[JsonProgram]:
149149
class TypeChatProgramTranslator(TypeChatTranslator[JsonProgram]):
150150
_api_declaration_str: str
151151

152-
def __init__(self, model: TypeChatModel, validator: TypeChatProgramValidator, api_type: type):
152+
def __init__(self, model: TypeChatLanguageModel, validator: TypeChatProgramValidator, api_type: type):
153153
super().__init__(model=model, validator=validator, target_type=api_type)
154154
# TODO: the conversion result here has errors!
155155
conversion_result = python_type_to_typescript_schema(api_type)

python/examples/multiSchema/agents.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import json
1111

1212
from typing_extensions import TypeVar, Generic
13-
from typechat import Failure, TypeChatTranslator, TypeChatValidator, TypeChatModel
13+
from typechat import Failure, TypeChatTranslator, TypeChatValidator, TypeChatLanguageModel
1414

1515
import examples.math.schema as math_schema
1616
from examples.math.program import (
@@ -29,7 +29,7 @@ class JsonPrintAgent(Generic[T]):
2929
_validator: TypeChatValidator[T]
3030
_translator: TypeChatTranslator[T]
3131

32-
def __init__(self, model: TypeChatModel, target_type: type[T]):
32+
def __init__(self, model: TypeChatLanguageModel, target_type: type[T]):
3333
super().__init__()
3434
self._validator = TypeChatValidator(target_type)
3535
self._translator = TypeChatTranslator(model, self._validator, target_type)
@@ -47,7 +47,7 @@ class MathAgent:
4747
_validator: TypeChatProgramValidator
4848
_translator: TypeChatProgramTranslator
4949

50-
def __init__(self, model: TypeChatModel):
50+
def __init__(self, model: TypeChatLanguageModel):
5151
super().__init__()
5252
self._validator = TypeChatProgramValidator()
5353
self._translator = TypeChatProgramTranslator(model, self._validator, math_schema.MathAPI)
@@ -95,7 +95,7 @@ class MusicAgent:
9595
_client_context: ClientContext | None
9696
_authentication_vals: dict[str, str | None]
9797

98-
def __init__(self, model: TypeChatModel, authentication_vals: dict[str, str | None]):
98+
def __init__(self, model: TypeChatLanguageModel, authentication_vals: dict[str, str | None]):
9999
super().__init__()
100100
self._validator = TypeChatValidator(music_schema.PlayerActions)
101101
self._translator = TypeChatTranslator(model, self._validator, music_schema.PlayerActions)

python/examples/multiSchema/router.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
from typing_extensions import Any, Callable, Awaitable, TypedDict, Annotated
3-
from typechat import Failure, TypeChatValidator, TypeChatModel, TypeChatTranslator
3+
from typechat import Failure, TypeChatValidator, TypeChatLanguageModel, TypeChatTranslator
44

55

66
class AgentInfo(TypedDict):
@@ -18,7 +18,7 @@ class TextRequestRouter:
1818
_validator: TypeChatValidator[TaskClassification]
1919
_translator: TypeChatTranslator[TaskClassification]
2020

21-
def __init__(self, model: TypeChatModel):
21+
def __init__(self, model: TypeChatLanguageModel):
2222
super().__init__()
2323
self._validator = TypeChatValidator(TaskClassification)
2424
self._translator = TypeChatTranslator(model, self._validator, TaskClassification)

python/src/typechat/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
#
33
# SPDX-License-Identifier: MIT
44

5-
from typechat._internal.model import TypeChatModel, create_language_model
5+
from typechat._internal.model import TypeChatLanguageModel, create_language_model
66
from typechat._internal.result import Failure, Result, Success
77
from typechat._internal.translator import TypeChatTranslator
88
from typechat._internal.ts_conversion import python_type_to_typescript_schema
99
from typechat._internal.validator import TypeChatValidator
1010
from typechat._internal.interactive import process_requests
1111

1212
__all__ = [
13-
"TypeChatModel",
13+
"TypeChatLanguageModel",
1414
"TypeChatTranslator",
1515
"TypeChatValidator",
1616
"Success",

python/src/typechat/_internal/interactive.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22
from typing import Callable, Awaitable
33

44
async def process_requests(interactive_prompt: str, input_file_name: str | None, process_request: Callable[[str], Awaitable[None]]):
5+
"""
6+
A request processor for interactive input or input from a text file. If an input file name is specified,
7+
the callback function is invoked for each line in file. Otherwise, the callback function is invoked for
8+
each line of interactive input until the user types "quit" or "exit".
9+
10+
Args:
11+
interactive_prompt: Prompt to present to user.
12+
input_file_name: Input text file name, if any.
13+
process_request: Async callback function that is invoked for each interactive input or each line in text file.
14+
"""
515
if input_file_name is not None:
616
with open(input_file_name, "r") as file:
717
lines = filter(str.rstrip, file)

python/src/typechat/_internal/model.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
11
from typing_extensions import Protocol, override
2-
import os
32
import openai
43

54
from typechat._internal.result import Failure, Result, Success
65

76

8-
class TypeChatModel(Protocol):
7+
class TypeChatLanguageModel(Protocol):
98
async def complete(self, input: str) -> Result[str]:
9+
"""
10+
Represents a AI language model that can complete prompts.
11+
12+
TypeChat uses an implementation of this protocol to communicate
13+
with an AI service that can translate natural language requests to JSON
14+
instances according to a provided schema.
15+
The `create_language_model` function can create an instance.
16+
"""
1017
...
1118

1219

13-
class DefaultOpenAIModel(TypeChatModel):
20+
class DefaultOpenAIModel(TypeChatLanguageModel):
1421
model_name: str
1522
client: openai.AsyncOpenAI | openai.AsyncAzureOpenAI
1623

@@ -34,24 +41,46 @@ async def complete(self, input: str) -> Result[str]:
3441
except Exception as e:
3542
return Failure(str(e))
3643

37-
def create_language_model(vals: dict[str,str|None]) -> TypeChatModel:
38-
model: TypeChatModel
44+
def create_language_model(vals: dict[str, str | None]) -> TypeChatLanguageModel:
45+
"""
46+
Creates a language model encapsulation of an OpenAI or Azure OpenAI REST API endpoint
47+
chosen by a dictionary of variables (typically just `os.environ`).
48+
49+
If an `OPENAI_API_KEY` environment variable exists, an OpenAI model is constructed.
50+
The `OPENAI_ENDPOINT` and `OPENAI_MODEL` environment variables must also be defined or an error will be raised.
51+
52+
If an `AZURE_OPENAI_API_KEY` environment variable exists, an Azure OpenAI model is constructed.
53+
The `AZURE_OPENAI_ENDPOINT` environment variable must also be defined or an exception will be thrown.
54+
55+
If none of these key variables are defined, an exception is thrown.
56+
@returns An instance of `TypeChatLanguageModel`.
57+
58+
Args:
59+
vals: A dictionary of variables. Typically just `os.environ`.
60+
"""
61+
model: TypeChatLanguageModel
3962
client: openai.AsyncOpenAI | openai.AsyncAzureOpenAI
4063

64+
def required_var(name: str) -> str:
65+
val = vals.get(name, None)
66+
if val is None:
67+
raise ValueError(f"Missing environment variable {name}.")
68+
return val
69+
4170
if "OPENAI_API_KEY" in vals:
42-
client = openai.AsyncOpenAI(api_key=vals["OPENAI_API_KEY"])
43-
model = DefaultOpenAIModel(model_name=vals.get("OPENAI_MODEL", None) or "gpt-35-turbo", client=client)
71+
client = openai.AsyncOpenAI(api_key=required_var("OPENAI_API_KEY"))
72+
model = DefaultOpenAIModel(model_name=required_var("OPENAI_MODEL"), client=client)
4473

45-
elif "AZURE_OPENAI_API_KEY" in vals and "AZURE_OPENAI_ENDPOINT" in vals:
46-
os.environ["OPENAI_API_TYPE"] = "azure"
74+
elif "AZURE_OPENAI_API_KEY" in vals:
75+
openai.api_type = "azure"
4776
client = openai.AsyncAzureOpenAI(
48-
azure_endpoint=vals.get("AZURE_OPENAI_ENDPOINT", None) or "",
49-
api_key=vals["AZURE_OPENAI_API_KEY"],
77+
api_key=required_var("AZURE_OPENAI_API_KEY"),
78+
azure_endpoint=required_var("AZURE_OPENAI_ENDPOINT"),
5079
api_version="2023-03-15-preview",
5180
)
5281
model = DefaultOpenAIModel(model_name=vals.get("AZURE_OPENAI_MODEL", None) or "gpt-35-turbo", client=client)
5382

5483
else:
55-
raise ValueError("Missing environment variables for Open AI or Azure OpenAI model")
84+
raise ValueError("Missing environment variables for OPENAI_API_KEY or AZURE_OPENAI_API_KEY.")
5685

5786
return model
Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
from dataclasses import dataclass
2-
from typing_extensions import Generic, TypeVar
2+
from typing_extensions import Generic, TypeAlias, TypeVar
33

44
T = TypeVar("T", covariant=True)
55

66
@dataclass
77
class Success(Generic[T]):
8+
"An object representing a successful operation with a result of type `T`."
89
value: T
910

1011

1112
@dataclass
1213
class Failure:
14+
"An object representing an operation that failed for the reason given in `message`."
1315
message: str
1416

1517

16-
Result = Success[T] | Failure
18+
"""
19+
An object representing a successful or failed operation of type `T`.
20+
"""
21+
Result: TypeAlias = Success[T] | Failure

python/src/typechat/_internal/translator.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,53 @@
11
from typing_extensions import Generic, TypeVar
22

3-
from typechat._internal.model import TypeChatModel
3+
from typechat._internal.model import TypeChatLanguageModel
44
from typechat._internal.result import Failure, Result, Success
55
from typechat._internal.ts_conversion import python_type_to_typescript_schema
66
from typechat._internal.validator import TypeChatValidator
77

88
T = TypeVar("T", covariant=True)
99

1010
class TypeChatTranslator(Generic[T]):
11-
model: TypeChatModel
11+
"""
12+
Represents an object that can translate natural language requests in JSON objects of the given type.
13+
"""
14+
15+
model: TypeChatLanguageModel
1216
validator: TypeChatValidator[T]
1317
target_type: type[T]
1418
_type_name: str
1519
_schema_str: str
1620
_max_repair_attempts = 1
1721

18-
def __init__(self, model: TypeChatModel, validator: TypeChatValidator[T], target_type: type[T]):
22+
def __init__(self, model: TypeChatLanguageModel, validator: TypeChatValidator[T], target_type: type[T]):
23+
"""
24+
Args:
25+
model: The associated `TypeChatLanguageModel`.
26+
validator: The associated `TypeChatValidator[T]`.
27+
target_type: A runtime type object describing `T` - the expected shape of JSON data.
28+
"""
1929
super().__init__()
2030
self.model = model
21-
self.target_type = target_type
2231
self.validator = validator
32+
self.target_type = target_type
33+
2334
conversion_result = python_type_to_typescript_schema(target_type)
35+
# TODO: Examples may not work here!
36+
# if conversion_result.errors:
37+
# raise ValueError(f"Could not convert Python type to TypeScript schema: {conversion_result.errors}")
2438
self._type_name = conversion_result.typescript_type_reference
2539
self._schema_str = conversion_result.typescript_schema_str
2640

2741
async def translate(self, request: str) -> Result[T]:
42+
"""
43+
Translates a natural language request into an object of type `T`. If the JSON object returned by
44+
the language model fails to validate, repair attempts will be made up until `_max_repair_attempts`.
45+
The prompt for the subsequent attempts will include the diagnostics produced for the prior attempt.
46+
This often helps produce a valid instance.
47+
48+
Args:
49+
request: A natural language request.
50+
"""
2851
request = self._create_request_prompt(request)
2952
num_repairs_attempted = 0
3053
while True:

python/src/typechat/_internal/ts_conversion/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
from typing_extensions import TypeAliasType
23

34
from typechat._internal.ts_conversion.python_type_to_ts_nodes import python_type_to_typescript_nodes
45
from typechat._internal.ts_conversion.ts_node_to_string import ts_declaration_to_str
@@ -19,7 +20,7 @@ class TypeScriptSchemaConversionResult:
1920
errors: list[str]
2021
"""Any errors that occurred during conversion."""
2122

22-
def python_type_to_typescript_schema(py_type: object) -> TypeScriptSchemaConversionResult:
23+
def python_type_to_typescript_schema(py_type: type | TypeAliasType) -> TypeScriptSchemaConversionResult:
2324
"""Converts a Python type to a TypeScript schema."""
2425

2526
node_conversion_result = python_type_to_typescript_nodes(py_type)

0 commit comments

Comments
 (0)