diff --git a/gemini-api/pom.xml b/gemini-api/pom.xml
index 596b85c..6398210 100644
--- a/gemini-api/pom.xml
+++ b/gemini-api/pom.xml
@@ -6,7 +6,7 @@
swiss.ameri
gemini
- 1beta.0.2.4
+ 1beta.0.2.6-SNAPSHOT
gemini-api
diff --git a/gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java b/gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java
index 32c4cb0..0417ac1 100644
--- a/gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java
+++ b/gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java
@@ -8,10 +8,7 @@
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.UUID;
+import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
@@ -145,7 +142,10 @@ public List safetyRatings(UUID id) {
return emptyList();
}
return response.candidates().stream()
- .flatMap(candidate -> candidate.safetyRatings().stream())
+ // when streaming, we sometimes don't get a safety rating... (with 1.5 pro)
+ .map(ResponseCandidate::safetyRatings)
+ .filter(Objects::nonNull)
+ .flatMap(Collection::stream)
.toList();
}
diff --git a/gemini-api/src/main/java/swiss/ameri/gemini/api/GenerationConfig.java b/gemini-api/src/main/java/swiss/ameri/gemini/api/GenerationConfig.java
index 5aa64ff..2142da0 100644
--- a/gemini-api/src/main/java/swiss/ameri/gemini/api/GenerationConfig.java
+++ b/gemini-api/src/main/java/swiss/ameri/gemini/api/GenerationConfig.java
@@ -36,7 +36,7 @@
public record GenerationConfig(
List stopSequences,
String responseMimeType,
- String responseSchema,
+ Schema responseSchema,
Integer maxOutputTokens,
Double temperature,
Double topP,
@@ -53,7 +53,7 @@ public static GenerationConfigBuilder builder() {
public static class GenerationConfigBuilder {
private final List stopSequences = new ArrayList<>();
private String responseMimeType;
- private String responseSchema;
+ private Schema responseSchema;
private Integer maxOutputTokens;
private Double temperature;
private Double topP;
@@ -69,7 +69,7 @@ public GenerationConfigBuilder responseMimeType(String responseMimeType) {
return this;
}
- public GenerationConfigBuilder responseSchema(String responseSchema) {
+ public GenerationConfigBuilder responseSchema(Schema responseSchema) {
this.responseSchema = responseSchema;
return this;
}
diff --git a/gemini-api/src/main/java/swiss/ameri/gemini/api/Schema.java b/gemini-api/src/main/java/swiss/ameri/gemini/api/Schema.java
new file mode 100644
index 0000000..0f29b0c
--- /dev/null
+++ b/gemini-api/src/main/java/swiss/ameri/gemini/api/Schema.java
@@ -0,0 +1,178 @@
+package swiss.ameri.gemini.api;
+
+
+import java.util.List;
+import java.util.Map;
+
+/**
+ * The Schema object allows the definition of input and output data types.
+ * These types can be objects, but also primitives and arrays.
+ * Represents a select subset of an OpenAPI 3.0 schema object.
+ *
+ * @param type Required. Data type.
+ * @param format Optional. The format of the data. This is used only for primitive datatypes.
+ * Supported formats:
+ * for NUMBER type: float, double
+ * for INTEGER type: int32, int64
+ * for STRING type: enum
+ * @param description Optional. A brief description of the parameter. This could contain examples of use.
+ * Parameter description may be formatted as Markdown.
+ * @param nullable Optional. Indicates if the value may be null.
+ * @param ameri_swiss_enum Optional. Note: the ameri_swiss prefix must be removed by the {@link swiss.ameri.gemini.spi.JsonParser}.
+ * Possible values of the element of Type.STRING with enum format.
+ * For example we can define an Enum Direction as :
+ * {type:STRING, format:enum, enum:["EAST", NORTH", "SOUTH", "WEST"]}
+ * @param maxItems Optional. Maximum number of the elements for Type.ARRAY.
+ * @param minItems Optional. Minimum number of the elements for Type.ARRAY.
+ * @param properties Optional. Properties of Type.OBJECT.
+ * An object containing a list of "key": value pairs. Example:
+ * { "name": "wrench", "mass": "1.3kg", "count": "3" }.
+ * @param required Optional. Required properties of Type.OBJECT.
+ * @param items Optional. Schema of the elements of Type.ARRAY.
+ * @see Schema for further information.
+ */
+public record Schema(
+ Type type,
+ String format,
+ String description,
+ Boolean nullable,
+ List ameri_swiss_enum,
+ String maxItems,
+ String minItems,
+ Map properties,
+ List required,
+ Schema items
+) {
+
+
+ /**
+ * Create a {@link SchemaBuilder}.
+ *
+ * @return an empty {@link SchemaBuilder}
+ */
+ public static SchemaBuilder builder() {
+ return new SchemaBuilder();
+ }
+
+ /**
+ * A builder for {@link Schema}. Currently, does not validate the fields when building the model. Not thread-safe.
+ */
+ public static class SchemaBuilder {
+ private Type type;
+ private String format;
+ private String description;
+ private Boolean nullable;
+ private List ameri_swiss_enum;
+ private String maxItems;
+ private String minItems;
+ private Map properties;
+ private List required;
+ private Schema items;
+
+
+ private SchemaBuilder() {
+ }
+
+ public Schema build() {
+ return new Schema(
+ this.type,
+ this.format,
+ this.description,
+ this.nullable,
+ this.ameri_swiss_enum,
+ this.maxItems,
+ this.minItems,
+ this.properties,
+ this.required,
+ this.items
+ );
+ }
+
+ public SchemaBuilder type(Type type) {
+ this.type = type;
+ return this;
+ }
+
+ public SchemaBuilder format(String format) {
+ this.format = format;
+ return this;
+ }
+
+ public SchemaBuilder description(String description) {
+ this.description = description;
+ return this;
+ }
+
+ public SchemaBuilder nullable(Boolean nullable) {
+ this.nullable = nullable;
+ return this;
+ }
+
+ public SchemaBuilder ameri_swiss_enum(List ameri_swiss_enum) {
+ this.ameri_swiss_enum = ameri_swiss_enum;
+ return this;
+ }
+
+ public SchemaBuilder maxItems(String maxItems) {
+ this.maxItems = maxItems;
+ return this;
+ }
+
+ public SchemaBuilder minItems(String minItems) {
+ this.minItems = minItems;
+ return this;
+ }
+
+ public SchemaBuilder properties(Map properties) {
+ this.properties = properties;
+ return this;
+ }
+
+ public SchemaBuilder required(List required) {
+ this.required = required;
+ return this;
+ }
+
+ public SchemaBuilder items(Schema items) {
+ this.items = items;
+ return this;
+ }
+ }
+
+ /**
+ * Type contains the list of OpenAPI data types.
+ *
+ * @see Data types
+ */
+ public enum Type {
+ /**
+ * Not specified, should not be used.
+ */
+ TYPE_UNSPECIFIED,
+ /**
+ * String type.
+ */
+ STRING,
+ /**
+ * Number type.
+ */
+ NUMBER,
+ /**
+ * Integer type.
+ */
+ INTEGER,
+ /**
+ * Boolean type.
+ */
+ BOOLEAN,
+ /**
+ * Array type.
+ */
+ ARRAY,
+ /**
+ * Object type.
+ */
+ OBJECT
+ }
+}
+
diff --git a/gemini-gson/pom.xml b/gemini-gson/pom.xml
index 5766553..7804315 100644
--- a/gemini-gson/pom.xml
+++ b/gemini-gson/pom.xml
@@ -7,7 +7,7 @@
swiss.ameri
gemini
- 1beta.0.2.4
+ 1beta.0.2.6-SNAPSHOT
gemini-gson
diff --git a/gemini-gson/src/main/java/swiss/ameri/gemini/gson/GsonJsonParser.java b/gemini-gson/src/main/java/swiss/ameri/gemini/gson/GsonJsonParser.java
index 09a13aa..6c0dc87 100644
--- a/gemini-gson/src/main/java/swiss/ameri/gemini/gson/GsonJsonParser.java
+++ b/gemini-gson/src/main/java/swiss/ameri/gemini/gson/GsonJsonParser.java
@@ -1,6 +1,9 @@
package swiss.ameri.gemini.gson;
+import com.google.gson.FieldNamingStrategy;
import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import swiss.ameri.gemini.api.Schema;
import swiss.ameri.gemini.spi.JsonParser;
/**
@@ -8,6 +11,17 @@
*/
public class GsonJsonParser implements JsonParser {
+ /**
+ * Field naming strategy to avoid usage of illegal field names in java.
+ * See e.g. {@link Schema#ameri_swiss_enum()}, which cannot be named {@code enum}.
+ */
+ public static final FieldNamingStrategy FIELD_NAMING_STRATEGY = field -> {
+ if (field.getName().startsWith("ameri_swiss_")) {
+ return field.getName().substring("ameri_swiss_".length());
+ }
+ return field.getName();
+ };
+
private final Gson gson;
/**
@@ -23,7 +37,7 @@ public GsonJsonParser(Gson gson) {
* Create a default {@link JsonParser} instance.
*/
public GsonJsonParser() {
- this(new Gson());
+ this(new GsonBuilder().setFieldNamingStrategy(FIELD_NAMING_STRATEGY).create());
}
@Override
diff --git a/gemini-tester/pom.xml b/gemini-tester/pom.xml
index 90285d5..c505e3e 100644
--- a/gemini-tester/pom.xml
+++ b/gemini-tester/pom.xml
@@ -7,7 +7,7 @@
swiss.ameri
gemini
- 1beta.0.2.4
+ 1beta.0.2.6-SNAPSHOT
gemini-tester
diff --git a/gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java b/gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java
index 3895b96..0b81e42 100644
--- a/gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java
+++ b/gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java
@@ -8,6 +8,7 @@
import java.io.InputStream;
import java.util.Base64;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
@@ -38,6 +39,8 @@ public static void main(String[] args) throws Exception {
countTokens(genAi);
generateContent(genAi);
generateContentStream(genAi);
+ generateWithResponseSchema(genAi);
+ generateContentStreamWithResponseSchema(genAi);
multiChatTurn(genAi);
textAndImage(genAi);
embedContents(genAi);
@@ -90,7 +93,7 @@ private static void countTokens(GenAi genAi) {
private static void multiChatTurn(GenAi genAi) {
System.out.println("----- multi turn chat");
GenerativeModel chatModel = GenerativeModel.builder()
- .modelName(ModelVariant.GEMINI_1_0_PRO)
+ .modelName(ModelVariant.GEMINI_1_5_PRO)
.addContent(new Content.TextContent(
Content.Role.USER.roleName(),
"Write the first line of a story about a magic backpack."
@@ -133,9 +136,68 @@ private static void generateContent(GenAi genAi) throws InterruptedException, Ex
.get(20, TimeUnit.SECONDS);
}
+
+ private static void generateContentStreamWithResponseSchema(GenAi genAi) {
+ System.out.println("----- Generate content (streaming) with response schema -- with usage meta data");
+ var model = createResponseSchemaModel();
+ genAi.generateContentStream(model)
+ .forEach(x -> {
+ System.out.println(x);
+ // note that the usage metadata is updated as it arrives
+ System.out.println(genAi.usageMetadata(x.id()));
+ System.out.println(genAi.safetyRatings(x.id()));
+ });
+ }
+
+ private static void generateWithResponseSchema(GenAi genAi) throws InterruptedException, ExecutionException, TimeoutException {
+ var model = createResponseSchemaModel();
+ System.out.println("----- Generate with response schema (blocking)");
+ genAi.generateContent(model)
+ .thenAccept(gcr -> {
+ System.out.println(gcr);
+ System.out.println("----- Generate with response schema (blocking) usage meta data & safety ratings");
+ System.out.println(genAi.usageMetadata(gcr.id()));
+ System.out.println(genAi.safetyRatings(gcr.id()).stream().map(GenAi.SafetyRating::toTypedSafetyRating).toList());
+ })
+ .get(20, TimeUnit.SECONDS);
+ }
+
+ private static GenerativeModel createResponseSchemaModel() {
+ return GenerativeModel.builder()
+ .modelName(ModelVariant.GEMINI_1_5_FLASH)
+ .addContent(Content.textContent(
+ Content.Role.USER,
+ "List 3 popular cookie recipes."
+ ))
+ .addSafetySetting(SafetySetting.of(
+ SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
+ SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH
+ ))
+ .generationConfig(new GenerationConfig(
+ null,
+ "application/json",
+ Schema.builder()
+ .type(Schema.Type.ARRAY)
+ .items(Schema.builder()
+ .type(Schema.Type.OBJECT)
+ .properties(Map.of(
+ "recipe_name", Schema.builder()
+ .type(Schema.Type.STRING)
+ .build()
+ ))
+ .build())
+ .build(),
+ null,
+ null,
+ null,
+ null
+ ))
+ .build();
+ }
+
private static GenerativeModel createStoryModel() {
return GenerativeModel.builder()
- .modelName(ModelVariant.GEMINI_1_0_PRO)
+ .modelName(ModelVariant.GEMINI_1_5_PRO)
.addContent(Content.textContent(
Content.Role.USER,
"Write a 50 word story about a magic backpack."
@@ -159,7 +221,7 @@ private static GenerativeModel createStoryModel() {
private static void getModel(GenAi genAi) {
System.out.println("----- Get Model");
System.out.println(
- genAi.getModel(ModelVariant.GEMINI_1_0_PRO)
+ genAi.getModel(ModelVariant.GEMINI_1_5_PRO)
);
}
diff --git a/pom.xml b/pom.xml
index 3ef3b06..100c860 100644
--- a/pom.xml
+++ b/pom.xml
@@ -6,7 +6,7 @@
swiss.ameri
gemini
- 1beta.0.2.4
+ 1beta.0.2.6-SNAPSHOT
pom