forked from GoogleCloudPlatform/python-docs-samples
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_gemini_examples.py
More file actions
171 lines (126 loc) · 4.9 KB
/
test_gemini_examples.py
File metadata and controls
171 lines (126 loc) · 4.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import vertexai
import gemini_all_modalities
import gemini_audio
import gemini_chat_example
import gemini_count_token_example
import gemini_grounding_example
import gemini_guide_example
import gemini_multi_image_example
import gemini_pdf_example
import gemini_pro_basic_example
import gemini_pro_config_example
import gemini_safety_config_example
import gemini_single_turn_video_example
import gemini_system_instruction
import gemini_text_input_example
import gemini_video_audio
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
LOCATION = "us-central1"
vertexai.init(project=PROJECT_ID, location=LOCATION)
def test_gemini_guide_example() -> None:
text = gemini_guide_example.generate_text(PROJECT_ID)
text = text.lower()
assert len(text) > 0
def test_gemini_text_input_example() -> None:
text = gemini_text_input_example.generate_from_text_input(PROJECT_ID)
assert len(text) > 0
def test_gemini_pro_basic_example() -> None:
text = gemini_pro_basic_example.generate_text(PROJECT_ID)
assert len(text) > 0
def test_gemini_pro_config_example() -> None:
import urllib.request
# download the image
fname = "scones.jpg"
url = "https://storage.googleapis.com/generativeai-downloads/images/scones.jpg"
urllib.request.urlretrieve(url, fname)
if os.path.isfile(fname):
text = gemini_pro_config_example.generate_text(PROJECT_ID)
text = text.lower()
assert len(text) > 0
# clean-up
os.remove(fname)
else:
raise Exception("File(scones.jpg) not found!")
def test_gemini_multi_image_example() -> None:
text = gemini_multi_image_example.generate_text_multimodal(PROJECT_ID)
text = text.lower()
assert len(text) > 0
assert "city" in text
assert "landmark" in text
def test_gemini_count_token_example() -> None:
response = gemini_count_token_example.count_tokens(PROJECT_ID)
assert response
assert response.usage_metadata
response = gemini_count_token_example.count_tokens_multimodal(PROJECT_ID)
assert response
assert response.usage_metadata
def test_gemini_safety_config_example() -> None:
text = gemini_safety_config_example.generate_text(PROJECT_ID)
assert len(text) > 0
def test_gemini_single_turn_video_example() -> None:
text = gemini_single_turn_video_example.generate_text(PROJECT_ID)
text = text.lower()
assert len(text) > 0
assert any(
[_ in text for _ in ("zoo", "tiger", "leaf", "water", "animals", "photos")]
)
@pytest.mark.skip(
"TODO: Exception Logs indicate safety filters are likely blocking model output b/339985493"
)
def test_gemini_pdf_example() -> None:
text = gemini_pdf_example.analyze_pdf(PROJECT_ID)
assert len(text) > 0
def test_gemini_chat_example() -> None:
text = gemini_chat_example.chat_text_example(PROJECT_ID)
text = text.lower()
assert len(text) > 0
assert any([_ in text for _ in ("hi", "hello", "greeting")])
text = gemini_chat_example.chat_stream_example(PROJECT_ID)
text = text.lower()
assert len(text) > 0
assert any([_ in text for _ in ("hi", "hello", "greeting")])
@pytest.mark.skip(
"Unable to test Google Search grounding due to allowlist restrictions."
)
def test_gemini_grounding_web_example() -> None:
response = gemini_grounding_example.generate_text_with_grounding_web(
PROJECT_ID,
)
assert response
def test_gemini_grounding_vais_example() -> None:
data_store_path = f"projects/{PROJECT_ID}/locations/global/collections/default_collection/dataStores/grounding-test-datastore"
response = gemini_grounding_example.generate_text_with_grounding_vertex_ai_search(
PROJECT_ID,
data_store_path=data_store_path,
)
assert response
def test_summarize_audio() -> None:
text = gemini_audio.summarize_audio(PROJECT_ID)
assert len(text) > 0
def test_transcript_audio() -> None:
text = gemini_audio.transcript_audio(PROJECT_ID)
assert len(text) > 0
def test_analyze_video_with_audio() -> None:
text = gemini_video_audio.analyze_video_with_audio(PROJECT_ID)
assert len(text) > 0
def test_analyze_all_modalities() -> None:
text = gemini_all_modalities.analyze_all_modalities(PROJECT_ID)
assert len(text) > 0
def test_set_system_instruction() -> None:
text = gemini_system_instruction.set_system_instruction(PROJECT_ID)
assert len(text) > 0