-
Notifications
You must be signed in to change notification settings - Fork 1k
Expand file tree
/
Copy pathtest_google_generativeai_patch.py
More file actions
55 lines (48 loc) · 2.06 KB
/
test_google_generativeai_patch.py
File metadata and controls
55 lines (48 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import json
import unittest
import threading
from test.support.os_helper import EnvironmentVarGuard
from urllib.parse import urlparse
from http.server import BaseHTTPRequestHandler, HTTPServer
class HTTPHandler(BaseHTTPRequestHandler):
called = False
path = None
headers = {}
def do_HEAD(self):
self.send_response(200)
def do_GET(self):
HTTPHandler.path = self.path
HTTPHandler.headers = self.headers
HTTPHandler.called = True
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
class TestGoogleGenerativeAiPatch(unittest.TestCase):
endpoint = "http://127.0.0.1:80"
def test_proxy_enabled(self):
env = EnvironmentVarGuard()
secrets_token = "secrets_token"
proxy_token = "proxy_token"
env.set("KAGGLE_USER_SECRETS_TOKEN", secrets_token)
env.set("KAGGLE_DATA_PROXY_TOKEN", proxy_token)
env.set("KAGGLE_DATA_PROXY_URL", self.endpoint)
env.set("KAGGLE_GRPC_DATA_PROXY_URL", "http://127.0.0.1:50001")
env.set("KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY", "True")
server_address = urlparse(self.endpoint)
with env:
with HTTPServer((server_address.hostname, server_address.port), HTTPHandler) as httpd:
threading.Thread(target=httpd.serve_forever).start()
import google.generativeai as palm
api_key = "NotARealAPIKey"
palm.configure(api_key = api_key)
try:
for _ in palm.list_models():
pass
except:
pass
httpd.shutdown()
self.assertTrue(HTTPHandler.called)
self.assertIn("/palmapi", HTTPHandler.path)
self.assertEqual(proxy_token, HTTPHandler.headers["x-kaggle-proxy-data"])
self.assertEqual("Bearer {}".format(secrets_token), HTTPHandler.headers["x-kaggle-authorization"])
self.assertEqual(api_key, HTTPHandler.headers["x-goog-api-key"])