3030import java .util .HashMap ;
3131import java .util .HashSet ;
3232import java .util .Set ;
33- import java .util .function . Predicate ;
33+ import java .util .stream . StreamSupport ;
3434
3535import com .sun .source .tree .LambdaExpressionTree .BodyKind ;
3636import com .sun .tools .javac .code .*;
4545import com .sun .tools .javac .util .JCDiagnostic .Error ;
4646import com .sun .tools .javac .util .JCDiagnostic .Warning ;
4747
48- import com .sun .tools .javac .code .Kinds .Kind ;
4948import com .sun .tools .javac .code .Symbol .*;
5049import com .sun .tools .javac .tree .JCTree .*;
5150
5251import static com .sun .tools .javac .code .Flags .*;
5352import static com .sun .tools .javac .code .Flags .BLOCK ;
5453import static com .sun .tools .javac .code .Kinds .Kind .*;
54+ import com .sun .tools .javac .code .Type .TypeVar ;
5555import static com .sun .tools .javac .code .TypeTag .BOOLEAN ;
5656import static com .sun .tools .javac .code .TypeTag .VOID ;
5757import com .sun .tools .javac .resources .CompilerProperties .Fragments ;
58- import com .sun .tools .javac .tree .JCTree .JCParenthesizedPattern ;
5958import static com .sun .tools .javac .tree .JCTree .Tag .*;
6059import com .sun .tools .javac .util .JCDiagnostic .Fragment ;
6160
@@ -665,7 +664,7 @@ public void visitSwitch(JCSwitch tree) {
665664 ListBuffer <PendingExit > prevPendingExits = pendingExits ;
666665 pendingExits = new ListBuffer <>();
667666 scan (tree .selector );
668- Set <Object > constants = tree .patternSwitch ? allSwitchConstants ( tree . selector ) : null ;
667+ Set <Symbol > constants = tree .patternSwitch ? new HashSet <>( ) : null ;
669668 for (List <JCCase > l = tree .cases ; l .nonEmpty (); l = l .tail ) {
670669 alive = Liveness .ALIVE ;
671670 JCCase c = l .head ;
@@ -687,8 +686,9 @@ public void visitSwitch(JCSwitch tree) {
687686 l .tail .head .pos (),
688687 Warnings .PossibleFallThroughIntoCase );
689688 }
690- if ((constants == null || !constants .isEmpty ()) && !tree .hasTotalPattern &&
691- tree .patternSwitch && !TreeInfo .isErrorEnumSwitch (tree .selector , tree .cases )) {
689+ if (!tree .hasTotalPattern && tree .patternSwitch &&
690+ !TreeInfo .isErrorEnumSwitch (tree .selector , tree .cases ) &&
691+ (constants == null || !isExhaustive (tree .selector .type , constants ))) {
692692 log .error (tree , Errors .NotExhaustiveStatement );
693693 }
694694 if (!tree .hasTotalPattern ) {
@@ -702,7 +702,7 @@ public void visitSwitchExpression(JCSwitchExpression tree) {
702702 ListBuffer <PendingExit > prevPendingExits = pendingExits ;
703703 pendingExits = new ListBuffer <>();
704704 scan (tree .selector );
705- Set <Object > constants = allSwitchConstants ( tree . selector );
705+ Set <Symbol > constants = new HashSet <>( );
706706 Liveness prevAlive = alive ;
707707 for (List <JCCase > l = tree .cases ; l .nonEmpty (); l = l .tail ) {
708708 alive = Liveness .ALIVE ;
@@ -723,47 +723,83 @@ public void visitSwitchExpression(JCSwitchExpression tree) {
723723 }
724724 c .completesNormally = alive != Liveness .DEAD ;
725725 }
726- if (( constants == null || ! constants . isEmpty ()) && !tree .hasTotalPattern &&
727- !TreeInfo . isErrorEnumSwitch (tree .selector , tree . cases )) {
726+ if (! tree . hasTotalPattern && !TreeInfo . isErrorEnumSwitch ( tree .selector , tree . cases ) &&
727+ !isExhaustive (tree .selector . type , constants )) {
728728 log .error (tree , Errors .NotExhaustive );
729729 }
730730 alive = prevAlive ;
731731 alive = alive .or (resolveYields (tree , prevPendingExits ));
732732 }
733733
734- private Set <Object > allSwitchConstants (JCExpression selector ) {
735- Set <Object > constants = null ;
736- TypeSymbol selectorSym = selector .type .tsym ;
737- if ((selectorSym .flags () & ENUM ) != 0 ) {
738- constants = new HashSet <>();
739- Predicate <Symbol > enumConstantFilter =
740- s -> (s .flags () & ENUM ) != 0 && s .kind == Kind .VAR ;
741- for (Symbol s : selectorSym .members ().getSymbols (enumConstantFilter )) {
742- constants .add (s .name );
743- }
744- } else if (selectorSym .isAbstract () && selectorSym .isSealed () && selectorSym .kind == Kind .TYP ) {
745- constants = new HashSet <>();
746- constants .addAll (((ClassSymbol ) selectorSym ).permitted );
747- }
748- return constants ;
749- }
750-
751- private void handleConstantCaseLabel (Set <Object > constants , JCCaseLabel pat ) {
734+ private void handleConstantCaseLabel (Set <Symbol > constants , JCCaseLabel pat ) {
752735 if (constants != null ) {
753736 if (pat .isExpression ()) {
754737 JCExpression expr = (JCExpression ) pat ;
755- if (expr .hasTag (IDENT ))
756- constants .remove (((JCIdent ) expr ).name );
738+ if (expr .hasTag (IDENT ) && (( JCIdent ) expr ). sym . isEnum () )
739+ constants .add (((JCIdent ) expr ).sym );
757740 } else if (pat .isPattern ()) {
758741 PatternPrimaryType patternType = TreeInfo .primaryPatternType ((JCPattern ) pat );
759742
760743 if (patternType .unconditional ()) {
761- constants .remove (patternType .type ().tsym );
744+ constants .add (patternType .type ().tsym );
762745 }
763746 }
764747 }
765748 }
766749
750+ private void transitiveCovers (Set <Symbol > covered ) {
751+ List <Symbol > todo = List .from (covered );
752+ while (todo .nonEmpty ()) {
753+ Symbol sym = todo .head ;
754+ todo = todo .tail ;
755+ switch (sym .kind ) {
756+ case VAR -> {
757+ Iterable <Symbol > constants = sym .owner
758+ .members ()
759+ .getSymbols (s -> s .isEnum () &&
760+ s .kind == VAR );
761+ boolean hasAll = StreamSupport .stream (constants .spliterator (), false )
762+ .allMatch (covered ::contains );
763+
764+ if (hasAll && covered .add (sym .owner )) {
765+ todo = todo .prepend (sym .owner );
766+ }
767+ }
768+
769+ case TYP -> {
770+ for (Type sup : types .directSupertypes (sym .type )) {
771+ if (sup .tsym .kind == TYP && sup .tsym .isAbstract () && sup .tsym .isSealed ()) {
772+ boolean hasAll = ((ClassSymbol ) sup .tsym ).permitted
773+ .stream ()
774+ .allMatch (covered ::contains );
775+
776+ if (hasAll && covered .add (sup .tsym )) {
777+ todo = todo .prepend (sup .tsym );
778+ }
779+ }
780+ }
781+ }
782+ }
783+ }
784+ }
785+
786+ private boolean isExhaustive (Type seltype , Set <Symbol > covered ) {
787+ transitiveCovers (covered );
788+ return switch (seltype .getTag ()) {
789+ case CLASS -> {
790+ if (seltype .isCompound ()) {
791+ if (seltype .isIntersection ()) {
792+ yield ((Type .IntersectionClassType ) seltype ).getComponents ().stream ().anyMatch (t -> isExhaustive (t , covered ));
793+ }
794+ yield false ;
795+ }
796+ yield covered .contains (seltype .tsym );
797+ }
798+ case TYPEVAR -> isExhaustive (((TypeVar ) seltype ).getUpperBound (), covered );
799+ default -> false ;
800+ };
801+ }
802+
767803 public void visitTry (JCTry tree ) {
768804 ListBuffer <PendingExit > prevPendingExits = pendingExits ;
769805 pendingExits = new ListBuffer <>();
0 commit comments