Skip to content

Commit c165431

Browse files
authored
Vertex SDK - LLM - Added tuning samples for the code generation models - code-bison (GoogleCloudPlatform#10504)
* Vertex SDK - LLM - Added tuning samples for the code generation models - `code-bison` * Bumping the SKD version * Fixed the called function names
1 parent b5b828a commit c165431

5 files changed

Lines changed: 255 additions & 1 deletion
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_sdk_list_tuned_code_generation_models]
16+
17+
import vertexai
18+
from vertexai.preview.language_models import CodeGenerationModel
19+
20+
21+
def list_tuned_code_generation_models(
22+
project_id: str,
23+
location: str,
24+
) -> None:
25+
"""List tuned models."""
26+
vertexai.init(project=project_id, location=location)
27+
model = CodeGenerationModel.from_pretrained("code-bison@001")
28+
tuned_model_names = model.list_tuned_model_names()
29+
print(tuned_model_names)
30+
# [END aiplatform_sdk_list_tuned_code_generation_models]
31+
return tuned_model_names
32+
33+
34+
if __name__ == "__main__":
35+
list_tuned_code_generation_models()
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import backoff
18+
from google.api_core.exceptions import ResourceExhausted
19+
from google.cloud import aiplatform
20+
21+
import list_tuned_code_generation_models
22+
23+
24+
_PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
25+
_LOCATION = "us-central1"
26+
27+
28+
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
29+
def test_list_tuned_code_generation_models() -> None:
30+
tuned_model_names = list_tuned_code_generation_models.list_tuned_code_generation_models(
31+
_PROJECT_ID,
32+
_LOCATION,
33+
)
34+
filtered_models_counter = 0
35+
for tuned_model_name in tuned_model_names:
36+
model_registry = aiplatform.models.ModelRegistry(model=tuned_model_name)
37+
if (
38+
"Vertex LLM Test Fixture "
39+
"(list_tuned_models_test.py::test_list_tuned_models)"
40+
) in model_registry.get_version_info("1").model_display_name:
41+
filtered_models_counter += 1
42+
assert filtered_models_counter == 0

generative_ai/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
pandas==1.3.5; python_version == '3.7'
22
pandas==2.0.1; python_version > '3.7'
3-
google-cloud-aiplatform[pipelines]==1.28.1
3+
google-cloud-aiplatform[pipelines]==1.29.0
44
google-auth==2.17.3
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_sdk_tune_code_generation_model]
16+
from __future__ import annotations
17+
18+
19+
from google.auth import default
20+
import pandas as pd
21+
import vertexai
22+
from vertexai.preview.language_models import CodeGenerationModel
23+
24+
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
25+
26+
27+
def tune_code_generation_model(
28+
project_id: str,
29+
location: str,
30+
training_data: pd.DataFrame | str,
31+
train_steps: int = 300,
32+
) -> None:
33+
"""Tune a new model, based on a prompt-response data.
34+
35+
"training_data" can be either the GCS URI of a file formatted in JSONL format
36+
(for example: training_data=f'gs://{bucket}/{filename}.jsonl'), or a pandas
37+
DataFrame. Each training example should be JSONL record with two keys, for
38+
example:
39+
{
40+
"input_text": <input prompt>,
41+
"output_text": <associated output>
42+
},
43+
or the pandas DataFame should contain two columns:
44+
['input_text', 'output_text']
45+
with rows for each training example.
46+
47+
Args:
48+
project_id: GCP Project ID, used to initialize vertexai
49+
location: GCP Region, used to initialize vertexai
50+
training_data: GCS URI of jsonl file or pandas dataframe of training data
51+
train_steps: Number of training steps to use when tuning the model.
52+
"""
53+
vertexai.init(project=project_id, location=location, credentials=credentials)
54+
model = CodeGenerationModel.from_pretrained("code-bison@001")
55+
56+
model.tune_model(
57+
training_data=training_data,
58+
# Optional:
59+
train_steps=train_steps,
60+
tuning_job_location="europe-west4",
61+
tuned_model_location=location,
62+
)
63+
64+
print(model._job.status)
65+
# [END aiplatform_sdk_tune_code_generation_model]
66+
return model
67+
68+
69+
if __name__ == "__main__":
70+
tune_code_generation_model()
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import uuid
17+
18+
from google.cloud import aiplatform
19+
from google.cloud import storage
20+
from google.cloud.aiplatform.compat.types import pipeline_state
21+
import pytest
22+
from vertexai.preview.language_models import TextGenerationModel
23+
24+
import tune_code_generation_model
25+
26+
_PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
27+
_LOCATION = "us-central1"
28+
_BUCKET = os.environ["CLOUD_STORAGE_BUCKET"]
29+
30+
31+
def get_model_display_name(tuned_model: TextGenerationModel) -> str:
32+
language_model_tuning_job = tuned_model._job
33+
pipeline_job = language_model_tuning_job._job
34+
return dict(pipeline_job._gca_resource.runtime_config.parameter_values)[
35+
"model_display_name"
36+
]
37+
38+
39+
def upload_to_gcs(bucket: str, name: str, data: str) -> None:
40+
client = storage.Client()
41+
bucket = client.get_bucket(bucket)
42+
blob = bucket.blob(name)
43+
blob.upload_from_string(data)
44+
45+
46+
def download_from_gcs(bucket: str, name: str) -> str:
47+
client = storage.Client()
48+
bucket = client.get_bucket(bucket)
49+
blob = bucket.blob(name)
50+
data = blob.download_as_bytes()
51+
return "\n".join(data.decode().splitlines()[:10])
52+
53+
54+
def delete_from_gcs(bucket: str, name: str) -> None:
55+
client = storage.Client()
56+
bucket = client.get_bucket(bucket)
57+
blob = bucket.blob(name)
58+
blob.delete()
59+
60+
61+
@pytest.fixture(scope="function")
62+
def training_data_filename() -> str:
63+
temp_filename = f"{uuid.uuid4()}.jsonl"
64+
data = download_from_gcs(
65+
"cloud-samples-data", "ai-platform/generative_ai/headline_classification.jsonl"
66+
)
67+
upload_to_gcs(_BUCKET, temp_filename, data)
68+
try:
69+
yield f"gs://{_BUCKET}/{temp_filename}"
70+
finally:
71+
delete_from_gcs(_BUCKET, temp_filename)
72+
73+
74+
def teardown_model(
75+
tuned_model: TextGenerationModel, training_data_filename: str
76+
) -> None:
77+
for tuned_model_name in tuned_model.list_tuned_model_names():
78+
model_registry = aiplatform.models.ModelRegistry(model=tuned_model_name)
79+
if (
80+
training_data_filename
81+
in model_registry.get_version_info("1").model_display_name
82+
):
83+
display_name = model_registry.get_version_info("1").model_display_name
84+
for endpoint in aiplatform.Endpoint.list():
85+
for _ in endpoint.list_models():
86+
if endpoint.display_name == display_name:
87+
endpoint.undeploy_all()
88+
endpoint.delete()
89+
aiplatform.Model(model_registry.model_resource_name).delete()
90+
91+
92+
@pytest.mark.skip("Blocked on b/277959219")
93+
def test_tuning_code_generation_model(training_data_filename: str) -> None:
94+
"""Takes approx. 20 minutes."""
95+
tuned_model = tune_code_generation_model.tune_code_generation_model(
96+
training_data=training_data_filename,
97+
project_id=_PROJECT_ID,
98+
location=_LOCATION,
99+
train_steps=1,
100+
)
101+
try:
102+
assert (
103+
tuned_model._job.status
104+
== pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
105+
)
106+
finally:
107+
teardown_model(tuned_model, training_data_filename)

0 commit comments

Comments
 (0)