From 96fd64a66baaadfe6f18b8e1420501aef422461e Mon Sep 17 00:00:00 2001 From: Yvonne Yu Date: Tue, 21 May 2024 19:06:12 -0700 Subject: [PATCH] feat: [vertexai] infer location and project when user doesn't specify them. PiperOrigin-RevId: 635997756 --- .../com/google/cloud/vertexai/Constants.java | 4 + .../com/google/cloud/vertexai/VertexAI.java | 67 ++++- .../vertexai/generativeai/ChatSession.java | 108 ++++--- .../generativeai/GenerativeModel.java | 13 +- .../google/cloud/vertexai/VertexAITest.java | 263 +++++++++++++++++- .../generativeai/ChatSessionTest.java | 20 +- .../it/ITGenerativeModelIntegrationTest.java | 16 +- 7 files changed, 435 insertions(+), 56 deletions(-) diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/Constants.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/Constants.java index 3175e0cfcac5..f81cc129cf56 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/Constants.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/Constants.java @@ -20,6 +20,10 @@ public final class Constants { // Constants for VertexAI class public static final String USER_AGENT_HEADER = "model-builder"; + static final String DEFAULT_LOCATION = "us-central1"; + static final String GOOGLE_CLOUD_REGION = "GOOGLE_CLOUD_REGION"; + static final String CLOUD_ML_REGION = "CLOUD_ML_REGION"; + static final String GOOGLE_CLOUD_PROJECT = "GOOGLE_CLOUD_PROJECT"; private Constants() {} } diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java index 30abfe14cc51..4071c676adb0 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java @@ -27,6 +27,7 @@ import com.google.api.gax.rpc.FixedHeaderProvider; import com.google.api.gax.rpc.HeaderProvider; import com.google.auth.Credentials; +import com.google.auth.oauth2.GoogleCredentials; import com.google.cloud.vertexai.api.LlmUtilityServiceClient; import com.google.cloud.vertexai.api.LlmUtilityServiceSettings; import com.google.cloud.vertexai.api.PredictionServiceClient; @@ -67,6 +68,11 @@ public class VertexAI implements AutoCloseable { private final transient Supplier predictionClientSupplier; private final transient Supplier llmClientSupplier; + @InternalApi + static Optional getEnvironmentVariable(String envKey) { + return Optional.ofNullable(System.getenv(envKey)); + } + /** * Constructs a VertexAI instance. * @@ -85,6 +91,29 @@ public VertexAI(String projectId, String location) { /* llmClientSupplierOpt= */ Optional.empty()); } + /** + * Constructs a VertexAI instance. + * + *

Note: SDK infers location from runtime environment first. If there is no location + * inferred from runtime environment, SDK will default location to `us-central1`. + * + *

SDK will infer projectId from runtime environment and GoogleCredentials. + * + * @throws java.lang.IllegalArgumentException If there is not projectId inferred from either + * runtime environment or GoogleCredentials + */ + public VertexAI() { + this( + null, + null, + Transport.GRPC, + ImmutableList.of(), + /* credentials= */ Optional.empty(), + /* apiEndpoint= */ Optional.empty(), + /* predictionClientSupplierOpt= */ Optional.empty(), + /* llmClientSupplierOpt= */ Optional.empty()); + } + private VertexAI( String projectId, String location, @@ -98,12 +127,8 @@ private VertexAI( throw new IllegalArgumentException( "At most one of Credentials and scopes should be specified."); } - checkArgument(!Strings.isNullOrEmpty(projectId), "projectId can't be null or empty"); - checkArgument(!Strings.isNullOrEmpty(location), "location can't be null or empty"); checkNotNull(transport, "transport can't be null"); - - this.projectId = projectId; - this.location = location; + this.location = Strings.isNullOrEmpty(location) ? inferLocation() : location; this.transport = transport; if (credentials.isPresent()) { @@ -118,13 +143,15 @@ private VertexAI( .build(); } + this.projectId = Strings.isNullOrEmpty(projectId) ? inferProjectId() : projectId; this.predictionClientSupplier = Suppliers.memoize(predictionClientSupplierOpt.orElse(this::newPredictionServiceClient)); this.llmClientSupplier = Suppliers.memoize(llmClientSupplierOpt.orElse(this::newLlmUtilityClient)); - this.apiEndpoint = apiEndpoint.orElse(String.format("%s-aiplatform.googleapis.com", location)); + this.apiEndpoint = + apiEndpoint.orElse(String.format("%s-aiplatform.googleapis.com", this.location)); } /** Builder for {@link VertexAI}. */ @@ -141,8 +168,6 @@ public static class Builder { private Supplier llmClientSupplier; public VertexAI build() { - checkNotNull(projectId, "projectId must be set."); - checkNotNull(location, "location must be set."); return new VertexAI( projectId, @@ -339,6 +364,32 @@ private LlmUtilityServiceClient newLlmUtilityClient() { } } + private String inferProjectId() { + final String projectNotFoundErrorMessage = + ("Unable to infer your project. Please provide a project Id by one of the following:" + + "\n- Passing a constructor argument by using new VertexAI(String projectId, String" + + " location)" + + "\n- Setting project using 'gcloud config set project my-project'"); + final Optional projectIdOptional = + getEnvironmentVariable(Constants.GOOGLE_CLOUD_PROJECT); + if (projectIdOptional.isPresent()) { + return projectIdOptional.get(); + } + try { + return Optional.ofNullable((GoogleCredentials) this.credentialsProvider.getCredentials()) + .map((credentials) -> credentials.getQuotaProjectId()) + .orElseThrow(() -> new IllegalArgumentException(projectNotFoundErrorMessage)); + } catch (IOException e) { + throw new IllegalArgumentException(projectNotFoundErrorMessage, e); + } + } + + private String inferLocation() { + return getEnvironmentVariable(Constants.GOOGLE_CLOUD_REGION) + .orElse( + getEnvironmentVariable(Constants.CLOUD_ML_REGION).orElse(Constants.DEFAULT_LOCATION)); + } + private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IOException { LlmUtilityServiceSettings.Builder settingsBuilder; if (transport == Transport.REST) { diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java index 5d3e7dd4df12..9f5b07f1eb93 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java @@ -29,19 +29,24 @@ import com.google.cloud.vertexai.api.GenerationConfig; import com.google.cloud.vertexai.api.SafetySetting; import com.google.cloud.vertexai.api.Tool; +import com.google.cloud.vertexai.api.ToolConfig; import com.google.common.collect.ImmutableList; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Optional; -/** Represents a conversation between the user and the model */ +/** + * Represents a conversation between the user and the model. + * + *

Note: this class is NOT thread-safe. + */ public final class ChatSession { private final GenerativeModel model; private final Optional rootChatSession; private final Optional automaticFunctionCallingResponder; - private List history = new ArrayList<>(); - private int previousHistorySize = 0; + private List history; + private int previousHistorySize; private Optional> currentResponseStream; private Optional currentResponse; @@ -50,7 +55,7 @@ public final class ChatSession { * GenerationConfig) inherits from the model. */ public ChatSession(GenerativeModel model) { - this(model, Optional.empty(), Optional.empty()); + this(model, new ArrayList<>(), 0, Optional.empty(), Optional.empty()); } /** @@ -58,6 +63,9 @@ public ChatSession(GenerativeModel model) { * Configurations of the chat (e.g., GenerationConfig) inherits from the model. * * @param model a {@link GenerativeModel} instance that generates contents in the chat. + * @param history a list of {@link Content} containing interleaving conversation between "user" + * and "model". + * @param previousHistorySize the size of the previous history. * @param rootChatSession a root {@link ChatSession} instance. All the chat history in the current * chat session will be merged to the root chat session. * @param automaticFunctionCallingResponder an {@link AutomaticFunctionCallingResponder} instance @@ -66,10 +74,14 @@ public ChatSession(GenerativeModel model) { */ private ChatSession( GenerativeModel model, + List history, + int previousHistorySize, Optional rootChatSession, Optional automaticFunctionCallingResponder) { checkNotNull(model, "model should not be null"); this.model = model; + this.history = history; + this.previousHistorySize = previousHistorySize; this.rootChatSession = rootChatSession; this.automaticFunctionCallingResponder = automaticFunctionCallingResponder; currentResponseStream = Optional.empty(); @@ -84,15 +96,12 @@ private ChatSession( * @return a new {@link ChatSession} instance with the specified GenerationConfig. */ public ChatSession withGenerationConfig(GenerationConfig generationConfig) { - ChatSession rootChat = rootChatSession.orElse(this); - ChatSession newChatSession = - new ChatSession( - model.withGenerationConfig(generationConfig), - Optional.of(rootChat), - automaticFunctionCallingResponder); - newChatSession.history = history; - newChatSession.previousHistorySize = previousHistorySize; - return newChatSession; + return new ChatSession( + model.withGenerationConfig(generationConfig), + history, + previousHistorySize, + Optional.of(rootChatSession.orElse(this)), + automaticFunctionCallingResponder); } /** @@ -103,15 +112,12 @@ public ChatSession withGenerationConfig(GenerationConfig generationConfig) { * @return a new {@link ChatSession} instance with the specified SafetySettings. */ public ChatSession withSafetySettings(List safetySettings) { - ChatSession rootChat = rootChatSession.orElse(this); - ChatSession newChatSession = - new ChatSession( - model.withSafetySettings(safetySettings), - Optional.of(rootChat), - automaticFunctionCallingResponder); - newChatSession.history = history; - newChatSession.previousHistorySize = previousHistorySize; - return newChatSession; + return new ChatSession( + model.withSafetySettings(safetySettings), + history, + previousHistorySize, + Optional.of(rootChatSession.orElse(this)), + automaticFunctionCallingResponder); } /** @@ -122,13 +128,44 @@ public ChatSession withSafetySettings(List safetySettings) { * @return a new {@link ChatSession} instance with the specified Tools. */ public ChatSession withTools(List tools) { - ChatSession rootChat = rootChatSession.orElse(this); - ChatSession newChatSession = - new ChatSession( - model.withTools(tools), Optional.of(rootChat), automaticFunctionCallingResponder); - newChatSession.history = history; - newChatSession.previousHistorySize = previousHistorySize; - return newChatSession; + return new ChatSession( + model.withTools(tools), + history, + previousHistorySize, + Optional.of(rootChatSession.orElse(this)), + automaticFunctionCallingResponder); + } + + /** + * Creates a copy of the current ChatSession with updated ToolConfig. + * + * @param toolConfig a {@link com.google.cloud.vertexai.api.ToolConfig} that will be used in the + * new ChatSession. + * @return a new {@link ChatSession} instance with the specified ToolConfigs. + */ + public ChatSession withToolConfig(ToolConfig toolConfig) { + return new ChatSession( + model.withToolConfig(toolConfig), + history, + previousHistorySize, + Optional.of(rootChatSession.orElse(this)), + automaticFunctionCallingResponder); + } + + /** + * Creates a copy of the current ChatSession with updated SystemInstruction. + * + * @param systemInstruction a {@link com.google.cloud.vertexai.api.Content} containing system + * instructions. + * @return a new {@link ChatSession} instance with the specified ToolConfigs. + */ + public ChatSession withSystemInstruction(Content systemInstruction) { + return new ChatSession( + model.withSystemInstruction(systemInstruction), + history, + previousHistorySize, + Optional.of(rootChatSession.orElse(this)), + automaticFunctionCallingResponder); } /** @@ -141,13 +178,12 @@ public ChatSession withTools(List tools) { */ public ChatSession withAutomaticFunctionCallingResponder( AutomaticFunctionCallingResponder automaticFunctionCallingResponder) { - ChatSession rootChat = rootChatSession.orElse(this); - ChatSession newChatSession = - new ChatSession( - model, Optional.of(rootChat), Optional.of(automaticFunctionCallingResponder)); - newChatSession.history = history; - newChatSession.previousHistorySize = previousHistorySize; - return newChatSession; + return new ChatSession( + model, + history, + previousHistorySize, + Optional.of(rootChatSession.orElse(this)), + Optional.of(automaticFunctionCallingResponder)); } /** diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java index d88fdc5da081..ed86b5993607 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/GenerativeModel.java @@ -39,7 +39,13 @@ import java.util.List; import java.util.Optional; -/** This class holds a generative model that can complete what you provided. */ +/** + * This class holds a generative model that can complete what you provided. This class is + * thread-safe. + * + *

Note: The instances of {@link ChatSession} returned by {@link GenerativeModel#startChat()} are + * NOT thread-safe. + */ public final class GenerativeModel { private final String modelName; private final String resourceName; @@ -645,6 +651,11 @@ public Optional getSystemInstruction() { return systemInstruction; } + /** + * Returns a new {@link ChatSession} instance that can be used to start a chat with this model. + * + *

Note: the returned {@link ChatSession} instance is NOT thread-safe. + */ public ChatSession startChat() { return new ChatSession(this); } diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/VertexAITest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/VertexAITest.java index 9ee2a14c7f79..4d7bfff73fc1 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/VertexAITest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/VertexAITest.java @@ -20,12 +20,15 @@ import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; +import com.google.api.gax.core.GoogleCredentialsProvider; import com.google.auth.oauth2.GoogleCredentials; import com.google.cloud.vertexai.api.PredictionServiceClient; import com.google.cloud.vertexai.api.PredictionServiceSettings; import com.google.common.collect.ImmutableList; import java.io.IOException; +import java.util.Optional; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -43,25 +46,156 @@ public final class VertexAITest { private static final String TEST_ENDPOINT = "test_endpoint"; private static final String TEST_DEFAULT_ENDPOINT = String.format("%s-aiplatform.googleapis.com", TEST_LOCATION); + private static final Optional EMPTY_ENV_VAR_OPTIONAL = Optional.ofNullable(null); private VertexAI vertexAi; - @Rule public final MockitoRule mocksRule = MockitoJUnit.rule(); @Mock private GoogleCredentials mockGoogleCredentials; @Mock private PredictionServiceClient mockPredictionServiceClient; + @Mock private GoogleCredentialsProvider.Builder mockCredentialsProviderBuilder; + + @Mock private GoogleCredentialsProvider mockCredentialsProvider; + @Test public void testInstantiateVertexAI_usingConstructor_shouldContainRightFields() throws IOException { vertexAi = new VertexAI(TEST_PROJECT, TEST_LOCATION); + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); assertThat(vertexAi.getLocation()).isEqualTo(TEST_LOCATION); assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); assertThat(vertexAi.getApiEndpoint()).isEqualTo(TEST_DEFAULT_ENDPOINT); } + @Test + public void + testInstantiateVertexAI_usingConstructorNoArgsProjectEnvVarSet_shouldContainRightFields() + throws IOException { + try (MockedStatic mockStaticVertexAI = mockStatic(VertexAI.class)) { + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_PROJECT")) + .thenReturn(Optional.of(TEST_PROJECT)); + + vertexAi = new VertexAI(); + + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo("us-central1"); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); + assertThat(vertexAi.getApiEndpoint()).isEqualTo("us-central1-aiplatform.googleapis.com"); + } + } + + @Test + public void + testInstantiateVertexAI_usingConstructorNoArgsProjectEnvVarNotSet_shouldContainRightFields() + throws IOException { + try (MockedStatic mockStaticVertexAI = mockStatic(VertexAI.class); + MockedStatic mockStaticPredictionServiceSettings = + mockStatic(PredictionServiceSettings.class); ) { + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_PROJECT")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStaticPredictionServiceSettings + .when(() -> PredictionServiceSettings.defaultCredentialsProviderBuilder()) + .thenReturn(mockCredentialsProviderBuilder); + when(mockCredentialsProviderBuilder.build()).thenReturn(mockCredentialsProvider); + when(mockCredentialsProvider.getCredentials()).thenReturn(mockGoogleCredentials); + when(mockGoogleCredentials.getQuotaProjectId()).thenReturn(TEST_PROJECT); + + vertexAi = new VertexAI(); + + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo("us-central1"); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); + assertThat(vertexAi.getApiEndpoint()).isEqualTo("us-central1-aiplatform.googleapis.com"); + } + } + + @Test + public void + testConstructor_noArgsCredentialsProviderThrowsIOException_shouldThrowIllegalArgumentException() + throws IOException { + try (MockedStatic mockStatic = mockStatic(PredictionServiceSettings.class)) { + final String expectedErrorMessage = + ("Unable to infer your project. Please provide a project Id by one of the following:" + + "\n- Passing a constructor argument by using new VertexAI(String projectId, String" + + " location)" + + "\n- Setting project using 'gcloud config set project my-project'"); + mockStatic + .when(() -> PredictionServiceSettings.defaultCredentialsProviderBuilder()) + .thenReturn(mockCredentialsProviderBuilder); + when(mockCredentialsProviderBuilder.build()).thenReturn(mockCredentialsProvider); + when(mockCredentialsProvider.getCredentials()).thenThrow(new IOException("")); + + IllegalArgumentException thrown = + assertThrows(IllegalArgumentException.class, () -> new VertexAI()); + assertThat(thrown).hasMessageThat().contains(expectedErrorMessage); + } + } + + @Test + public void + testInstantiateVertexAI_usingConstructorLocationFromGOOGLE_CLOUD_REGION_shouldContainRightFields() + throws IOException { + try (MockedStatic mockStaticVertexAI = mockStatic(VertexAI.class)) { + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION")) + .thenReturn(Optional.of("us-central2")); + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION")) + .thenReturn(Optional.of("us-central3")); + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_PROJECT")) + .thenReturn(Optional.of(TEST_PROJECT)); + + vertexAi = new VertexAI(); + + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo("us-central2"); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); + assertThat(vertexAi.getApiEndpoint()).isEqualTo("us-central2-aiplatform.googleapis.com"); + } + } + + @Test + public void + testInstantiateVertexAI_usingConstructorLocationFromCLOUD_ML_REGION_shouldContainRightFields() + throws IOException { + try (MockedStatic mockStaticVertexAI = mockStatic(VertexAI.class)) { + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION")) + .thenReturn(Optional.of("us-central2")); + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_PROJECT")) + .thenReturn(Optional.of(TEST_PROJECT)); + + vertexAi = new VertexAI(); + + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo("us-central2"); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); + assertThat(vertexAi.getApiEndpoint()).isEqualTo("us-central2-aiplatform.googleapis.com"); + } + } + @Test public void testInstantiateVertexAI_builderWithCredentials_shouldContainRightFields() throws IOException { @@ -71,6 +205,7 @@ public void testInstantiateVertexAI_builderWithCredentials_shouldContainRightFie .setLocation(TEST_LOCATION) .setCredentials(mockGoogleCredentials) .build(); + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); assertThat(vertexAi.getLocation()).isEqualTo(TEST_LOCATION); assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); @@ -79,8 +214,132 @@ public void testInstantiateVertexAI_builderWithCredentials_shouldContainRightFie } @Test - public void testInstantiateVertexAI_builderWithScopes_throwsIlegalArgumentException() + public void testInstantiateVertexAI_builderNoArgsProjectEnvVarSet_shouldContainRightFields() + throws IOException { + try (MockedStatic mockStaticVertexAI = mockStatic(VertexAI.class)) { + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_PROJECT")) + .thenReturn(Optional.of(TEST_PROJECT)); + + vertexAi = new VertexAI.Builder().build(); + + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo("us-central1"); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); + assertThat(vertexAi.getApiEndpoint()).isEqualTo("us-central1-aiplatform.googleapis.com"); + } + } + + @Test + public void testInstantiateVertexAI_builderNoArgsProjectEnvVarNotSet_shouldContainRightFields() + throws IOException { + try (MockedStatic mockStaticVertexAI = mockStatic(VertexAI.class); + MockedStatic mockStaticPredictionServiceSettings = + mockStatic(PredictionServiceSettings.class); ) { + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_PROJECT")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStaticPredictionServiceSettings + .when(() -> PredictionServiceSettings.defaultCredentialsProviderBuilder()) + .thenReturn(mockCredentialsProviderBuilder); + when(mockCredentialsProviderBuilder.build()).thenReturn(mockCredentialsProvider); + when(mockCredentialsProvider.getCredentials()).thenReturn(mockGoogleCredentials); + when(mockGoogleCredentials.getQuotaProjectId()).thenReturn(TEST_PROJECT); + + vertexAi = new VertexAI.Builder().build(); + + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo("us-central1"); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); + assertThat(vertexAi.getApiEndpoint()).isEqualTo("us-central1-aiplatform.googleapis.com"); + } + } + + @Test + public void + testBuilder_noArgsCredentialsProviderThrowsIOException_shouldThrowIllegalArgumentException() + throws IOException { + try (MockedStatic mockStaticPredictionServiceSettings = + mockStatic(PredictionServiceSettings.class)) { + final String expectedErrorMessage = + ("Unable to infer your project. Please provide a project Id by one of the following:" + + "\n- Passing a constructor argument by using new VertexAI(String projectId, String" + + " location)" + + "\n- Setting project using 'gcloud config set project my-project'"); + mockStaticPredictionServiceSettings + .when(() -> PredictionServiceSettings.defaultCredentialsProviderBuilder()) + .thenReturn(mockCredentialsProviderBuilder); + when(mockCredentialsProviderBuilder.build()).thenReturn(mockCredentialsProvider); + when(mockCredentialsProvider.getCredentials()).thenThrow(new IOException("")); + + IllegalArgumentException thrown = + assertThrows(IllegalArgumentException.class, () -> new VertexAI.Builder().build()); + assertThat(thrown).hasMessageThat().contains(expectedErrorMessage); + } + } + + @Test + public void + testInstantiateVertexAI_builderLocationFromGOOGLE_CLOUD_REGION_shouldContainRightFields() + throws IOException { + try (MockedStatic mockStaticVertexAI = mockStatic(VertexAI.class)) { + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION")) + .thenReturn(Optional.of("us-central2")); + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION")) + .thenReturn(Optional.of("us-central3")); + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_PROJECT")) + .thenReturn(Optional.of(TEST_PROJECT)); + + vertexAi = new VertexAI.Builder().build(); + + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo("us-central2"); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); + assertThat(vertexAi.getApiEndpoint()).isEqualTo("us-central2-aiplatform.googleapis.com"); + } + } + + @Test + public void testInstantiateVertexAI_builderLocationFromCLOUD_ML_REGION_shouldContainRightFields() throws IOException { + try (MockedStatic mockStaticVertexAI = mockStatic(VertexAI.class)) { + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION")) + .thenReturn(EMPTY_ENV_VAR_OPTIONAL); + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION")) + .thenReturn(Optional.of("us-central2")); + mockStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_PROJECT")) + .thenReturn(Optional.of(TEST_PROJECT)); + + vertexAi = new VertexAI.Builder().build(); + + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo("us-central2"); + assertThat(vertexAi.getTransport()).isEqualTo(Transport.GRPC); + assertThat(vertexAi.getApiEndpoint()).isEqualTo("us-central2-aiplatform.googleapis.com"); + } + } + + @Test + public void testInstantiateVertexAI_builderWithScopes_throwsIlegalArgumentException() + throws IllegalArgumentException { IllegalArgumentException thrown = assertThrows( IllegalArgumentException.class, diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java index 0537d96fb0dc..aa0c7bf911df 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java @@ -29,6 +29,7 @@ import com.google.cloud.vertexai.api.Candidate.FinishReason; import com.google.cloud.vertexai.api.Content; import com.google.cloud.vertexai.api.FunctionCall; +import com.google.cloud.vertexai.api.FunctionCallingConfig; import com.google.cloud.vertexai.api.FunctionDeclaration; import com.google.cloud.vertexai.api.GenerateContentRequest; import com.google.cloud.vertexai.api.GenerateContentResponse; @@ -40,6 +41,7 @@ import com.google.cloud.vertexai.api.SafetySetting.HarmBlockThreshold; import com.google.cloud.vertexai.api.Schema; import com.google.cloud.vertexai.api.Tool; +import com.google.cloud.vertexai.api.ToolConfig; import com.google.cloud.vertexai.api.Type; import com.google.protobuf.Struct; import com.google.protobuf.Value; @@ -174,6 +176,16 @@ public final class ChatSessionTest { .build()) .addRequired("location"))) .build(); + private static final ToolConfig TOOL_CONFIG = + ToolConfig.newBuilder() + .setFunctionCallingConfig( + FunctionCallingConfig.newBuilder() + .setMode(FunctionCallingConfig.Mode.ANY) + .addAllowedFunctionNames("getCurrentWeather")) + .build(); + private static final Content SYSTEM_INSTRUCTION = + ContentMaker.fromString( + "You're a helpful assistant that starts all its answers with: \"COOL\""); @Rule public final MockitoRule mocksRule = MockitoJUnit.rule(); @@ -518,7 +530,9 @@ public void testChatSessionMergeHistoryToRootChatSession() throws Exception { rootChat .withGenerationConfig(GENERATION_CONFIG) .withSafetySettings(Arrays.asList(SAFETY_SETTING)) - .withTools(Arrays.asList(TOOL)); + .withTools(Arrays.asList(TOOL)) + .withToolConfig(TOOL_CONFIG) + .withSystemInstruction(SYSTEM_INSTRUCTION); response = childChat.sendMessage(SAMPLE_MESSAGE_2); // (Assert) root chat history should contain all 4 contents @@ -532,8 +546,12 @@ public void testChatSessionMergeHistoryToRootChatSession() throws Exception { ArgumentCaptor request = ArgumentCaptor.forClass(GenerateContentRequest.class); verify(mockUnaryCallable, times(2)).call(request.capture()); + Content expectedSystemInstruction = SYSTEM_INSTRUCTION.toBuilder().clearRole().build(); assertThat(request.getAllValues().get(1).getGenerationConfig()).isEqualTo(GENERATION_CONFIG); assertThat(request.getAllValues().get(1).getSafetySettings(0)).isEqualTo(SAFETY_SETTING); assertThat(request.getAllValues().get(1).getTools(0)).isEqualTo(TOOL); + assertThat(request.getAllValues().get(1).getToolConfig()).isEqualTo(TOOL_CONFIG); + assertThat(request.getAllValues().get(1).getSystemInstruction()) + .isEqualTo(expectedSystemInstruction); } } diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java index 960037c77a1b..68a508c4682e 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java @@ -130,12 +130,7 @@ private static void assertNonEmptyAndLogTextContentOfResponseStream( @Test public void generateContent_restTransport_nonEmptyCandidateList() throws IOException { - try (VertexAI vertexAiViaRest = - new VertexAI.Builder() - .setProjectId(PROJECT_ID) - .setLocation(LOCATION) - .setTransport(Transport.REST) - .build()) { + try (VertexAI vertexAiViaRest = new VertexAI.Builder().setTransport(Transport.REST).build()) { GenerativeModel textModelWithRest = new GenerativeModel(MODEL_NAME_TEXT, vertexAiViaRest); GenerateContentResponse response = textModelWithRest.generateContent(TEXT); @@ -145,9 +140,14 @@ public void generateContent_restTransport_nonEmptyCandidateList() throws IOExcep @Test public void generateContent_withPlainText_nonEmptyCandidateList() throws IOException { - GenerateContentResponse response = textModel.generateContent(TEXT); + try (final VertexAI vertexAiInferredArgs = new VertexAI()) { + final GenerativeModel textModel = + new GenerativeModel(MODEL_NAME_TEXT, vertexAiInferredArgs) + .withGenerationConfig(GenerationConfig.newBuilder().setTemperature(0).build()); + GenerateContentResponse response = textModel.generateContent(TEXT); - assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, response); + assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, response); + } } @Test