Skip to content

Commit 7078fb7

Browse files
copybara-service[bot]Zhenyi Qi
andauthored
chore: [vertexai] Make instantiation of clients thread safe. (#10588)
PiperOrigin-RevId: 615418373 Co-authored-by: Zhenyi Qi <zhenyiqi@google.com>
1 parent 67c8784 commit 7078fb7

File tree

2 files changed

+108
-165
lines changed

2 files changed

+108
-165
lines changed

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java

Lines changed: 89 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
package com.google.cloud.vertexai;
1818

19+
import static com.google.common.base.Preconditions.checkArgument;
20+
import static com.google.common.base.Preconditions.checkNotNull;
21+
1922
import com.google.api.core.InternalApi;
2023
import com.google.api.gax.core.CredentialsProvider;
2124
import com.google.api.gax.core.FixedCredentialsProvider;
@@ -28,8 +31,10 @@
2831
import com.google.cloud.vertexai.api.LlmUtilityServiceSettings;
2932
import com.google.cloud.vertexai.api.PredictionServiceClient;
3033
import com.google.cloud.vertexai.api.PredictionServiceSettings;
34+
import com.google.common.base.Strings;
3135
import java.io.IOException;
3236
import java.util.List;
37+
import java.util.concurrent.locks.ReentrantLock;
3338
import java.util.logging.Level;
3439
import java.util.logging.Logger;
3540

@@ -56,9 +61,8 @@ public class VertexAI implements AutoCloseable {
5661
private Transport transport = Transport.GRPC;
5762
// The clients will be instantiated lazily
5863
private PredictionServiceClient predictionServiceClient = null;
59-
private PredictionServiceClient predictionServiceRestClient = null;
6064
private LlmUtilityServiceClient llmUtilityClient = null;
61-
private LlmUtilityServiceClient llmUtilityRestClient = null;
65+
private final ReentrantLock lock = new ReentrantLock();
6266

6367
/**
6468
* Construct a VertexAI instance.
@@ -193,32 +197,35 @@ public Credentials getCredentials() throws IOException {
193197

194198
/** Sets the value for {@link #getTransport()}. */
195199
public void setTransport(Transport transport) {
200+
checkNotNull(transport, "Transport can't be null.");
201+
if (this.transport == transport) {
202+
return;
203+
}
204+
196205
this.transport = transport;
206+
resetClients();
197207
}
198208

199209
/** Sets the value for {@link #getApiEndpoint()}. */
200210
public void setApiEndpoint(String apiEndpoint) {
211+
checkArgument(!Strings.isNullOrEmpty(apiEndpoint), "Api endpoint can't be null or empty.");
212+
if (this.apiEndpoint == apiEndpoint) {
213+
return;
214+
}
201215
this.apiEndpoint = apiEndpoint;
216+
resetClients();
217+
}
202218

219+
private void resetClients() {
203220
if (this.predictionServiceClient != null) {
204221
this.predictionServiceClient.close();
205222
this.predictionServiceClient = null;
206223
}
207224

208-
if (this.predictionServiceRestClient != null) {
209-
this.predictionServiceRestClient.close();
210-
this.predictionServiceRestClient = null;
211-
}
212-
213225
if (this.llmUtilityClient != null) {
214226
this.llmUtilityClient.close();
215227
this.llmUtilityClient = null;
216228
}
217-
218-
if (this.llmUtilityRestClient != null) {
219-
this.llmUtilityRestClient.close();
220-
this.llmUtilityRestClient = null;
221-
}
222229
}
223230

224231
/**
@@ -230,78 +237,47 @@ public void setApiEndpoint(String apiEndpoint) {
230237
*/
231238
@InternalApi
232239
public PredictionServiceClient getPredictionServiceClient() throws IOException {
233-
if (this.transport == Transport.GRPC) {
234-
return getPredictionServiceGrpcClient();
235-
} else {
236-
return getPredictionServiceRestClient();
240+
if (predictionServiceClient != null) {
241+
return predictionServiceClient;
237242
}
238-
}
239-
240-
/**
241-
* Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
242-
* first prediction API call is made.
243-
*
244-
* @return {@link PredictionServiceClient} that send GRPC requests to the backing service through
245-
* method calls that map to the API methods.
246-
*/
247-
private PredictionServiceClient getPredictionServiceGrpcClient() throws IOException {
248-
if (predictionServiceClient == null) {
249-
PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newBuilder();
250-
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
251-
if (this.credentialsProvider != null) {
252-
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
243+
lock.lock();
244+
try {
245+
if (predictionServiceClient == null) {
246+
PredictionServiceSettings settings = getPredictionServiceSettings();
247+
// Disable the warning message logged in getApplicationDefault
248+
Logger defaultCredentialsProviderLogger =
249+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
250+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
251+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
252+
predictionServiceClient = PredictionServiceClient.create(settings);
253+
defaultCredentialsProviderLogger.setLevel(previousLevel);
253254
}
254-
HeaderProvider headerProvider =
255-
FixedHeaderProvider.create(
256-
"user-agent",
257-
String.format(
258-
"%s/%s",
259-
Constants.USER_AGENT_HEADER,
260-
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
261-
settingsBuilder.setHeaderProvider(headerProvider);
262-
// Disable the warning message logged in getApplicationDefault
263-
Logger defaultCredentialsProviderLogger =
264-
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
265-
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
266-
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
267-
predictionServiceClient = PredictionServiceClient.create(settingsBuilder.build());
268-
defaultCredentialsProviderLogger.setLevel(previousLevel);
255+
return predictionServiceClient;
256+
} finally {
257+
lock.unlock();
269258
}
270-
return predictionServiceClient;
271259
}
272260

273-
/**
274-
* Returns the {@link PredictionServiceClient} with REST. The client will be instantiated when the
275-
* first prediction API call is made.
276-
*
277-
* @return {@link PredictionServiceClient} that send REST requests to the backing service through
278-
* method calls that map to the API methods.
279-
*/
280-
private PredictionServiceClient getPredictionServiceRestClient() throws IOException {
281-
if (predictionServiceRestClient == null) {
282-
PredictionServiceSettings.Builder settingsBuilder =
283-
PredictionServiceSettings.newHttpJsonBuilder();
284-
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
285-
if (this.credentialsProvider != null) {
286-
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
287-
}
288-
HeaderProvider headerProvider =
289-
FixedHeaderProvider.create(
290-
"user-agent",
291-
String.format(
292-
"%s/%s",
293-
Constants.USER_AGENT_HEADER,
294-
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
295-
settingsBuilder.setHeaderProvider(headerProvider);
296-
// Disable the warning message logged in getApplicationDefault
297-
Logger defaultCredentialsProviderLogger =
298-
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
299-
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
300-
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
301-
predictionServiceRestClient = PredictionServiceClient.create(settingsBuilder.build());
302-
defaultCredentialsProviderLogger.setLevel(previousLevel);
261+
private PredictionServiceSettings getPredictionServiceSettings() throws IOException {
262+
PredictionServiceSettings.Builder builder;
263+
if (transport == Transport.REST) {
264+
builder = PredictionServiceSettings.newHttpJsonBuilder();
265+
} else {
266+
builder = PredictionServiceSettings.newBuilder();
267+
}
268+
builder.setEndpoint(String.format("%s:443", this.apiEndpoint));
269+
if (this.credentialsProvider != null) {
270+
builder.setCredentialsProvider(this.credentialsProvider);
303271
}
304-
return predictionServiceRestClient;
272+
HeaderProvider headerProvider =
273+
FixedHeaderProvider.create(
274+
"user-agent",
275+
String.format(
276+
"%s/%s",
277+
Constants.USER_AGENT_HEADER,
278+
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
279+
builder.setHeaderProvider(headerProvider);
280+
return builder.build();
305281
}
306282

307283
/**
@@ -313,78 +289,47 @@ private PredictionServiceClient getPredictionServiceRestClient() throws IOExcept
313289
*/
314290
@InternalApi
315291
public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
316-
if (this.transport == Transport.GRPC) {
317-
return getLlmUtilityGrpcClient();
318-
} else {
319-
return getLlmUtilityRestClient();
292+
if (llmUtilityClient != null) {
293+
return llmUtilityClient;
320294
}
321-
}
322-
323-
/**
324-
* Returns the {@link LlmUtilityServiceClient} with GRPC. The client will be instantiated when the
325-
* first API call is made.
326-
*
327-
* @return {@link LlmUtilityServiceClient} that makes gRPC calls to the backing service through
328-
* method calls that map to the API methods.
329-
*/
330-
private LlmUtilityServiceClient getLlmUtilityGrpcClient() throws IOException {
331-
if (llmUtilityClient == null) {
332-
LlmUtilityServiceSettings.Builder settingsBuilder = LlmUtilityServiceSettings.newBuilder();
333-
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
334-
if (this.credentialsProvider != null) {
335-
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
295+
lock.lock();
296+
try {
297+
if (llmUtilityClient == null) {
298+
LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings();
299+
// Disable the warning message logged in getApplicationDefault
300+
Logger defaultCredentialsProviderLogger =
301+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
302+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
303+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
304+
llmUtilityClient = LlmUtilityServiceClient.create(settings);
305+
defaultCredentialsProviderLogger.setLevel(previousLevel);
336306
}
337-
HeaderProvider headerProvider =
338-
FixedHeaderProvider.create(
339-
"user-agent",
340-
String.format(
341-
"%s/%s",
342-
Constants.USER_AGENT_HEADER,
343-
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
344-
settingsBuilder.setHeaderProvider(headerProvider);
345-
// Disable the warning message logged in getApplicationDefault
346-
Logger defaultCredentialsProviderLogger =
347-
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
348-
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
349-
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
350-
llmUtilityClient = LlmUtilityServiceClient.create(settingsBuilder.build());
351-
defaultCredentialsProviderLogger.setLevel(previousLevel);
307+
return llmUtilityClient;
308+
} finally {
309+
lock.unlock();
352310
}
353-
return llmUtilityClient;
354311
}
355312

356-
/**
357-
* Returns the {@link LlmUtilityServiceClient} with REST. The client will be instantiated when the
358-
* first API call is made.
359-
*
360-
* @return {@link LlmUtilityServiceClient} that makes REST requests to the backing service through
361-
* method calls that map to the API methods.
362-
*/
363-
private LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException {
364-
if (llmUtilityRestClient == null) {
365-
LlmUtilityServiceSettings.Builder settingsBuilder =
366-
LlmUtilityServiceSettings.newHttpJsonBuilder();
367-
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
368-
if (this.credentialsProvider != null) {
369-
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
370-
}
371-
HeaderProvider headerProvider =
372-
FixedHeaderProvider.create(
373-
"user-agent",
374-
String.format(
375-
"%s/%s",
376-
Constants.USER_AGENT_HEADER,
377-
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
378-
settingsBuilder.setHeaderProvider(headerProvider);
379-
// Disable the warning message logged in getApplicationDefault
380-
Logger defaultCredentialsProviderLogger =
381-
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
382-
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
383-
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
384-
llmUtilityRestClient = LlmUtilityServiceClient.create(settingsBuilder.build());
385-
defaultCredentialsProviderLogger.setLevel(previousLevel);
313+
private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IOException {
314+
LlmUtilityServiceSettings.Builder settingsBuilder;
315+
if (transport == Transport.REST) {
316+
settingsBuilder = LlmUtilityServiceSettings.newHttpJsonBuilder();
317+
} else {
318+
settingsBuilder = LlmUtilityServiceSettings.newBuilder();
319+
}
320+
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
321+
if (this.credentialsProvider != null) {
322+
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
386323
}
387-
return llmUtilityRestClient;
324+
HeaderProvider headerProvider =
325+
FixedHeaderProvider.create(
326+
"user-agent",
327+
String.format(
328+
"%s/%s",
329+
Constants.USER_AGENT_HEADER,
330+
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
331+
settingsBuilder.setHeaderProvider(headerProvider);
332+
return settingsBuilder.build();
388333
}
389334

390335
/** Closes the VertexAI instance together with all its instantiated clients. */
@@ -393,14 +338,8 @@ public void close() {
393338
if (predictionServiceClient != null) {
394339
predictionServiceClient.close();
395340
}
396-
if (predictionServiceRestClient != null) {
397-
predictionServiceRestClient.close();
398-
}
399341
if (llmUtilityClient != null) {
400342
llmUtilityClient.close();
401343
}
402-
if (llmUtilityRestClient != null) {
403-
llmUtilityRestClient.close();
404-
}
405344
}
406345
}

0 commit comments

Comments
 (0)