Skip to content

Commit 765b109

Browse files
feat(generativeai): Add sample for Prompt Optimiser (GoogleCloudPlatform#12624)
* feat(genai): add prompt optimizer sample * add config files and tests * update comments * lint fix * πŸ¦‰ Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * lint fix and update region tag * πŸ¦‰ Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * lint and quota fix * refactor and headercheck exclusion * test dir path * test moving config files and update bucket * update headercheck * update acc to review * πŸ¦‰ Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * update location * update --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 8df3ca6 commit 765b109

File tree

7 files changed

+242
-0
lines changed

7 files changed

+242
-0
lines changed

β€Ž.github/header-checker-lint.ymlβ€Ž

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ ignoreFiles:
2525
- "dlp/snippets/resources/harmless.txt"
2626
- "dlp/snippets/resources/test.txt"
2727
- "dlp/snippets/resources/term_list.txt"
28+
- "generative_ai/prompts/test_resources/sample_prompt_template.txt"
29+
- "generative_ai/prompts/test_resources/sample_system_instruction.txt"
2830
- "service_extensions/callouts/add_header/service_pb2.py"
2931
- "service_extensions/callouts/add_header/service_pb2_grpc.py"
3032

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
18+
def optimize_prompts(
19+
project: str,
20+
location: str,
21+
staging_bucket: str,
22+
configuration_path: str,
23+
) -> str:
24+
"""Improve prompts by evaluating the model's response to sample prompts against specified evaluation metric(s).
25+
Args:
26+
project: Google Cloud Project ID.
27+
location: Location where you want to run the Vertex AI prompt optimizer.
28+
staging_bucket: Specify the Google Cloud Storage bucket to store outputs and metadata. For example, gs://bucket-name
29+
configuration_path: URI of the configuration file in your Google Cloud Storage bucket. For example, gs://bucket-name/configuration.json.
30+
Returns:
31+
custom_job.resource_name: Returns the resource name of the job created of type: projects/project-id/locations/location/customJobs/job-id
32+
"""
33+
# [START generativeaionvertexai_prompt_optimizer]
34+
from google.cloud import aiplatform
35+
36+
# TODO(developer): Update & uncomment below line
37+
# project = "your-gcp-project-id"
38+
# location = "location"
39+
# staging_bucket = "output-bucket-gcs-uri"
40+
# configuration_path = "configuration-file-gcs-uri"
41+
aiplatform.init(project=project, location=location, staging_bucket=staging_bucket)
42+
43+
worker_pool_specs = [
44+
{
45+
"replica_count": 1,
46+
"container_spec": {
47+
"image_uri": "us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/apd:preview_v1_0",
48+
"args": [f"--config={configuration_path}"],
49+
},
50+
"machine_spec": {
51+
"machine_type": "n1-standard-4",
52+
},
53+
}
54+
]
55+
56+
custom_job = aiplatform.CustomJob(
57+
display_name="Prompt Optimizer example",
58+
worker_pool_specs=worker_pool_specs,
59+
)
60+
custom_job.submit()
61+
print(f"Job resource name: {custom_job.resource_name}")
62+
63+
# [END generativeaionvertexai_prompt_optimizer]
64+
return custom_job.resource_name
65+
66+
67+
if __name__ == "__main__":
68+
optimize_prompts(
69+
os.environ["PROJECT_ID"],
70+
"us-central1",
71+
os.environ["PROMPT_OPTIMIZER_BUCKET_NAME"],
72+
os.environ["JSON_CONFIG_PATH"],
73+
)
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import os
17+
import time
18+
from typing import Callable
19+
20+
from google.cloud import aiplatform, storage
21+
from google.cloud.aiplatform import CustomJob
22+
from google.cloud.aiplatform_v1 import JobState
23+
from google.cloud.exceptions import NotFound
24+
from google.cloud.storage import transfer_manager
25+
26+
from prompt_optimizer import optimize_prompts
27+
28+
import pytest
29+
30+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
31+
STAGING_BUCKET_NAME = "prompt_optimizer_bucket"
32+
CONFIGURATION_DIRECTORY = "test_resources"
33+
CONFIGURATION_FILENAME = "sample_configuration.json"
34+
LOCATION = "us-central1"
35+
OUTPUT_PATH = "instruction"
36+
37+
STORAGE_CLIENT = storage.Client()
38+
39+
40+
def _clean_resources(bucket_resource_name: str) -> None:
41+
# delete blobs and bucket if exists
42+
try:
43+
bucket = STORAGE_CLIENT.get_bucket(bucket_resource_name)
44+
except NotFound:
45+
print(f"Bucket {bucket_resource_name} cannot be accessed")
46+
return
47+
48+
blobs = bucket.list_blobs()
49+
for blob in blobs:
50+
blob.delete()
51+
bucket.delete()
52+
53+
54+
def substitute_env_variable(data: dict, target_key: str, env_var_name: str) -> dict:
55+
# substitute env variables in the given config file with runtime values
56+
if isinstance(data, dict):
57+
for key, value in data.items():
58+
if key == target_key:
59+
data[key] = os.environ.get(env_var_name)
60+
else:
61+
data[key] = substitute_env_variable(value, target_key, env_var_name)
62+
elif isinstance(data, list):
63+
for i, value in enumerate(data):
64+
data[i] = substitute_env_variable(value, target_key, env_var_name)
65+
return data
66+
67+
68+
def update_json() -> dict:
69+
# Load the JSON file
70+
file_path = os.path.join(
71+
os.path.dirname(__file__), CONFIGURATION_DIRECTORY, CONFIGURATION_FILENAME
72+
)
73+
with open(file_path, "r") as f:
74+
data = json.load(f)
75+
# Substitute only the "project" variable with the value of "PROJECT_ID"
76+
substituted_data = substitute_env_variable(data, "project", "PROJECT_ID")
77+
return substituted_data
78+
79+
80+
@pytest.fixture(scope="session")
81+
def bucket_name() -> str:
82+
filenames = [
83+
"sample_prompt_template.txt",
84+
"sample_prompts.jsonl",
85+
"sample_system_instruction.txt",
86+
]
87+
# cleanup existing stale resources
88+
_clean_resources(STAGING_BUCKET_NAME)
89+
# create bucket
90+
bucket = STORAGE_CLIENT.bucket(STAGING_BUCKET_NAME)
91+
bucket.storage_class = "STANDARD"
92+
new_bucket = STORAGE_CLIENT.create_bucket(bucket, location="us")
93+
# update JSON to substitute env variables
94+
substituted_data = update_json()
95+
# convert the JSON data to a byte string
96+
json_str = json.dumps(substituted_data, indent=2)
97+
json_bytes = json_str.encode("utf-8")
98+
# upload substituted JSON file to the bucket
99+
blob = bucket.blob(CONFIGURATION_FILENAME)
100+
blob.upload_from_string(json_bytes)
101+
# upload config files to the bucket
102+
transfer_manager.upload_many_from_filenames(
103+
new_bucket,
104+
filenames,
105+
source_directory=os.path.join(
106+
os.path.dirname(__file__), CONFIGURATION_DIRECTORY
107+
),
108+
)
109+
yield new_bucket.name
110+
_clean_resources(new_bucket.name)
111+
112+
113+
def _main_test(test_func: Callable) -> None:
114+
job_resource_name: str = ""
115+
timeout = 900 # seconds
116+
# wait for the job to complete
117+
try:
118+
job_resource_name = test_func()
119+
start_time = time.time()
120+
while (
121+
get_job(job_resource_name).state
122+
not in [JobState.JOB_STATE_SUCCEEDED, JobState.JOB_STATE_FAILED]
123+
and time.time() - start_time < timeout
124+
):
125+
time.sleep(10)
126+
finally:
127+
# delete job
128+
get_job(job_resource_name).delete()
129+
130+
131+
def test_prompt_optimizer(bucket_name: pytest.fixture()) -> None:
132+
_main_test(
133+
test_func=lambda: optimize_prompts(
134+
PROJECT_ID,
135+
LOCATION,
136+
f"gs://{bucket_name}",
137+
f"gs://{bucket_name}/{CONFIGURATION_FILENAME}",
138+
)
139+
)
140+
assert (
141+
STORAGE_CLIENT.get_bucket(bucket_name).list_blobs(prefix=OUTPUT_PATH)
142+
is not None
143+
)
144+
145+
146+
def get_job(job_resource_name: str) -> CustomJob:
147+
return aiplatform.CustomJob.get(
148+
resource_name=job_resource_name, project=PROJECT_ID, location=LOCATION
149+
)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"project": "$PROJECT_ID",
3+
"system_instruction_path": "gs://prompt_optimizer_bucket/sample_system_instruction.txt",
4+
"prompt_template_path": "gs://prompt_optimizer_bucket/sample_prompt_template.txt",
5+
"target_model": "gemini-1.5-flash-001",
6+
"eval_metrics_types": ["safety"],
7+
"optimization_mode": "instruction",
8+
"input_data_path": "gs://prompt_optimizer_bucket/sample_prompts.jsonl",
9+
"output_path": "gs://prompt_optimizer_bucket",
10+
"eval_metrics_weights": [1]
11+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Question: Do {{animal_name}} {{animal_activity}}?
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{"animal_name": "Bears", "animal_activity": "Eat grapes"}
2+
{"animal_name": "Cows", "animal_activity": "swim in the ocean"}
3+
{"animal_name": "Bees", "animal_activity": "Ride donkeys"}
4+
{"animal_name": "Cats", "animal_activity": "go to school"}
5+
{"animal_name": "Lions", "animal_activity": "hunt"}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Based on the following text respond to the questions.'\n' Be concise, and answer \"I don't know\" if the response cannot be found in the provided text.

0 commit comments

Comments
Β (0)