2222import com .google .cloud .vertexai .api .CountTokensRequest ;
2323import com .google .cloud .vertexai .api .CountTokensResponse ;
2424import com .google .cloud .vertexai .api .GenerateContentRequest ;
25- import com .google .cloud .vertexai .api .GenerateContentRequest .Builder ;
2625import com .google .cloud .vertexai .api .GenerateContentResponse ;
2726import com .google .cloud .vertexai .api .GenerationConfig ;
2827import com .google .cloud .vertexai .api .Part ;
2928import com .google .cloud .vertexai .api .SafetySetting ;
29+ import com .google .cloud .vertexai .api .Tool ;
3030import java .io .IOException ;
3131import java .util .ArrayList ;
3232import java .util .Arrays ;
@@ -40,8 +40,131 @@ public class GenerativeModel {
4040 private final VertexAI vertexAi ;
4141 private GenerationConfig generationConfig = null ;
4242 private List <SafetySetting > safetySettings = null ;
43+ private List <Tool > tools = null ;
4344 private Transport transport ;
4445
46+ public static Builder newBuilder () {
47+ return new Builder ();
48+ }
49+
50+ private GenerativeModel (Builder builder ) {
51+ this .modelName = builder .modelName ;
52+
53+ this .vertexAi = builder .vertexAi ;
54+
55+ this .resourceName =
56+ String .format (
57+ "projects/%s/locations/%s/publishers/google/models/%s" ,
58+ this .vertexAi .getProjectId (), this .vertexAi .getLocation (), this .modelName );
59+
60+ if (builder .generationConfig != null ) {
61+ this .generationConfig = builder .generationConfig ;
62+ }
63+ if (builder .safetySettings != null ) {
64+ this .safetySettings = builder .safetySettings ;
65+ }
66+ if (builder .tools != null ) {
67+ this .tools = builder .tools ;
68+ }
69+
70+ if (builder .transport != null ) {
71+ this .transport = builder .transport ;
72+ } else {
73+ this .transport = this .vertexAi .getTransport ();
74+ }
75+ }
76+
77+ /** Builder class for {@link GenerativeModel}. */
78+ public static class Builder {
79+ private String modelName ;
80+ private VertexAI vertexAi ;
81+ private GenerationConfig generationConfig ;
82+ private List <SafetySetting > safetySettings ;
83+ private List <Tool > tools ;
84+ private Transport transport ;
85+
86+ private Builder () {}
87+
88+ public GenerativeModel build () {
89+ if (this .modelName == null ) {
90+ throw new IllegalArgumentException (
91+ "modelName is required. Please call setModelName() before building." );
92+ }
93+ if (this .vertexAi == null ) {
94+ throw new IllegalArgumentException (
95+ "vertexAi is required. Please call setVertexAi() before building." );
96+ }
97+ return new GenerativeModel (this );
98+ }
99+
100+ /**
101+ * Set the name of the generative model. This is required for building a GenerativeModel
102+ * instance. Supported format: "gemini-pro", "models/gemini-pro",
103+ * "publishers/google/models/gemini-pro", where "gemini-pro" is the model name. Valid model
104+ * names can be found at
105+ * https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models
106+ */
107+ public Builder setModelName (String modelName ) {
108+ this .modelName = validateModelName (modelName );
109+ return this ;
110+ }
111+
112+ /**
113+ * Set {@link com.google.cloud.vertexai.VertexAI} that contains the default configs for the
114+ * generative model. This is required for building a GenerativeModel instance.
115+ */
116+ public Builder setVertexAi (VertexAI vertexAi ) {
117+ this .vertexAi = vertexAi ;
118+ return this ;
119+ }
120+
121+ /**
122+ * Set {@link com.google.cloud.vertexai.api.GenerationConfig} that will be used by default to
123+ * interact with the generative model.
124+ */
125+ public Builder setGenerationConfig (GenerationConfig generationConfig ) {
126+ this .generationConfig = generationConfig ;
127+ return this ;
128+ }
129+
130+ /**
131+ * Set a list of {@link com.google.cloud.vertexai.api.SafetySetting} that will be used by
132+ * default to interact with the generative model.
133+ */
134+ public Builder setSafetySettings (List <SafetySetting > safetySettings ) {
135+ this .safetySettings = new ArrayList <>();
136+ for (SafetySetting safetySetting : safetySettings ) {
137+ if (safetySetting != null ) {
138+ this .safetySettings .add (safetySetting );
139+ }
140+ }
141+ return this ;
142+ }
143+
144+ /**
145+ * Set a list of {@link com.google.cloud.vertexai.api.Tool} that will be used by default to
146+ * interact with the generative model.
147+ */
148+ public Builder setTools (List <Tool > tools ) {
149+ this .tools = new ArrayList <>();
150+ for (Tool tool : tools ) {
151+ if (tool != null ) {
152+ this .tools .add (tool );
153+ }
154+ }
155+ return this ;
156+ }
157+
158+ /**
159+ * Set the {@link Transport} layer for API calls in the generative model. It overrides the
160+ * transport setting in {@link com.google.cloud.vertexai.VertexAI}
161+ */
162+ public Builder setTransport (Transport transport ) {
163+ this .transport = transport ;
164+ return this ;
165+ }
166+ }
167+
45168 /**
46169 * Construct a GenerativeModel instance.
47170 *
@@ -384,7 +507,8 @@ public GenerateContentResponse generateContent(
384507 public GenerateContentResponse generateContent (
385508 List <Content > contents , GenerationConfig generationConfig , List <SafetySetting > safetySettings )
386509 throws IOException {
387- Builder requestBuilder = GenerateContentRequest .newBuilder ().addAllContents (contents );
510+ GenerateContentRequest .Builder requestBuilder =
511+ GenerateContentRequest .newBuilder ().addAllContents (contents );
388512 if (generationConfig != null ) {
389513 requestBuilder .setGenerationConfig (generationConfig );
390514 } else if (this .generationConfig != null ) {
@@ -395,6 +519,9 @@ public GenerateContentResponse generateContent(
395519 } else if (this .safetySettings != null ) {
396520 requestBuilder .addAllSafetySettings (this .safetySettings );
397521 }
522+ if (this .tools != null ) {
523+ requestBuilder .addAllTools (this .tools );
524+ }
398525 return ResponseHandler .aggregateStreamIntoResponse (generateContentStream (requestBuilder ));
399526 }
400527
@@ -655,7 +782,8 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
655782 public ResponseStream <GenerateContentResponse > generateContentStream (
656783 List <Content > contents , GenerationConfig generationConfig , List <SafetySetting > safetySettings )
657784 throws IOException {
658- Builder requestBuilder = GenerateContentRequest .newBuilder ().addAllContents (contents );
785+ GenerateContentRequest .Builder requestBuilder =
786+ GenerateContentRequest .newBuilder ().addAllContents (contents );
659787 if (generationConfig != null ) {
660788 requestBuilder .setGenerationConfig (generationConfig );
661789 } else if (this .generationConfig != null ) {
@@ -666,6 +794,9 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
666794 } else if (this .safetySettings != null ) {
667795 requestBuilder .addAllSafetySettings (this .safetySettings );
668796 }
797+ if (this .tools != null ) {
798+ requestBuilder .addAllTools (this .tools );
799+ }
669800 return generateContentStream (requestBuilder );
670801 }
671802
@@ -678,8 +809,8 @@ public ResponseStream<GenerateContentResponse> generateContentStream(
678809 * com.google.cloud.vertexai.api.GenerateContentResponse}
679810 * @throws IOException if an I/O error occurs while making the API call
680811 */
681- private ResponseStream <GenerateContentResponse > generateContentStream (Builder requestBuilder )
682- throws IOException {
812+ private ResponseStream <GenerateContentResponse > generateContentStream (
813+ GenerateContentRequest . Builder requestBuilder ) throws IOException {
683814 GenerateContentRequest request = requestBuilder .setModel (this .resourceName ).build ();
684815 ResponseStream <GenerateContentResponse > responseStream = null ;
685816 if (this .transport == Transport .REST ) {
@@ -723,6 +854,16 @@ public void setSafetySettings(List<SafetySetting> safetySettings) {
723854 }
724855 }
725856
857+ /**
858+ * Sets the value for {@link #getTools}, which will be used by default for generating response.
859+ */
860+ public void setTools (List <Tool > tools ) {
861+ this .tools = new ArrayList <>();
862+ for (Tool tool : tools ) {
863+ this .tools .add (tool );
864+ }
865+ }
866+
726867 /**
727868 * Sets the value for {@link #getTransport}, which defines the layer for API calls in this
728869 * generative model.
@@ -760,6 +901,15 @@ public List<SafetySetting> getSafetySettings() {
760901 }
761902 }
762903
904+ /** Returns a list of {@link com.google.cloud.vertexai.api.Tool} of this generative model. */
905+ public List <Tool > getTools () {
906+ if (this .tools != null ) {
907+ return Collections .unmodifiableList (this .tools );
908+ } else {
909+ return null ;
910+ }
911+ }
912+
763913 public ChatSession startChat () {
764914 return new ChatSession (this );
765915 }
0 commit comments