1616
1717package com .google .cloud .vertexai ;
1818
19+ import static com .google .common .base .Preconditions .checkArgument ;
20+ import static com .google .common .base .Preconditions .checkNotNull ;
21+
1922import com .google .api .core .InternalApi ;
2023import com .google .api .gax .core .CredentialsProvider ;
2124import com .google .api .gax .core .FixedCredentialsProvider ;
2831import com .google .cloud .vertexai .api .LlmUtilityServiceSettings ;
2932import com .google .cloud .vertexai .api .PredictionServiceClient ;
3033import com .google .cloud .vertexai .api .PredictionServiceSettings ;
34+ import com .google .common .base .Strings ;
3135import java .io .IOException ;
3236import java .util .List ;
37+ import java .util .concurrent .locks .ReentrantLock ;
3338import java .util .logging .Level ;
3439import java .util .logging .Logger ;
3540
@@ -56,9 +61,8 @@ public class VertexAI implements AutoCloseable {
5661 private Transport transport = Transport .GRPC ;
5762 // The clients will be instantiated lazily
5863 private PredictionServiceClient predictionServiceClient = null ;
59- private PredictionServiceClient predictionServiceRestClient = null ;
6064 private LlmUtilityServiceClient llmUtilityClient = null ;
61- private LlmUtilityServiceClient llmUtilityRestClient = null ;
65+ private final ReentrantLock lock = new ReentrantLock () ;
6266
6367 /**
6468 * Construct a VertexAI instance.
@@ -193,32 +197,35 @@ public Credentials getCredentials() throws IOException {
193197
194198 /** Sets the value for {@link #getTransport()}. */
195199 public void setTransport (Transport transport ) {
200+ checkNotNull (transport , "Transport can't be null." );
201+ if (this .transport == transport ) {
202+ return ;
203+ }
204+
196205 this .transport = transport ;
206+ resetClients ();
197207 }
198208
199209 /** Sets the value for {@link #getApiEndpoint()}. */
200210 public void setApiEndpoint (String apiEndpoint ) {
211+ checkArgument (!Strings .isNullOrEmpty (apiEndpoint ), "Api endpoint can't be null or empty." );
212+ if (this .apiEndpoint == apiEndpoint ) {
213+ return ;
214+ }
201215 this .apiEndpoint = apiEndpoint ;
216+ resetClients ();
217+ }
202218
219+ private void resetClients () {
203220 if (this .predictionServiceClient != null ) {
204221 this .predictionServiceClient .close ();
205222 this .predictionServiceClient = null ;
206223 }
207224
208- if (this .predictionServiceRestClient != null ) {
209- this .predictionServiceRestClient .close ();
210- this .predictionServiceRestClient = null ;
211- }
212-
213225 if (this .llmUtilityClient != null ) {
214226 this .llmUtilityClient .close ();
215227 this .llmUtilityClient = null ;
216228 }
217-
218- if (this .llmUtilityRestClient != null ) {
219- this .llmUtilityRestClient .close ();
220- this .llmUtilityRestClient = null ;
221- }
222229 }
223230
224231 /**
@@ -230,78 +237,47 @@ public void setApiEndpoint(String apiEndpoint) {
230237 */
231238 @ InternalApi
232239 public PredictionServiceClient getPredictionServiceClient () throws IOException {
233- if (this .transport == Transport .GRPC ) {
234- return getPredictionServiceGrpcClient ();
235- } else {
236- return getPredictionServiceRestClient ();
240+ if (predictionServiceClient != null ) {
241+ return predictionServiceClient ;
237242 }
238- }
239-
240- /**
241- * Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
242- * first prediction API call is made.
243- *
244- * @return {@link PredictionServiceClient} that send GRPC requests to the backing service through
245- * method calls that map to the API methods.
246- */
247- private PredictionServiceClient getPredictionServiceGrpcClient () throws IOException {
248- if (predictionServiceClient == null ) {
249- PredictionServiceSettings .Builder settingsBuilder = PredictionServiceSettings .newBuilder ();
250- settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
251- if (this .credentialsProvider != null ) {
252- settingsBuilder .setCredentialsProvider (this .credentialsProvider );
243+ lock .lock ();
244+ try {
245+ if (predictionServiceClient == null ) {
246+ PredictionServiceSettings settings = getPredictionServiceSettings ();
247+ // Disable the warning message logged in getApplicationDefault
248+ Logger defaultCredentialsProviderLogger =
249+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
250+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
251+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
252+ predictionServiceClient = PredictionServiceClient .create (settings );
253+ defaultCredentialsProviderLogger .setLevel (previousLevel );
253254 }
254- HeaderProvider headerProvider =
255- FixedHeaderProvider .create (
256- "user-agent" ,
257- String .format (
258- "%s/%s" ,
259- Constants .USER_AGENT_HEADER ,
260- GaxProperties .getLibraryVersion (PredictionServiceSettings .class )));
261- settingsBuilder .setHeaderProvider (headerProvider );
262- // Disable the warning message logged in getApplicationDefault
263- Logger defaultCredentialsProviderLogger =
264- Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
265- Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
266- defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
267- predictionServiceClient = PredictionServiceClient .create (settingsBuilder .build ());
268- defaultCredentialsProviderLogger .setLevel (previousLevel );
255+ return predictionServiceClient ;
256+ } finally {
257+ lock .unlock ();
269258 }
270- return predictionServiceClient ;
271259 }
272260
273- /**
274- * Returns the {@link PredictionServiceClient} with REST. The client will be instantiated when the
275- * first prediction API call is made.
276- *
277- * @return {@link PredictionServiceClient} that send REST requests to the backing service through
278- * method calls that map to the API methods.
279- */
280- private PredictionServiceClient getPredictionServiceRestClient () throws IOException {
281- if (predictionServiceRestClient == null ) {
282- PredictionServiceSettings .Builder settingsBuilder =
283- PredictionServiceSettings .newHttpJsonBuilder ();
284- settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
285- if (this .credentialsProvider != null ) {
286- settingsBuilder .setCredentialsProvider (this .credentialsProvider );
287- }
288- HeaderProvider headerProvider =
289- FixedHeaderProvider .create (
290- "user-agent" ,
291- String .format (
292- "%s/%s" ,
293- Constants .USER_AGENT_HEADER ,
294- GaxProperties .getLibraryVersion (PredictionServiceSettings .class )));
295- settingsBuilder .setHeaderProvider (headerProvider );
296- // Disable the warning message logged in getApplicationDefault
297- Logger defaultCredentialsProviderLogger =
298- Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
299- Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
300- defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
301- predictionServiceRestClient = PredictionServiceClient .create (settingsBuilder .build ());
302- defaultCredentialsProviderLogger .setLevel (previousLevel );
261+ private PredictionServiceSettings getPredictionServiceSettings () throws IOException {
262+ PredictionServiceSettings .Builder builder ;
263+ if (transport == Transport .REST ) {
264+ builder = PredictionServiceSettings .newHttpJsonBuilder ();
265+ } else {
266+ builder = PredictionServiceSettings .newBuilder ();
267+ }
268+ builder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
269+ if (this .credentialsProvider != null ) {
270+ builder .setCredentialsProvider (this .credentialsProvider );
303271 }
304- return predictionServiceRestClient ;
272+ HeaderProvider headerProvider =
273+ FixedHeaderProvider .create (
274+ "user-agent" ,
275+ String .format (
276+ "%s/%s" ,
277+ Constants .USER_AGENT_HEADER ,
278+ GaxProperties .getLibraryVersion (PredictionServiceSettings .class )));
279+ builder .setHeaderProvider (headerProvider );
280+ return builder .build ();
305281 }
306282
307283 /**
@@ -313,78 +289,47 @@ private PredictionServiceClient getPredictionServiceRestClient() throws IOExcept
313289 */
314290 @ InternalApi
315291 public LlmUtilityServiceClient getLlmUtilityClient () throws IOException {
316- if (this .transport == Transport .GRPC ) {
317- return getLlmUtilityGrpcClient ();
318- } else {
319- return getLlmUtilityRestClient ();
292+ if (llmUtilityClient != null ) {
293+ return llmUtilityClient ;
320294 }
321- }
322-
323- /**
324- * Returns the {@link LlmUtilityServiceClient} with GRPC. The client will be instantiated when the
325- * first API call is made.
326- *
327- * @return {@link LlmUtilityServiceClient} that makes gRPC calls to the backing service through
328- * method calls that map to the API methods.
329- */
330- private LlmUtilityServiceClient getLlmUtilityGrpcClient () throws IOException {
331- if (llmUtilityClient == null ) {
332- LlmUtilityServiceSettings .Builder settingsBuilder = LlmUtilityServiceSettings .newBuilder ();
333- settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
334- if (this .credentialsProvider != null ) {
335- settingsBuilder .setCredentialsProvider (this .credentialsProvider );
295+ lock .lock ();
296+ try {
297+ if (llmUtilityClient == null ) {
298+ LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings ();
299+ // Disable the warning message logged in getApplicationDefault
300+ Logger defaultCredentialsProviderLogger =
301+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
302+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
303+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
304+ llmUtilityClient = LlmUtilityServiceClient .create (settings );
305+ defaultCredentialsProviderLogger .setLevel (previousLevel );
336306 }
337- HeaderProvider headerProvider =
338- FixedHeaderProvider .create (
339- "user-agent" ,
340- String .format (
341- "%s/%s" ,
342- Constants .USER_AGENT_HEADER ,
343- GaxProperties .getLibraryVersion (LlmUtilityServiceSettings .class )));
344- settingsBuilder .setHeaderProvider (headerProvider );
345- // Disable the warning message logged in getApplicationDefault
346- Logger defaultCredentialsProviderLogger =
347- Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
348- Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
349- defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
350- llmUtilityClient = LlmUtilityServiceClient .create (settingsBuilder .build ());
351- defaultCredentialsProviderLogger .setLevel (previousLevel );
307+ return llmUtilityClient ;
308+ } finally {
309+ lock .unlock ();
352310 }
353- return llmUtilityClient ;
354311 }
355312
356- /**
357- * Returns the {@link LlmUtilityServiceClient} with REST. The client will be instantiated when the
358- * first API call is made.
359- *
360- * @return {@link LlmUtilityServiceClient} that makes REST requests to the backing service through
361- * method calls that map to the API methods.
362- */
363- private LlmUtilityServiceClient getLlmUtilityRestClient () throws IOException {
364- if (llmUtilityRestClient == null ) {
365- LlmUtilityServiceSettings .Builder settingsBuilder =
366- LlmUtilityServiceSettings .newHttpJsonBuilder ();
367- settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
368- if (this .credentialsProvider != null ) {
369- settingsBuilder .setCredentialsProvider (this .credentialsProvider );
370- }
371- HeaderProvider headerProvider =
372- FixedHeaderProvider .create (
373- "user-agent" ,
374- String .format (
375- "%s/%s" ,
376- Constants .USER_AGENT_HEADER ,
377- GaxProperties .getLibraryVersion (LlmUtilityServiceSettings .class )));
378- settingsBuilder .setHeaderProvider (headerProvider );
379- // Disable the warning message logged in getApplicationDefault
380- Logger defaultCredentialsProviderLogger =
381- Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
382- Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
383- defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
384- llmUtilityRestClient = LlmUtilityServiceClient .create (settingsBuilder .build ());
385- defaultCredentialsProviderLogger .setLevel (previousLevel );
313+ private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings () throws IOException {
314+ LlmUtilityServiceSettings .Builder settingsBuilder ;
315+ if (transport == Transport .REST ) {
316+ settingsBuilder = LlmUtilityServiceSettings .newHttpJsonBuilder ();
317+ } else {
318+ settingsBuilder = LlmUtilityServiceSettings .newBuilder ();
319+ }
320+ settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
321+ if (this .credentialsProvider != null ) {
322+ settingsBuilder .setCredentialsProvider (this .credentialsProvider );
386323 }
387- return llmUtilityRestClient ;
324+ HeaderProvider headerProvider =
325+ FixedHeaderProvider .create (
326+ "user-agent" ,
327+ String .format (
328+ "%s/%s" ,
329+ Constants .USER_AGENT_HEADER ,
330+ GaxProperties .getLibraryVersion (LlmUtilityServiceSettings .class )));
331+ settingsBuilder .setHeaderProvider (headerProvider );
332+ return settingsBuilder .build ();
388333 }
389334
390335 /** Closes the VertexAI instance together with all its instantiated clients. */
@@ -393,14 +338,8 @@ public void close() {
393338 if (predictionServiceClient != null ) {
394339 predictionServiceClient .close ();
395340 }
396- if (predictionServiceRestClient != null ) {
397- predictionServiceRestClient .close ();
398- }
399341 if (llmUtilityClient != null ) {
400342 llmUtilityClient .close ();
401343 }
402- if (llmUtilityRestClient != null ) {
403- llmUtilityRestClient .close ();
404- }
405344 }
406345}
0 commit comments