diff --git a/src/main/java/graphql/schema/SchemaUtil.java b/src/main/java/graphql/schema/SchemaUtil.java index 368dd1940e..bb6d80bae4 100644 --- a/src/main/java/graphql/schema/SchemaUtil.java +++ b/src/main/java/graphql/schema/SchemaUtil.java @@ -51,7 +51,7 @@ private void collectTypes(GraphQLType root, Map result) { } else if (root instanceof GraphQLUnionType) { collectTypesForUnions((GraphQLUnionType) root, result); } else if (root instanceof GraphQLInputObjectType) { - result.put(((GraphQLInputObjectType) root).getName(), root); + collectTypesForInputObjects((GraphQLInputObjectType) root, result); } else if (root instanceof GraphQLTypeReference) { // nothing to do } else { @@ -95,6 +95,16 @@ private void collectTypesForObjects(GraphQLObjectType objectType, Map result) { + if (result.containsKey(objectType.getName())) return; + result.put(objectType.getName(), objectType); + + for (GraphQLInputObjectField fieldDefinition : objectType.getFields()) { + collectTypes(fieldDefinition.getType(), result); + } + } + + public Map allTypes(GraphQLSchema schema, Set dictionary) { Map typesByName = new LinkedHashMap<>(); collectTypes(schema.getQueryType(), typesByName); diff --git a/src/test/groovy/graphql/NestedInputSchema.java b/src/test/groovy/graphql/NestedInputSchema.java new file mode 100644 index 0000000000..3491991b1b --- /dev/null +++ b/src/test/groovy/graphql/NestedInputSchema.java @@ -0,0 +1,99 @@ +package graphql; + +import graphql.schema.DataFetcher; +import graphql.schema.DataFetchingEnvironment; +import graphql.schema.GraphQLArgument; +import graphql.schema.GraphQLFieldDefinition; +import graphql.schema.GraphQLInputObjectField; +import graphql.schema.GraphQLInputObjectType; +import graphql.schema.GraphQLObjectType; +import graphql.schema.GraphQLSchema; +import java.util.Map; +import static graphql.Scalars.GraphQLBoolean; +import static graphql.Scalars.GraphQLInt; + +public class NestedInputSchema { + + + public static GraphQLSchema createSchema() { + + + GraphQLObjectType root = rootType(); + + return GraphQLSchema.newSchema() + .query(root) + .build(); + } + + public static GraphQLObjectType rootType() { + return GraphQLObjectType.newObject() + + .name("Root") + .field(GraphQLFieldDefinition.newFieldDefinition() + .name("value") + .type(GraphQLInt) + .dataFetcher(new DataFetcher() { + @Override + public Object get(DataFetchingEnvironment environment) { + int initialValue = environment.getArgument("initialValue"); + Map filter = environment.getArgument("filter"); + if (filter != null) { + if (filter.containsKey("even")) { + Boolean even = (Boolean) filter.get("even"); + if (even && (initialValue%2 != 0)) { + return 0; + } else if (!even && (initialValue%2 == 0)) { + return 0; + } + } + if (filter.containsKey("range")) { + Map range = (Map) filter.get("range"); + if (initialValue < range.get("lowerBound") || + initialValue > range.get("upperBound")) { + return 0; + } + } + } + return initialValue; + }}) + .argument(GraphQLArgument.newArgument() + .name("intialValue") + .type(GraphQLInt) + .defaultValue(5) + .build()) + .argument(GraphQLArgument.newArgument() + .name("filter") + .type(filterType()) + .build()) + .build()) + .build(); + } + + public static GraphQLInputObjectType filterType() { + return GraphQLInputObjectType.newInputObject() + .name("Filter") + .field(GraphQLInputObjectField.newInputObjectField() + .name("even") + .type(GraphQLBoolean) + .build()) + .field(GraphQLInputObjectField.newInputObjectField() + .name("range") + .type(rangeType()) + .build()) + .build(); + } + + public static GraphQLInputObjectType rangeType() { + return GraphQLInputObjectType.newInputObject() + .name("Range") + .field(GraphQLInputObjectField.newInputObjectField() + .name("lowerBound") + .type(GraphQLInt) + .build()) + .field(GraphQLInputObjectField.newInputObjectField() + .name("upperBound") + .type(GraphQLInt) + .build()) + .build(); + } +} diff --git a/src/test/groovy/graphql/schema/SchemaUtilTest.groovy b/src/test/groovy/graphql/schema/SchemaUtilTest.groovy index 476715ab34..efd19988eb 100644 --- a/src/test/groovy/graphql/schema/SchemaUtilTest.groovy +++ b/src/test/groovy/graphql/schema/SchemaUtilTest.groovy @@ -1,11 +1,14 @@ package graphql.schema +import graphql.NestedInputSchema +import graphql.Scalars import graphql.introspection.Introspection import spock.lang.Specification import java.util.Collections; import static graphql.Scalars.GraphQLBoolean +import static graphql.Scalars.GraphQLInt import static graphql.Scalars.GraphQLString import static graphql.StarWarsSchema.* @@ -30,4 +33,27 @@ class SchemaUtilTest extends Specification { (Introspection.__Directive.name) : Introspection.__Directive, (GraphQLBoolean.name) : GraphQLBoolean] } + + def "collectAllTypesNestedInput"() { + when: + Map types = new SchemaUtil().allTypes(NestedInputSchema.createSchema(), Collections.emptySet()); + Map expected = + + [(NestedInputSchema.rootType().name) : NestedInputSchema.rootType(), + (NestedInputSchema.filterType().name) : NestedInputSchema.filterType(), + (NestedInputSchema.rangeType().name) : NestedInputSchema.rangeType(), + (GraphQLInt.name) : GraphQLInt, + (GraphQLString.name) : GraphQLString, + (Introspection.__Schema.name) : Introspection.__Schema, + (Introspection.__Type.name) : Introspection.__Type, + (Introspection.__TypeKind.name) : Introspection.__TypeKind, + (Introspection.__Field.name) : Introspection.__Field, + (Introspection.__InputValue.name): Introspection.__InputValue, + (Introspection.__EnumValue.name) : Introspection.__EnumValue, + (Introspection.__Directive.name) : Introspection.__Directive, + (GraphQLBoolean.name) : GraphQLBoolean]; + then: + types.keySet() == expected.keySet() + } + }