Skip to content

Commit 1d73827

Browse files
committed
feat: add multi-speaker voice config
PiperOrigin-RevId: 759431774
1 parent fb60681 commit 1d73827

6 files changed

Lines changed: 517 additions & 44 deletions

File tree

google/genai/_live_converters.py

Lines changed: 201 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,193 @@
2222
from ._common import set_value_by_path as setv
2323

2424

25+
def _PrebuiltVoiceConfig_to_mldev(
26+
api_client: BaseApiClient,
27+
from_object: Union[dict[str, Any], object],
28+
parent_object: Optional[dict[str, Any]] = None,
29+
) -> dict[str, Any]:
30+
to_object: dict[str, Any] = {}
31+
if getv(from_object, ['voice_name']) is not None:
32+
setv(to_object, ['voiceName'], getv(from_object, ['voice_name']))
33+
34+
return to_object
35+
36+
37+
def _PrebuiltVoiceConfig_to_vertex(
38+
api_client: BaseApiClient,
39+
from_object: Union[dict[str, Any], object],
40+
parent_object: Optional[dict[str, Any]] = None,
41+
) -> dict[str, Any]:
42+
to_object: dict[str, Any] = {}
43+
if getv(from_object, ['voice_name']) is not None:
44+
setv(to_object, ['voiceName'], getv(from_object, ['voice_name']))
45+
46+
return to_object
47+
48+
49+
def _VoiceConfig_to_mldev(
50+
api_client: BaseApiClient,
51+
from_object: Union[dict[str, Any], object],
52+
parent_object: Optional[dict[str, Any]] = None,
53+
) -> dict[str, Any]:
54+
to_object: dict[str, Any] = {}
55+
if getv(from_object, ['prebuilt_voice_config']) is not None:
56+
setv(
57+
to_object,
58+
['prebuiltVoiceConfig'],
59+
_PrebuiltVoiceConfig_to_mldev(
60+
api_client, getv(from_object, ['prebuilt_voice_config']), to_object
61+
),
62+
)
63+
64+
return to_object
65+
66+
67+
def _VoiceConfig_to_vertex(
68+
api_client: BaseApiClient,
69+
from_object: Union[dict[str, Any], object],
70+
parent_object: Optional[dict[str, Any]] = None,
71+
) -> dict[str, Any]:
72+
to_object: dict[str, Any] = {}
73+
if getv(from_object, ['prebuilt_voice_config']) is not None:
74+
setv(
75+
to_object,
76+
['prebuiltVoiceConfig'],
77+
_PrebuiltVoiceConfig_to_vertex(
78+
api_client, getv(from_object, ['prebuilt_voice_config']), to_object
79+
),
80+
)
81+
82+
return to_object
83+
84+
85+
def _SpeakerVoiceConfig_to_mldev(
86+
api_client: BaseApiClient,
87+
from_object: Union[dict[str, Any], object],
88+
parent_object: Optional[dict[str, Any]] = None,
89+
) -> dict[str, Any]:
90+
to_object: dict[str, Any] = {}
91+
if getv(from_object, ['speaker']) is not None:
92+
setv(to_object, ['speaker'], getv(from_object, ['speaker']))
93+
94+
if getv(from_object, ['voice_config']) is not None:
95+
setv(
96+
to_object,
97+
['voiceConfig'],
98+
_VoiceConfig_to_mldev(
99+
api_client, getv(from_object, ['voice_config']), to_object
100+
),
101+
)
102+
103+
return to_object
104+
105+
106+
def _SpeakerVoiceConfig_to_vertex(
107+
api_client: BaseApiClient,
108+
from_object: Union[dict[str, Any], object],
109+
parent_object: Optional[dict[str, Any]] = None,
110+
) -> dict[str, Any]:
111+
to_object: dict[str, Any] = {}
112+
if getv(from_object, ['speaker']) is not None:
113+
raise ValueError('speaker parameter is not supported in Vertex AI.')
114+
115+
if getv(from_object, ['voice_config']) is not None:
116+
raise ValueError('voice_config parameter is not supported in Vertex AI.')
117+
118+
return to_object
119+
120+
121+
def _MultiSpeakerVoiceConfig_to_mldev(
122+
api_client: BaseApiClient,
123+
from_object: Union[dict[str, Any], object],
124+
parent_object: Optional[dict[str, Any]] = None,
125+
) -> dict[str, Any]:
126+
to_object: dict[str, Any] = {}
127+
if getv(from_object, ['speaker_voice_configs']) is not None:
128+
setv(
129+
to_object,
130+
['speakerVoiceConfigs'],
131+
[
132+
_SpeakerVoiceConfig_to_mldev(api_client, item, to_object)
133+
for item in getv(from_object, ['speaker_voice_configs'])
134+
],
135+
)
136+
137+
return to_object
138+
139+
140+
def _MultiSpeakerVoiceConfig_to_vertex(
141+
api_client: BaseApiClient,
142+
from_object: Union[dict[str, Any], object],
143+
parent_object: Optional[dict[str, Any]] = None,
144+
) -> dict[str, Any]:
145+
to_object: dict[str, Any] = {}
146+
if getv(from_object, ['speaker_voice_configs']) is not None:
147+
raise ValueError(
148+
'speaker_voice_configs parameter is not supported in Vertex AI.'
149+
)
150+
151+
return to_object
152+
153+
154+
def _SpeechConfig_to_mldev(
155+
api_client: BaseApiClient,
156+
from_object: Union[dict[str, Any], object],
157+
parent_object: Optional[dict[str, Any]] = None,
158+
) -> dict[str, Any]:
159+
to_object: dict[str, Any] = {}
160+
if getv(from_object, ['voice_config']) is not None:
161+
setv(
162+
to_object,
163+
['voiceConfig'],
164+
_VoiceConfig_to_mldev(
165+
api_client, getv(from_object, ['voice_config']), to_object
166+
),
167+
)
168+
169+
if getv(from_object, ['multi_speaker_voice_config']) is not None:
170+
setv(
171+
to_object,
172+
['multiSpeakerVoiceConfig'],
173+
_MultiSpeakerVoiceConfig_to_mldev(
174+
api_client,
175+
getv(from_object, ['multi_speaker_voice_config']),
176+
to_object,
177+
),
178+
)
179+
180+
if getv(from_object, ['language_code']) is not None:
181+
setv(to_object, ['languageCode'], getv(from_object, ['language_code']))
182+
183+
return to_object
184+
185+
186+
def _SpeechConfig_to_vertex(
187+
api_client: BaseApiClient,
188+
from_object: Union[dict[str, Any], object],
189+
parent_object: Optional[dict[str, Any]] = None,
190+
) -> dict[str, Any]:
191+
to_object: dict[str, Any] = {}
192+
if getv(from_object, ['voice_config']) is not None:
193+
setv(
194+
to_object,
195+
['voiceConfig'],
196+
_VoiceConfig_to_vertex(
197+
api_client, getv(from_object, ['voice_config']), to_object
198+
),
199+
)
200+
201+
if getv(from_object, ['multi_speaker_voice_config']) is not None:
202+
raise ValueError(
203+
'multi_speaker_voice_config parameter is not supported in Vertex AI.'
204+
)
205+
206+
if getv(from_object, ['language_code']) is not None:
207+
setv(to_object, ['languageCode'], getv(from_object, ['language_code']))
208+
209+
return to_object
210+
211+
25212
def _Blob_to_mldev(
26213
api_client: BaseApiClient,
27214
from_object: Union[dict[str, Any], object],
@@ -1010,7 +1197,13 @@ def _LiveConnectConfig_to_mldev(
10101197
setv(
10111198
parent_object,
10121199
['setup', 'generationConfig', 'speechConfig'],
1013-
getv(from_object, ['speech_config']),
1200+
_SpeechConfig_to_mldev(
1201+
api_client,
1202+
t.t_live_speech_config(
1203+
api_client, getv(from_object, ['speech_config'])
1204+
),
1205+
to_object,
1206+
),
10141207
)
10151208

10161209
if getv(from_object, ['system_instruction']) is not None:
@@ -1154,7 +1347,13 @@ def _LiveConnectConfig_to_vertex(
11541347
setv(
11551348
parent_object,
11561349
['setup', 'generationConfig', 'speechConfig'],
1157-
getv(from_object, ['speech_config']),
1350+
_SpeechConfig_to_vertex(
1351+
api_client,
1352+
t.t_live_speech_config(
1353+
api_client, getv(from_object, ['speech_config'])
1354+
),
1355+
to_object,
1356+
),
11581357
)
11591358

11601359
if getv(from_object, ['system_instruction']) is not None:

google/genai/_transformers.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -866,26 +866,29 @@ def t_speech_config(
866866
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=origin)
867867
)
868868
)
869-
if (
870-
isinstance(origin, dict)
871-
and 'voice_config' in origin
872-
and origin['voice_config'] is not None
873-
and 'prebuilt_voice_config' in origin['voice_config']
874-
and origin['voice_config']['prebuilt_voice_config'] is not None
875-
and 'voice_name' in origin['voice_config']['prebuilt_voice_config']
876-
):
877-
return types.SpeechConfig(
878-
voice_config=types.VoiceConfig(
879-
prebuilt_voice_config=types.PrebuiltVoiceConfig(
880-
voice_name=origin['voice_config']['prebuilt_voice_config'].get(
881-
'voice_name'
882-
)
883-
)
884-
)
885-
)
869+
if isinstance(origin, dict):
870+
return types.SpeechConfig.model_validate(origin)
871+
886872
raise ValueError(f'Unsupported speechConfig type: {type(origin)}')
887873

888874

875+
def t_live_speech_config(
876+
client: _api_client.BaseApiClient,
877+
origin: types.SpeechConfigOrDict,
878+
) -> Optional[types.SpeechConfig]:
879+
if isinstance(origin, types.SpeechConfig):
880+
speech_config = origin
881+
if isinstance(origin, dict):
882+
speech_config = types.SpeechConfig.model_validate(origin)
883+
884+
if speech_config.multi_speaker_voice_config is not None:
885+
raise ValueError(
886+
'multi_speaker_voice_config is not supported in the live API.'
887+
)
888+
889+
return speech_config
890+
891+
889892
def t_tool(
890893
client: _api_client.BaseApiClient, origin: Any
891894
) -> Optional[Union[types.Tool, Any]]:

google/genai/models.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,46 @@ def _VoiceConfig_to_mldev(
500500
return to_object
501501

502502

503+
def _SpeakerVoiceConfig_to_mldev(
504+
api_client: BaseApiClient,
505+
from_object: Union[dict[str, Any], object],
506+
parent_object: Optional[dict[str, Any]] = None,
507+
) -> dict[str, Any]:
508+
to_object: dict[str, Any] = {}
509+
if getv(from_object, ['speaker']) is not None:
510+
setv(to_object, ['speaker'], getv(from_object, ['speaker']))
511+
512+
if getv(from_object, ['voice_config']) is not None:
513+
setv(
514+
to_object,
515+
['voiceConfig'],
516+
_VoiceConfig_to_mldev(
517+
api_client, getv(from_object, ['voice_config']), to_object
518+
),
519+
)
520+
521+
return to_object
522+
523+
524+
def _MultiSpeakerVoiceConfig_to_mldev(
525+
api_client: BaseApiClient,
526+
from_object: Union[dict[str, Any], object],
527+
parent_object: Optional[dict[str, Any]] = None,
528+
) -> dict[str, Any]:
529+
to_object: dict[str, Any] = {}
530+
if getv(from_object, ['speaker_voice_configs']) is not None:
531+
setv(
532+
to_object,
533+
['speakerVoiceConfigs'],
534+
[
535+
_SpeakerVoiceConfig_to_mldev(api_client, item, to_object)
536+
for item in getv(from_object, ['speaker_voice_configs'])
537+
],
538+
)
539+
540+
return to_object
541+
542+
503543
def _SpeechConfig_to_mldev(
504544
api_client: BaseApiClient,
505545
from_object: Union[dict[str, Any], object],
@@ -515,6 +555,17 @@ def _SpeechConfig_to_mldev(
515555
),
516556
)
517557

558+
if getv(from_object, ['multi_speaker_voice_config']) is not None:
559+
setv(
560+
to_object,
561+
['multiSpeakerVoiceConfig'],
562+
_MultiSpeakerVoiceConfig_to_mldev(
563+
api_client,
564+
getv(from_object, ['multi_speaker_voice_config']),
565+
to_object,
566+
),
567+
)
568+
518569
if getv(from_object, ['language_code']) is not None:
519570
setv(to_object, ['languageCode'], getv(from_object, ['language_code']))
520571

@@ -1746,6 +1797,35 @@ def _VoiceConfig_to_vertex(
17461797
return to_object
17471798

17481799

1800+
def _SpeakerVoiceConfig_to_vertex(
1801+
api_client: BaseApiClient,
1802+
from_object: Union[dict[str, Any], object],
1803+
parent_object: Optional[dict[str, Any]] = None,
1804+
) -> dict[str, Any]:
1805+
to_object: dict[str, Any] = {}
1806+
if getv(from_object, ['speaker']) is not None:
1807+
raise ValueError('speaker parameter is not supported in Vertex AI.')
1808+
1809+
if getv(from_object, ['voice_config']) is not None:
1810+
raise ValueError('voice_config parameter is not supported in Vertex AI.')
1811+
1812+
return to_object
1813+
1814+
1815+
def _MultiSpeakerVoiceConfig_to_vertex(
1816+
api_client: BaseApiClient,
1817+
from_object: Union[dict[str, Any], object],
1818+
parent_object: Optional[dict[str, Any]] = None,
1819+
) -> dict[str, Any]:
1820+
to_object: dict[str, Any] = {}
1821+
if getv(from_object, ['speaker_voice_configs']) is not None:
1822+
raise ValueError(
1823+
'speaker_voice_configs parameter is not supported in Vertex AI.'
1824+
)
1825+
1826+
return to_object
1827+
1828+
17491829
def _SpeechConfig_to_vertex(
17501830
api_client: BaseApiClient,
17511831
from_object: Union[dict[str, Any], object],
@@ -1761,6 +1841,11 @@ def _SpeechConfig_to_vertex(
17611841
),
17621842
)
17631843

1844+
if getv(from_object, ['multi_speaker_voice_config']) is not None:
1845+
raise ValueError(
1846+
'multi_speaker_voice_config parameter is not supported in Vertex AI.'
1847+
)
1848+
17641849
if getv(from_object, ['language_code']) is not None:
17651850
setv(to_object, ['languageCode'], getv(from_object, ['language_code']))
17661851

0 commit comments

Comments
 (0)