1616import graphql .language .ObjectTypeDefinition ;
1717import graphql .language .ObjectValue ;
1818import graphql .language .OperationTypeDefinition ;
19- import graphql .language .ResolvedTypeDefinition ;
2019import graphql .language .ScalarTypeDefinition ;
2120import graphql .language .SchemaDefinition ;
2221import graphql .language .StringValue ;
@@ -338,7 +337,7 @@ private GraphQLInterfaceType buildInterfaceType(BuildContext buildCtx, Interface
338337 builder .name (typeDefinition .getName ());
339338 builder .description (buildDescription (typeDefinition ));
340339
341- builder .typeResolver (getTypeResolver (buildCtx , typeDefinition ));
340+ builder .typeResolver (getTypeResolverForInterface (buildCtx , typeDefinition ));
342341
343342 typeDefinition .getFieldDefinitions ().forEach (fieldDef ->
344343 builder .field (buildField (buildCtx , typeDefinition , fieldDef )));
@@ -349,7 +348,7 @@ private GraphQLUnionType buildUnionType(BuildContext buildCtx, UnionTypeDefiniti
349348 GraphQLUnionType .Builder builder = GraphQLUnionType .newUnionType ();
350349 builder .name (typeDefinition .getName ());
351350 builder .description (buildDescription (typeDefinition ));
352- builder .typeResolver (getTypeResolver (buildCtx , typeDefinition ));
351+ builder .typeResolver (getTypeResolverForUnion (buildCtx , typeDefinition ));
353352
354353 typeDefinition .getMemberTypes ().forEach (mt -> {
355354 GraphQLOutputType outputType = buildOutputType (buildCtx , mt );
@@ -474,18 +473,39 @@ private Object buildObjectValue(ObjectValue defaultValue) {
474473 return map ;
475474 }
476475
477- private TypeResolver getTypeResolver (BuildContext buildCtx , ResolvedTypeDefinition typeDefinition ) {
476+ private TypeResolver getTypeResolverForUnion (BuildContext buildCtx , UnionTypeDefinition unionType ) {
478477 TypeDefinitionRegistry typeRegistry = buildCtx .getTypeRegistry ();
479478 RuntimeWiring wiring = buildCtx .getWiring ();
480479 WiringFactory wiringFactory = wiring .getWiringFactory ();
481480
482481 TypeResolver typeResolver ;
483- if (wiringFactory .providesTypeResolver (typeRegistry , typeDefinition )) {
484- typeResolver = wiringFactory .getTypeResolver (typeRegistry , typeDefinition );
482+ if (wiringFactory .providesTypeResolver (typeRegistry , unionType )) {
483+ typeResolver = wiringFactory .getTypeResolver (typeRegistry , unionType );
485484 assertNotNull (typeResolver , "The WiringFactory indicated it provides a type resolver but then returned null" );
486485
487486 } else {
488- typeResolver = wiring .getTypeResolvers ().get (typeDefinition .getName ());
487+ typeResolver = wiring .getTypeResolvers ().get (unionType .getName ());
488+ if (typeResolver == null ) {
489+ // this really should be checked earlier via a pre-flight check
490+ typeResolver = new TypeResolverProxy ();
491+ }
492+ }
493+
494+ return typeResolver ;
495+ }
496+
497+ private TypeResolver getTypeResolverForInterface (BuildContext buildCtx , InterfaceTypeDefinition interfaceType ) {
498+ TypeDefinitionRegistry typeRegistry = buildCtx .getTypeRegistry ();
499+ RuntimeWiring wiring = buildCtx .getWiring ();
500+ WiringFactory wiringFactory = wiring .getWiringFactory ();
501+
502+ TypeResolver typeResolver ;
503+ if (wiringFactory .providesTypeResolver (typeRegistry , interfaceType )) {
504+ typeResolver = wiringFactory .getTypeResolver (typeRegistry , interfaceType );
505+ assertNotNull (typeResolver , "The WiringFactory indicated it provides a type resolver but then returned null" );
506+
507+ } else {
508+ typeResolver = wiring .getTypeResolvers ().get (interfaceType .getName ());
489509 if (typeResolver == null ) {
490510 // this really should be checked earlier via a pre-flight check
491511 typeResolver = new TypeResolverProxy ();
0 commit comments