From e20ff59741365ed5bf4d6ee39c5175ae51e04122 Mon Sep 17 00:00:00 2001 From: Michael Ameri Date: Fri, 17 Jan 2025 23:19:23 +0100 Subject: [PATCH 1/4] add function calling --- .../java/swiss/ameri/gemini/api/Content.java | 61 +++++++++++++ .../swiss/ameri/gemini/api/FunctionCall.java | 16 ++++ .../ameri/gemini/api/FunctionDeclaration.java | 20 +++++ .../ameri/gemini/api/FunctionResponse.java | 17 ++++ .../java/swiss/ameri/gemini/api/GenAi.java | 80 +++++++++++++++-- .../ameri/gemini/api/GenerativeModel.java | 35 ++++++-- .../ameri/gemini/tester/GeminiTester.java | 85 ++++++++++++++++++- 7 files changed, 297 insertions(+), 17 deletions(-) create mode 100644 gemini-api/src/main/java/swiss/ameri/gemini/api/FunctionCall.java create mode 100644 gemini-api/src/main/java/swiss/ameri/gemini/api/FunctionDeclaration.java create mode 100644 gemini-api/src/main/java/swiss/ameri/gemini/api/FunctionResponse.java diff --git a/gemini-api/src/main/java/swiss/ameri/gemini/api/Content.java b/gemini-api/src/main/java/swiss/ameri/gemini/api/Content.java index 57a39ad..1234ca8 100644 --- a/gemini-api/src/main/java/swiss/ameri/gemini/api/Content.java +++ b/gemini-api/src/main/java/swiss/ameri/gemini/api/Content.java @@ -9,6 +9,41 @@ */ public sealed interface Content { + /** + * Role belonging to this turn in the conversation. + * + * @return string value of a {@link Role} + */ + String role(); + + /** + * Create a {@link FunctionCallContent}. + * + * @param role belonging to this turn in the conversation. + * @param functionCall by the role + * @return a {@link FunctionCallContent} + */ + static FunctionCallContent functionCallContent( + Role role, + FunctionCall functionCall + ) { + return new FunctionCallContent(role == null ? null : role.roleName(), functionCall); + } + + /** + * Create a {@link FunctionCallContent}. + * + * @param role belonging to this turn in the conversation. + * @param functionResponse by the role + * @return a {@link FunctionResponseContent} + */ + static FunctionResponseContent functionResponseContent( + Role role, + FunctionResponse functionResponse + ) { + return new FunctionResponseContent(role == null ? null : role.roleName(), functionResponse); + } + /** * Create a {@link TextContent}. * @@ -49,6 +84,32 @@ static TextAndMediaContent.TextAndMediaContentBuilder textAndMediaContentBuilder return TextAndMediaContent.builder(); } + /** + * A predicted FunctionCall returned from the model that contains a string representing the FunctionDeclaration.name + * with the arguments and their values. + * + * @param role belonging to this turn in the conversation. see {@link Role} + * @param functionCall returned form the model + */ + record FunctionCallContent( + String role, + FunctionCall functionCall + ) implements Content { + } + + /** + * The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name and + * a structured JSON object containing any output from the function is used as context to the model. + * This should contain the result of aFunctionCall made based on model prediction. + * + * @param role belonging to this turn in the conversation. see {@link Role} + * @param functionResponse response to a function call + */ + record FunctionResponseContent( + String role, + FunctionResponse functionResponse + ) implements Content { + } /** * A part of a conversation that contains text. diff --git a/gemini-api/src/main/java/swiss/ameri/gemini/api/FunctionCall.java b/gemini-api/src/main/java/swiss/ameri/gemini/api/FunctionCall.java new file mode 100644 index 0000000..85fc5d9 --- /dev/null +++ b/gemini-api/src/main/java/swiss/ameri/gemini/api/FunctionCall.java @@ -0,0 +1,16 @@ +package swiss.ameri.gemini.api; + +import java.util.Map; + +/** + * A predicted FunctionCall returned from the model that contains a string representing the FunctionDeclaration.name + * with the arguments and their values. + * + * @param name Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63. + * @param format Optional. The function parameters and values in JSON object format. + */ +public record FunctionCall( + String name, + Map format +) { +} diff --git a/gemini-api/src/main/java/swiss/ameri/gemini/api/FunctionDeclaration.java b/gemini-api/src/main/java/swiss/ameri/gemini/api/FunctionDeclaration.java new file mode 100644 index 0000000..350857b --- /dev/null +++ b/gemini-api/src/main/java/swiss/ameri/gemini/api/FunctionDeclaration.java @@ -0,0 +1,20 @@ +package swiss.ameri.gemini.api; + +/** + * Structured representation of a function declaration as defined by the OpenAPI 3.03 specification. + * Included in this declaration are the function name and parameters. + * This FunctionDeclaration is a representation of a block of code that can be used as a Tool by the model and executed by the client. + * + * @param name Required. The name of the function. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63. + * @param description Required. A brief description of the function. + * @param parameters Optional. Describes the parameters to this function. + * Reflects the Open API 3.03 Parameter Object string Key: the name of the parameter. + * Parameter names are case-sensitive. + * Schema Value: the Schema defining the type used for the parameter. + */ +public record FunctionDeclaration( + String name, + String description, + Schema parameters +) { +} diff --git a/gemini-api/src/main/java/swiss/ameri/gemini/api/FunctionResponse.java b/gemini-api/src/main/java/swiss/ameri/gemini/api/FunctionResponse.java new file mode 100644 index 0000000..1c16246 --- /dev/null +++ b/gemini-api/src/main/java/swiss/ameri/gemini/api/FunctionResponse.java @@ -0,0 +1,17 @@ +package swiss.ameri.gemini.api; + +import java.util.Map; + +/** + * The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name and + * a structured JSON object containing any output from the function is used as context to the model. + * This should contain the result of aFunctionCall made based on model prediction. + * + * @param name Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63. + * @param response Required. The function response in JSON object format. + */ +public record FunctionResponse( + String name, + Map response +) { +} diff --git a/gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java b/gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java index 0417ac1..9106eda 100644 --- a/gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java +++ b/gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java @@ -8,7 +8,13 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; -import java.util.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; @@ -331,6 +337,10 @@ public CompletableFuture> embedContents( private static GenerateContentRequest convert(GenerativeModel model) { List generationContents = convertGenerationContents(model); + List tools = new ArrayList<>(); + if (!model.functionDeclarations().isEmpty()) { + tools.add(new Tool(model.functionDeclarations())); + } return new GenerateContentRequest( model.modelName(), generationContents, @@ -341,7 +351,9 @@ private static GenerateContentRequest convert(GenerativeModel model) { model.systemInstruction().stream() .map(SystemInstructionPart::new) .toList() - ) + ), + tools.isEmpty() ? null : + List.of(new Tool(model.functionDeclarations())) ); } @@ -355,6 +367,8 @@ private static List convertGenerationContents(GenerativeModel List.of( new GenerationPart( textContent.text(), + null, + null, null ) ) @@ -368,7 +382,9 @@ private static List convertGenerationContents(GenerativeModel new InlineData( imageContent.media().mimeType(), imageContent.media().mediaBase64() - ) + ), + null, + null ) ) ); @@ -379,6 +395,8 @@ private static List convertGenerationContents(GenerativeModel Stream.of( new GenerationPart( textAndImagesContent.text(), + null, + null, null ) ), @@ -388,10 +406,36 @@ private static List convertGenerationContents(GenerativeModel new InlineData( imageData.mimeType(), imageData.mediaBase64() - ) + ), + null, + null )) ).toList() ); + } else if (content instanceof Content.FunctionCallContent functionCallContent) { + return new GenerationContent( + functionCallContent.role(), + List.of( + new GenerationPart( + null, + null, + functionCallContent.functionCall(), + null + ) + ) + ); + } else if (content instanceof Content.FunctionResponseContent functionResponseContent) { + return new GenerationContent( + functionResponseContent.role(), + List.of( + new GenerationPart( + null, + null, + null, + functionResponseContent.functionResponse() + ) + ) + ); } else { throw new GeminiException("Unexpected content:\n" + content); } @@ -423,11 +467,13 @@ public void close() { * * @param id the id of the request, for subsequent queries regarding metadata of the query * @param text of the generated content + * @param functionCall Optional. if the model wants to call a function * @param finishReason the reason generation was finished, according to FinishReason */ public record GeneratedContent( UUID id, String text, + FunctionCall functionCall, String finishReason ) { } @@ -555,9 +601,10 @@ private GeneratedContent parse(String body, UUID uuid) { // we assume we always get a candidate. Otherwise, there is probably something wrong with the input var candidate = gcr.candidates().get(0); if (candidate.content() == null) { - return new GeneratedContent(uuid, "", candidate.finishReason()); + return new GeneratedContent(uuid, "", null, candidate.finishReason()); } - return new GeneratedContent(uuid, candidate.content().parts().get(0).text(), candidate.finishReason()); + GenerationPart firstPart = candidate.content().parts().get(0); + return new GeneratedContent(uuid, firstPart.text(), firstPart.functionCall(), candidate.finishReason()); } catch (Exception e) { throw new GeminiException("Unexpected body:\n" + body, e); } @@ -613,7 +660,17 @@ private record GenerateContentRequest( List contents, List safetySettings, GenerationConfig generationConfig, - SystemInstruction systemInstruction + SystemInstruction systemInstruction, + List tools + ) { + } + + /** + * See Tool + */ + private record Tool( + List functionDeclarations + // still missing CodeExecution and GoogleSearchRetrieval ) { } @@ -633,10 +690,15 @@ private record GenerationContent( ) { } + /** + * See Part + */ private record GenerationPart( - // contains one or the other + // contains one of these String text, - InlineData inline_data + InlineData inline_data, + FunctionCall functionCall, + FunctionResponse functionResponse ) { } diff --git a/gemini-api/src/main/java/swiss/ameri/gemini/api/GenerativeModel.java b/gemini-api/src/main/java/swiss/ameri/gemini/api/GenerativeModel.java index 2bb9b48..77d4377 100644 --- a/gemini-api/src/main/java/swiss/ameri/gemini/api/GenerativeModel.java +++ b/gemini-api/src/main/java/swiss/ameri/gemini/api/GenerativeModel.java @@ -6,18 +6,20 @@ /** * Contains all the information needed for Gemini API to generate new content. * - * @param modelName to be used. see {@link ModelVariant}. Must start with "models/" - * @param contents given as input to Gemini API - * @param safetySettings optional, to adjust safety settings - * @param generationConfig optional, to configure the prompt - * @param systemInstruction optional, system instruction + * @param modelName to be used. see {@link ModelVariant}. Must start with "models/" + * @param contents given as input to Gemini API + * @param safetySettings optional, to adjust safety settings + * @param generationConfig optional, to configure the prompt + * @param systemInstruction optional, system instruction + * @param functionDeclarations optional, functions the model may call */ public record GenerativeModel( String modelName, List contents, List safetySettings, GenerationConfig generationConfig, - List systemInstruction + List systemInstruction, + List functionDeclarations ) { /** @@ -38,6 +40,7 @@ public static class GenerativeModelBuilder { private final List contents = new ArrayList<>(); private final List safetySettings = new ArrayList<>(); private final List systemInstructions = new ArrayList<>(); + private final List functionDeclarations = new ArrayList<>(); private GenerativeModelBuilder() { } @@ -96,6 +99,17 @@ public GenerativeModelBuilder addSafetySetting(SafetySetting safetySetting) { return this; } + /** + * Add function declarations + * + * @param functionDeclaration to be added + * @return this + */ + public GenerativeModelBuilder addFunctionDeclaration(FunctionDeclaration functionDeclaration) { + this.functionDeclarations.add(functionDeclaration); + return this; + } + /** * Set the generation config * @@ -113,7 +127,14 @@ public GenerativeModelBuilder generationConfig(GenerationConfig generationConfig * @return a completed (not necessarily validated) {@link GenerativeModel} */ public GenerativeModel build() { - return new GenerativeModel(modelName, contents, safetySettings, generationConfig, systemInstructions); + return new GenerativeModel( + modelName, + contents, + safetySettings, + generationConfig, + systemInstructions, + functionDeclarations + ); } } diff --git a/gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java b/gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java index e28ebde..294c8fa 100644 --- a/gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java +++ b/gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java @@ -1,6 +1,15 @@ package swiss.ameri.gemini.tester; -import swiss.ameri.gemini.api.*; +import swiss.ameri.gemini.api.Content; +import swiss.ameri.gemini.api.FunctionCall; +import swiss.ameri.gemini.api.FunctionDeclaration; +import swiss.ameri.gemini.api.FunctionResponse; +import swiss.ameri.gemini.api.GenAi; +import swiss.ameri.gemini.api.GenerationConfig; +import swiss.ameri.gemini.api.GenerativeModel; +import swiss.ameri.gemini.api.ModelVariant; +import swiss.ameri.gemini.api.SafetySetting; +import swiss.ameri.gemini.api.Schema; import swiss.ameri.gemini.gson.GsonJsonParser; import swiss.ameri.gemini.spi.JsonParser; @@ -44,6 +53,8 @@ public static void main(String[] args) throws Exception { multiChatTurn(genAi); textAndImage(genAi); embedContents(genAi); + functionCall(genAi); + functionResponse(genAi); } @@ -111,6 +122,78 @@ private static void multiChatTurn(GenAi genAi) { .forEach(System.out::println); } + private static void functionCall(GenAi genAi) throws ExecutionException, InterruptedException, TimeoutException { + System.out.println("----- Function call"); + GenerativeModel chatModel = GenerativeModel.builder() + .modelName(ModelVariant.GEMINI_1_5_PRO) + .addContent(new Content.TextContent( + Content.Role.USER.roleName(), + "What is the current weather in Zurich?" + )) + .addFunctionDeclaration(new FunctionDeclaration( + "getCurrentWeather", + "Get the current weather for a city.", + Schema.builder() + .type(Schema.Type.OBJECT) + .properties(Map.of("city", Schema.builder() + .type(Schema.Type.STRING) + .build())) + .build() + )) + .build(); + genAi.generateContent(chatModel) + .thenAccept(generatedContent -> { + System.out.println(generatedContent); + if (generatedContent.functionCall() == null) { + throw new RuntimeException("Expected a function call..."); + } + }) + .get(20, TimeUnit.SECONDS); + } + + private static void functionResponse(GenAi genAi) throws ExecutionException, InterruptedException, TimeoutException { + System.out.println("----- Function response"); + GenerativeModel chatModel = GenerativeModel.builder() + .modelName(ModelVariant.GEMINI_1_5_PRO) + .addContent(new Content.TextContent( + Content.Role.USER.roleName(), + "What is the current weather in Zurich?" + )) + .addContent(Content.functionCallContent( + Content.Role.MODEL, + new FunctionCall( + "getCurrentWeather", + null + ) + )) + .addContent(Content.functionResponseContent( + Content.Role.USER, + new FunctionResponse( + "getCurrentWeather", + Map.of("temperatureCelsius", "13") + ) + )) + .addFunctionDeclaration(new FunctionDeclaration( + "getCurrentWeather", + "Get the current weather for a city.", + Schema.builder() + .type(Schema.Type.OBJECT) + .properties(Map.of("city", Schema.builder() + .type(Schema.Type.STRING) + .build())) + .build() + )) + .build(); + genAi.generateContent(chatModel) + .thenAccept(generatedContent -> { + System.out.println(generatedContent); + if (generatedContent.text() == null) { + throw new RuntimeException("Expected a text..."); + } + }) + .get(20, TimeUnit.SECONDS); + } + private static void generateContentStream(GenAi genAi) { System.out.println("----- Generate content (streaming) -- with usage meta data"); var model = createStoryModel(); From 249abba9d2de9d2a1fa003888c5279371cc5e4ef Mon Sep 17 00:00:00 2001 From: Michael Ameri Date: Sat, 18 Jan 2025 09:53:37 +0100 Subject: [PATCH 2/4] fix naming --- .../src/main/java/swiss/ameri/gemini/api/FunctionCall.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gemini-api/src/main/java/swiss/ameri/gemini/api/FunctionCall.java b/gemini-api/src/main/java/swiss/ameri/gemini/api/FunctionCall.java index 85fc5d9..a91718d 100644 --- a/gemini-api/src/main/java/swiss/ameri/gemini/api/FunctionCall.java +++ b/gemini-api/src/main/java/swiss/ameri/gemini/api/FunctionCall.java @@ -6,11 +6,11 @@ * A predicted FunctionCall returned from the model that contains a string representing the FunctionDeclaration.name * with the arguments and their values. * - * @param name Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63. - * @param format Optional. The function parameters and values in JSON object format. + * @param name Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63. + * @param args Optional. The function parameters and values in JSON object format. */ public record FunctionCall( String name, - Map format + Map args ) { } From ac2c28cf8a6c0b267575277f2c70cb8b0868e2cf Mon Sep 17 00:00:00 2001 From: Michael Ameri Date: Sat, 18 Jan 2025 16:25:00 +0100 Subject: [PATCH 3/4] release 1beta.0.2.7 --- gemini-api/pom.xml | 2 +- gemini-gson/pom.xml | 2 +- gemini-tester/pom.xml | 2 +- pom.xml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gemini-api/pom.xml b/gemini-api/pom.xml index db31fa2..76751c2 100644 --- a/gemini-api/pom.xml +++ b/gemini-api/pom.xml @@ -6,7 +6,7 @@ swiss.ameri gemini - 1beta.0.2.7-SNAPSHOT + 1beta.0.2.7 gemini-api diff --git a/gemini-gson/pom.xml b/gemini-gson/pom.xml index 7f2681d..a8f9034 100644 --- a/gemini-gson/pom.xml +++ b/gemini-gson/pom.xml @@ -7,7 +7,7 @@ swiss.ameri gemini - 1beta.0.2.7-SNAPSHOT + 1beta.0.2.7 gemini-gson diff --git a/gemini-tester/pom.xml b/gemini-tester/pom.xml index 4f8081e..6e863bd 100644 --- a/gemini-tester/pom.xml +++ b/gemini-tester/pom.xml @@ -7,7 +7,7 @@ swiss.ameri gemini - 1beta.0.2.7-SNAPSHOT + 1beta.0.2.7 gemini-tester diff --git a/pom.xml b/pom.xml index ac4448a..8ec67f1 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ swiss.ameri gemini - 1beta.0.2.7-SNAPSHOT + 1beta.0.2.7 pom From 261f87c56e1a3b02a06611bdac5dedaeee759812 Mon Sep 17 00:00:00 2001 From: Michael Ameri Date: Sat, 18 Jan 2025 16:27:16 +0100 Subject: [PATCH 4/4] next release 1beta.0.2.8-SNAPSHOT --- gemini-api/pom.xml | 2 +- gemini-gson/pom.xml | 2 +- gemini-tester/pom.xml | 2 +- pom.xml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gemini-api/pom.xml b/gemini-api/pom.xml index 76751c2..cb65a8a 100644 --- a/gemini-api/pom.xml +++ b/gemini-api/pom.xml @@ -6,7 +6,7 @@ swiss.ameri gemini - 1beta.0.2.7 + 1beta.0.2.8-SNAPSHOT gemini-api diff --git a/gemini-gson/pom.xml b/gemini-gson/pom.xml index a8f9034..5a72c63 100644 --- a/gemini-gson/pom.xml +++ b/gemini-gson/pom.xml @@ -7,7 +7,7 @@ swiss.ameri gemini - 1beta.0.2.7 + 1beta.0.2.8-SNAPSHOT gemini-gson diff --git a/gemini-tester/pom.xml b/gemini-tester/pom.xml index 6e863bd..5c97ea2 100644 --- a/gemini-tester/pom.xml +++ b/gemini-tester/pom.xml @@ -7,7 +7,7 @@ swiss.ameri gemini - 1beta.0.2.7 + 1beta.0.2.8-SNAPSHOT gemini-tester diff --git a/pom.xml b/pom.xml index 8ec67f1..42d82ca 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ swiss.ameri gemini - 1beta.0.2.7 + 1beta.0.2.8-SNAPSHOT pom