Skip to content

Commit 1bfbef6

Browse files
Allow models to take lists of PromptSections, allow translators to take preambles. (#203)
* Allow models to take lists of PromptSections, allow translators to take preambles. * Export `PromptSection` and the new direct model creation functions. * Appropriately override `translate`.
1 parent 85d9e04 commit 1bfbef6

4 files changed

Lines changed: 40 additions & 22 deletions

File tree

python/examples/healthData/translator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from typing_extensions import TypeVar, Any, override, TypedDict, Literal
33

4-
from typechat import TypeChatValidator, TypeChatLanguageModel, TypeChatTranslator, Result, Failure
4+
from typechat import TypeChatValidator, TypeChatLanguageModel, TypeChatTranslator, Result, Failure, PromptSection
55

66
from datetime import datetime
77

@@ -27,8 +27,8 @@ def __init__(
2727
self._additional_agent_instructions = additional_agent_instructions
2828

2929
@override
30-
async def translate(self, request: str) -> Result[T]:
31-
result = await super().translate(request=request)
30+
async def translate(self, request: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]:
31+
result = await super().translate(request=request, prompt_preamble=prompt_preamble)
3232
if not isinstance(result, Failure):
3333
self._chat_history.append(ChatMessage(source="assistant", body=result.value))
3434
return result

python/src/typechat/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# SPDX-License-Identifier: MIT
44

5-
from typechat._internal.model import TypeChatLanguageModel, create_language_model
5+
from typechat._internal.model import PromptSection, TypeChatLanguageModel, create_language_model, create_openai_language_model, create_azure_openai_language_model
66
from typechat._internal.result import Failure, Result, Success
77
from typechat._internal.translator import TypeChatTranslator
88
from typechat._internal.ts_conversion import python_type_to_typescript_schema
@@ -17,6 +17,9 @@
1717
"Failure",
1818
"Result",
1919
"python_type_to_typescript_schema",
20+
"PromptSection",
2021
"create_language_model",
21-
"process_requests"
22+
"create_openai_language_model",
23+
"create_azure_openai_language_model",
24+
"process_requests",
2225
]

python/src/typechat/_internal/model.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,17 @@
66

77
import httpx
88

9+
class PromptSection(TypedDict):
10+
"""
11+
Represents a section of an LLM prompt with an associated role. TypeChat uses the "user" role for
12+
prompts it generates and the "assistant" role for previous LLM responses (which will be part of
13+
the prompt in repair attempts). TypeChat currently doesn't use the "system" role.
14+
"""
15+
role: Literal["system", "user", "assistant"]
16+
content: str
17+
918
class TypeChatLanguageModel(Protocol):
10-
async def complete(self, prompt: str) -> Result[str]:
19+
async def complete(self, prompt: str | list[PromptSection]) -> Result[str]:
1120
"""
1221
Represents a AI language model that can complete prompts.
1322
@@ -18,15 +27,6 @@ async def complete(self, prompt: str) -> Result[str]:
1827
"""
1928
...
2029

21-
class _PromptSection(TypedDict):
22-
"""
23-
Represents a section of an LLM prompt with an associated role. TypeChat uses the "user" role for
24-
prompts it generates and the "assistant" role for previous LLM responses (which will be part of
25-
the prompt in repair attempts). TypeChat currently doesn't use the "system" role.
26-
"""
27-
role: Literal["system", "user", "assistant"]
28-
content: str
29-
3030
_TRANSIENT_ERROR_CODES = [
3131
429,
3232
500,
@@ -51,15 +51,18 @@ def __init__(self, url: str, headers: dict[str, str], default_params: dict[str,
5151
self._async_client = httpx.AsyncClient()
5252

5353
@override
54-
async def complete(self, prompt: str) -> Success[str] | Failure:
54+
async def complete(self, prompt: str | list[PromptSection]) -> Success[str] | Failure:
5555
headers = {
5656
"Content-Type": "application/json",
5757
**self.headers,
5858
}
59-
messages = [{"role": "user", "content": prompt}]
59+
60+
if isinstance(prompt, str):
61+
prompt = [{"role": "user", "content": prompt}]
62+
6063
body = {
6164
**self.default_params,
62-
"messages": messages,
65+
"messages": prompt,
6366
"temperature": 0.0,
6467
"n": 1,
6568
}
@@ -73,7 +76,7 @@ async def complete(self, prompt: str) -> Success[str] | Failure:
7376
)
7477
if response.is_success:
7578
json_result = cast(
76-
dict[Literal["choices"], list[dict[Literal["message"], _PromptSection]]],
79+
dict[Literal["choices"], list[dict[Literal["message"], PromptSection]]],
7780
response.json()
7881
)
7982
return Success(json_result["choices"][0]["message"]["content"] or "")

python/src/typechat/_internal/translator.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing_extensions import Generic, TypeVar
22

3-
from typechat._internal.model import TypeChatLanguageModel
3+
from typechat._internal.model import PromptSection, TypeChatLanguageModel
44
from typechat._internal.result import Failure, Result, Success
55
from typechat._internal.ts_conversion import python_type_to_typescript_schema
66
from typechat._internal.validator import TypeChatValidator
@@ -43,10 +43,11 @@ def __init__(
4343
if _raise_on_schema_errors and conversion_result.errors:
4444
error_text = "".join(f"\n- {error}" for error in conversion_result.errors)
4545
raise ValueError(f"Could not convert Python type to TypeScript schema: \n{error_text}")
46+
4647
self._type_name = conversion_result.typescript_type_reference
4748
self._schema_str = conversion_result.typescript_schema_str
4849

49-
async def translate(self, request: str) -> Result[T]:
50+
async def translate(self, request: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]:
5051
"""
5152
Translates a natural language request into an object of type `T`. If the JSON object returned by
5253
the language model fails to validate, repair attempts will be made up until `_max_repair_attempts`.
@@ -55,11 +56,22 @@ async def translate(self, request: str) -> Result[T]:
5556
5657
Args:
5758
request: A natural language request.
59+
prompt_preamble: An optional string or list of prompt sections to prepend to the generated prompt.\
60+
If a string is given, it is converted to a single "user" role prompt section.
5861
"""
5962
request = self._create_request_prompt(request)
63+
64+
prompt: str | list[PromptSection]
65+
if prompt_preamble is None:
66+
prompt = request
67+
else:
68+
if isinstance(prompt_preamble, str):
69+
prompt_preamble = [{"role": "user", "content": prompt_preamble}]
70+
prompt = [*prompt_preamble, {"role": "user", "content": request}]
71+
6072
num_repairs_attempted = 0
6173
while True:
62-
completion_response = await self.model.complete(request)
74+
completion_response = await self.model.complete(prompt)
6375
if isinstance(completion_response, Failure):
6476
return completion_response
6577

0 commit comments

Comments
 (0)