diff --git a/gemini-api/pom.xml b/gemini-api/pom.xml index 596b85c..6398210 100644 --- a/gemini-api/pom.xml +++ b/gemini-api/pom.xml @@ -6,7 +6,7 @@ swiss.ameri gemini - 1beta.0.2.4 + 1beta.0.2.6-SNAPSHOT gemini-api diff --git a/gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java b/gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java index 32c4cb0..0417ac1 100644 --- a/gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java +++ b/gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java @@ -8,10 +8,7 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; +import java.util.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; @@ -145,7 +142,10 @@ public List safetyRatings(UUID id) { return emptyList(); } return response.candidates().stream() - .flatMap(candidate -> candidate.safetyRatings().stream()) + // when streaming, we sometimes don't get a safety rating... (with 1.5 pro) + .map(ResponseCandidate::safetyRatings) + .filter(Objects::nonNull) + .flatMap(Collection::stream) .toList(); } diff --git a/gemini-api/src/main/java/swiss/ameri/gemini/api/GenerationConfig.java b/gemini-api/src/main/java/swiss/ameri/gemini/api/GenerationConfig.java index 5aa64ff..2142da0 100644 --- a/gemini-api/src/main/java/swiss/ameri/gemini/api/GenerationConfig.java +++ b/gemini-api/src/main/java/swiss/ameri/gemini/api/GenerationConfig.java @@ -36,7 +36,7 @@ public record GenerationConfig( List stopSequences, String responseMimeType, - String responseSchema, + Schema responseSchema, Integer maxOutputTokens, Double temperature, Double topP, @@ -53,7 +53,7 @@ public static GenerationConfigBuilder builder() { public static class GenerationConfigBuilder { private final List stopSequences = new ArrayList<>(); private String responseMimeType; - private String responseSchema; + private Schema responseSchema; private Integer maxOutputTokens; private Double temperature; private Double topP; @@ -69,7 +69,7 @@ public GenerationConfigBuilder responseMimeType(String responseMimeType) { return this; } - public GenerationConfigBuilder responseSchema(String responseSchema) { + public GenerationConfigBuilder responseSchema(Schema responseSchema) { this.responseSchema = responseSchema; return this; } diff --git a/gemini-api/src/main/java/swiss/ameri/gemini/api/Schema.java b/gemini-api/src/main/java/swiss/ameri/gemini/api/Schema.java new file mode 100644 index 0000000..0f29b0c --- /dev/null +++ b/gemini-api/src/main/java/swiss/ameri/gemini/api/Schema.java @@ -0,0 +1,178 @@ +package swiss.ameri.gemini.api; + + +import java.util.List; +import java.util.Map; + +/** + * The Schema object allows the definition of input and output data types. + * These types can be objects, but also primitives and arrays. + * Represents a select subset of an OpenAPI 3.0 schema object. + * + * @param type Required. Data type. + * @param format Optional. The format of the data. This is used only for primitive datatypes. + * Supported formats: + * for NUMBER type: float, double + * for INTEGER type: int32, int64 + * for STRING type: enum + * @param description Optional. A brief description of the parameter. This could contain examples of use. + * Parameter description may be formatted as Markdown. + * @param nullable Optional. Indicates if the value may be null. + * @param ameri_swiss_enum Optional. Note: the ameri_swiss prefix must be removed by the {@link swiss.ameri.gemini.spi.JsonParser}. + * Possible values of the element of Type.STRING with enum format. + * For example we can define an Enum Direction as : + * {type:STRING, format:enum, enum:["EAST", NORTH", "SOUTH", "WEST"]} + * @param maxItems Optional. Maximum number of the elements for Type.ARRAY. + * @param minItems Optional. Minimum number of the elements for Type.ARRAY. + * @param properties Optional. Properties of Type.OBJECT. + * An object containing a list of "key": value pairs. Example: + * { "name": "wrench", "mass": "1.3kg", "count": "3" }. + * @param required Optional. Required properties of Type.OBJECT. + * @param items Optional. Schema of the elements of Type.ARRAY. + * @see Schema for further information. + */ +public record Schema( + Type type, + String format, + String description, + Boolean nullable, + List ameri_swiss_enum, + String maxItems, + String minItems, + Map properties, + List required, + Schema items +) { + + + /** + * Create a {@link SchemaBuilder}. + * + * @return an empty {@link SchemaBuilder} + */ + public static SchemaBuilder builder() { + return new SchemaBuilder(); + } + + /** + * A builder for {@link Schema}. Currently, does not validate the fields when building the model. Not thread-safe. + */ + public static class SchemaBuilder { + private Type type; + private String format; + private String description; + private Boolean nullable; + private List ameri_swiss_enum; + private String maxItems; + private String minItems; + private Map properties; + private List required; + private Schema items; + + + private SchemaBuilder() { + } + + public Schema build() { + return new Schema( + this.type, + this.format, + this.description, + this.nullable, + this.ameri_swiss_enum, + this.maxItems, + this.minItems, + this.properties, + this.required, + this.items + ); + } + + public SchemaBuilder type(Type type) { + this.type = type; + return this; + } + + public SchemaBuilder format(String format) { + this.format = format; + return this; + } + + public SchemaBuilder description(String description) { + this.description = description; + return this; + } + + public SchemaBuilder nullable(Boolean nullable) { + this.nullable = nullable; + return this; + } + + public SchemaBuilder ameri_swiss_enum(List ameri_swiss_enum) { + this.ameri_swiss_enum = ameri_swiss_enum; + return this; + } + + public SchemaBuilder maxItems(String maxItems) { + this.maxItems = maxItems; + return this; + } + + public SchemaBuilder minItems(String minItems) { + this.minItems = minItems; + return this; + } + + public SchemaBuilder properties(Map properties) { + this.properties = properties; + return this; + } + + public SchemaBuilder required(List required) { + this.required = required; + return this; + } + + public SchemaBuilder items(Schema items) { + this.items = items; + return this; + } + } + + /** + * Type contains the list of OpenAPI data types. + * + * @see Data types + */ + public enum Type { + /** + * Not specified, should not be used. + */ + TYPE_UNSPECIFIED, + /** + * String type. + */ + STRING, + /** + * Number type. + */ + NUMBER, + /** + * Integer type. + */ + INTEGER, + /** + * Boolean type. + */ + BOOLEAN, + /** + * Array type. + */ + ARRAY, + /** + * Object type. + */ + OBJECT + } +} + diff --git a/gemini-gson/pom.xml b/gemini-gson/pom.xml index 5766553..7804315 100644 --- a/gemini-gson/pom.xml +++ b/gemini-gson/pom.xml @@ -7,7 +7,7 @@ swiss.ameri gemini - 1beta.0.2.4 + 1beta.0.2.6-SNAPSHOT gemini-gson diff --git a/gemini-gson/src/main/java/swiss/ameri/gemini/gson/GsonJsonParser.java b/gemini-gson/src/main/java/swiss/ameri/gemini/gson/GsonJsonParser.java index 09a13aa..6c0dc87 100644 --- a/gemini-gson/src/main/java/swiss/ameri/gemini/gson/GsonJsonParser.java +++ b/gemini-gson/src/main/java/swiss/ameri/gemini/gson/GsonJsonParser.java @@ -1,6 +1,9 @@ package swiss.ameri.gemini.gson; +import com.google.gson.FieldNamingStrategy; import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import swiss.ameri.gemini.api.Schema; import swiss.ameri.gemini.spi.JsonParser; /** @@ -8,6 +11,17 @@ */ public class GsonJsonParser implements JsonParser { + /** + * Field naming strategy to avoid usage of illegal field names in java. + * See e.g. {@link Schema#ameri_swiss_enum()}, which cannot be named {@code enum}. + */ + public static final FieldNamingStrategy FIELD_NAMING_STRATEGY = field -> { + if (field.getName().startsWith("ameri_swiss_")) { + return field.getName().substring("ameri_swiss_".length()); + } + return field.getName(); + }; + private final Gson gson; /** @@ -23,7 +37,7 @@ public GsonJsonParser(Gson gson) { * Create a default {@link JsonParser} instance. */ public GsonJsonParser() { - this(new Gson()); + this(new GsonBuilder().setFieldNamingStrategy(FIELD_NAMING_STRATEGY).create()); } @Override diff --git a/gemini-tester/pom.xml b/gemini-tester/pom.xml index 90285d5..c505e3e 100644 --- a/gemini-tester/pom.xml +++ b/gemini-tester/pom.xml @@ -7,7 +7,7 @@ swiss.ameri gemini - 1beta.0.2.4 + 1beta.0.2.6-SNAPSHOT gemini-tester diff --git a/gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java b/gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java index 3895b96..0b81e42 100644 --- a/gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java +++ b/gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java @@ -8,6 +8,7 @@ import java.io.InputStream; import java.util.Base64; import java.util.List; +import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -38,6 +39,8 @@ public static void main(String[] args) throws Exception { countTokens(genAi); generateContent(genAi); generateContentStream(genAi); + generateWithResponseSchema(genAi); + generateContentStreamWithResponseSchema(genAi); multiChatTurn(genAi); textAndImage(genAi); embedContents(genAi); @@ -90,7 +93,7 @@ private static void countTokens(GenAi genAi) { private static void multiChatTurn(GenAi genAi) { System.out.println("----- multi turn chat"); GenerativeModel chatModel = GenerativeModel.builder() - .modelName(ModelVariant.GEMINI_1_0_PRO) + .modelName(ModelVariant.GEMINI_1_5_PRO) .addContent(new Content.TextContent( Content.Role.USER.roleName(), "Write the first line of a story about a magic backpack." @@ -133,9 +136,68 @@ private static void generateContent(GenAi genAi) throws InterruptedException, Ex .get(20, TimeUnit.SECONDS); } + + private static void generateContentStreamWithResponseSchema(GenAi genAi) { + System.out.println("----- Generate content (streaming) with response schema -- with usage meta data"); + var model = createResponseSchemaModel(); + genAi.generateContentStream(model) + .forEach(x -> { + System.out.println(x); + // note that the usage metadata is updated as it arrives + System.out.println(genAi.usageMetadata(x.id())); + System.out.println(genAi.safetyRatings(x.id())); + }); + } + + private static void generateWithResponseSchema(GenAi genAi) throws InterruptedException, ExecutionException, TimeoutException { + var model = createResponseSchemaModel(); + System.out.println("----- Generate with response schema (blocking)"); + genAi.generateContent(model) + .thenAccept(gcr -> { + System.out.println(gcr); + System.out.println("----- Generate with response schema (blocking) usage meta data & safety ratings"); + System.out.println(genAi.usageMetadata(gcr.id())); + System.out.println(genAi.safetyRatings(gcr.id()).stream().map(GenAi.SafetyRating::toTypedSafetyRating).toList()); + }) + .get(20, TimeUnit.SECONDS); + } + + private static GenerativeModel createResponseSchemaModel() { + return GenerativeModel.builder() + .modelName(ModelVariant.GEMINI_1_5_FLASH) + .addContent(Content.textContent( + Content.Role.USER, + "List 3 popular cookie recipes." + )) + .addSafetySetting(SafetySetting.of( + SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH + )) + .generationConfig(new GenerationConfig( + null, + "application/json", + Schema.builder() + .type(Schema.Type.ARRAY) + .items(Schema.builder() + .type(Schema.Type.OBJECT) + .properties(Map.of( + "recipe_name", Schema.builder() + .type(Schema.Type.STRING) + .build() + )) + .build()) + .build(), + null, + null, + null, + null + )) + .build(); + } + private static GenerativeModel createStoryModel() { return GenerativeModel.builder() - .modelName(ModelVariant.GEMINI_1_0_PRO) + .modelName(ModelVariant.GEMINI_1_5_PRO) .addContent(Content.textContent( Content.Role.USER, "Write a 50 word story about a magic backpack." @@ -159,7 +221,7 @@ private static GenerativeModel createStoryModel() { private static void getModel(GenAi genAi) { System.out.println("----- Get Model"); System.out.println( - genAi.getModel(ModelVariant.GEMINI_1_0_PRO) + genAi.getModel(ModelVariant.GEMINI_1_5_PRO) ); } diff --git a/pom.xml b/pom.xml index 3ef3b06..100c860 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ swiss.ameri gemini - 1beta.0.2.4 + 1beta.0.2.6-SNAPSHOT pom