diff --git a/mypy/checker.py b/mypy/checker.py index 80402e71dce6..53227c0c4a81 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -8529,6 +8529,9 @@ def lookup_fully_qualified_or_none(self, fullname: str, /) -> SymbolTableNode | except KeyError: return None + def record_fixed_type(self, fixed: TypeInfo | TypeAlias) -> None: + pass + def fail( self, msg: str, diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 48ea7ab51f61..f93982a12a72 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4915,10 +4915,11 @@ def visit_type_application(self, tapp: TypeApplication) -> Type: if tapp.expr.node.python_3_12_type_alias: return self.type_alias_type_type() # Subscription of a (generic) alias in runtime context, expand the alias. - item = instantiate_type_alias( + item, _ = instantiate_type_alias( tapp.expr.node, tapp.types, self.chk.fail, + self.chk.note, tapp.expr.node.no_args, tapp, self.chk.options, @@ -4983,17 +4984,16 @@ class LongName(Generic[T]): ... # A = List[Tuple[T, T]] # x = A() <- same as List[Tuple[Any, Any]], see PEP 484. disallow_any = self.chk.options.disallow_any_generics and self.is_callee - item = get_proper_type( - set_any_tvars( - alias, - [], - ctx.line, - ctx.column, - self.chk.options, - disallow_any=disallow_any, - fail=self.msg.fail, - ) + item, _ = set_any_tvars( + alias, + [], + ctx.line, + ctx.column, + self.chk.options, + disallow_any=disallow_any, + fail=self.msg.fail, ) + item = get_proper_type(item) if isinstance(item, Instance): # Normally we get a callable type (or overloaded) with .is_type_obj() true # representing the class's constructor diff --git a/mypy/message_registry.py b/mypy/message_registry.py index 30ced27aef22..82885065934f 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -180,6 +180,7 @@ def with_additional_msg(self, info: str) -> ErrorMessage: IMPLICIT_GENERIC_ANY_BUILTIN: Final = ( 'Implicit generic "Any". Use "{}" and specify generic parameters' ) +NO_CYCLIC_DEFAULT: Final = "Cyclic type variable defaults are not supported" INVALID_UNPACK: Final = "{} cannot be unpacked (must be tuple or TypeVarTuple)" INVALID_UNPACK_POSITION: Final = "Unpack is only valid in a variadic position" INVALID_PARAM_SPEC_LOCATION: Final = "Invalid location for ParamSpec {}" diff --git a/mypy/nodes.py b/mypy/nodes.py index 32a694560b24..e050a1aa3421 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -3145,7 +3145,15 @@ class TypeVarLikeExpr(SymbolNode, Expression): Note that they are constructed by the semantic analyzer. """ - __slots__ = ("_name", "_fullname", "upper_bound", "default", "variance", "is_new_style") + __slots__ = ( + "_name", + "_fullname", + "upper_bound", + "default", + "variance", + "is_new_style", + "default_depends", + ) _name: str _fullname: str @@ -3160,6 +3168,9 @@ class TypeVarLikeExpr(SymbolNode, Expression): # TypeVar(..., contravariant=True) defines a contravariant type # variable. variance: int + # Record instances and type aliases that appear bare/implicit in the default value + # of this type variable. This is needed to detect recursive type variable defaults. + default_depends: set[TypeInfo | TypeAlias] | None def __init__( self, @@ -3178,6 +3189,7 @@ def __init__( self.default = default self.variance = variance self.is_new_style = is_new_style + self.default_depends = None @property def name(self) -> str: @@ -3655,6 +3667,7 @@ class is generic then it will be a type constructor of higher kind. "is_type_check_only", "deprecated", "type_object_type", + "default_depends", ) _fullname: str # Fully qualified name @@ -3816,6 +3829,16 @@ class is generic then it will be a type constructor of higher kind. # appears in runtime context. type_object_type: mypy.types.FunctionLike | None + # Type variables whose defaults depend on defaults of type variables in other classes + # and type aliases. We keep track of this to safely handle situations like this one: + # class C[T = D]: ... + # class D[S = C]: ... + # x: C + # Since we apply fix_instance() eagerly, inferring a precise type is quite tricky. + # Therefore, we infer the type of `x` as `C[D[Any]]` to avoid infinite recursion. + # Keys are type variable full names. + default_depends: dict[str, set[TypeAlias | TypeInfo]] + FLAGS: Final = [ "is_abstract", "is_enum", @@ -3877,6 +3900,7 @@ def __init__(self, names: SymbolTable, defn: ClassDef, module_name: str) -> None self.is_type_check_only = False self.deprecated = None self.type_object_type = None + self.default_depends = {} def add_type_vars(self) -> None: self.has_type_var_tuple_type = False @@ -4542,6 +4566,7 @@ def f(x: B[T]) -> T: ... # without T, Any would be used here "eager", "tvar_tuple_index", "python_3_12_type_alias", + "default_depends", ) __match_args__ = ("name", "target", "alias_tvars", "no_args") @@ -4574,6 +4599,8 @@ def __init__( self.eager = eager self.python_3_12_type_alias = python_3_12_type_alias self.tvar_tuple_index = None + # This plays the same role as TypeInfo.default_depends attribute. + self.default_depends: dict[str, set[TypeAlias | TypeInfo]] = {} for i, t in enumerate(alias_tvars): if isinstance(t, mypy.types.TypeVarTupleType): self.tvar_tuple_index = i diff --git a/mypy/semanal.py b/mypy/semanal.py index da58c9586966..2666aa7fbf6a 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -539,6 +539,11 @@ def __init__( # import foo.bar self.transitive_submodule_imports: dict[str, set[str]] = {} + # Instances and type aliases that were fixed using default valuers of type + # variables. This can be used on-demand by type analyzer. Use record_fixed_type() + # to create the set lazily. + self.types_fixed: set[TypeInfo | TypeAlias] | None = None + # mypyc doesn't properly handle implementing an abstractproperty # with a regular attribute so we make them properties @property @@ -1883,6 +1888,9 @@ def analyze_type_param( upper_bound = self.named_type("builtins.tuple", [self.object_type()]) else: upper_bound = self.object_type() + # Reset fixed types both before and after each collection just in case. + if self.types_fixed is not None: + self.types_fixed.clear() if type_param.default: default = self.anal_type( type_param.default, @@ -1892,6 +1900,7 @@ def analyze_type_param( allow_param_spec_literals=type_param.kind == PARAM_SPEC_KIND, allow_tuple_literal=type_param.kind == PARAM_SPEC_KIND, allow_unpack=type_param.kind == TYPE_VAR_TUPLE_KIND, + analyzing_tvar_def=True, ) if default is None: default = PlaceholderType(None, [], context.line) @@ -1903,6 +1912,8 @@ def analyze_type_param( default = self.check_typevartuple_default(default, type_param.default) else: default = AnyType(TypeOfAny.from_omitted_generics) + default_depends = self.types_fixed + self.types_fixed = None if type_param.kind == TYPE_VAR_KIND: values: list[Type] = [] if type_param.values: @@ -1915,7 +1926,7 @@ def analyze_type_param( values.append(AnyType(TypeOfAny.from_error)) else: values.append(analyzed) - return TypeVarExpr( + tv = TypeVarExpr( name=type_param.name, fullname=fullname, values=values, @@ -1926,7 +1937,7 @@ def analyze_type_param( line=context.line, ) elif type_param.kind == PARAM_SPEC_KIND: - return ParamSpecExpr( + tv = ParamSpecExpr( name=type_param.name, fullname=fullname, upper_bound=upper_bound, @@ -1937,7 +1948,7 @@ def analyze_type_param( else: assert type_param.kind == TYPE_VAR_TUPLE_KIND tuple_fallback = self.named_type("builtins.tuple", [self.object_type()]) - return TypeVarTupleExpr( + tv = TypeVarTupleExpr( name=type_param.name, fullname=fullname, upper_bound=upper_bound, @@ -1946,6 +1957,8 @@ def analyze_type_param( is_new_style=True, line=context.line, ) + tv.default_depends = default_depends + return tv def pop_type_args(self, type_args: list[TypeParam] | None) -> None: if not type_args: @@ -1972,24 +1985,22 @@ def analyze_class(self, defn: ClassDef) -> None: self.infer_metaclass_and_bases_from_compat_helpers(defn) bases = defn.base_type_exprs - bases, tvar_defs, is_protocol = self.clean_up_bases_and_infer_type_variables( - defn, bases, context=defn + bases, tvar_defs, is_protocol, declared_tvars = ( + self.clean_up_bases_and_infer_type_variables(defn, bases, context=defn) ) self.check_type_alias_bases(bases) - - for tvd in tvar_defs: - if isinstance(tvd, TypeVarType) and any( - has_placeholder(t) for t in [tvd.upper_bound] + tvd.values - ): - # Some type variable bounds or values are not ready, we need - # to re-analyze this class. - self.defer() - if has_placeholder(tvd.default): - # Placeholder values in TypeVarLikeTypes may get substituted in. - # Defer current target until they are ready. - self.mark_incomplete(defn.name, defn) - return + default_depends: dict[str, set[TypeAlias | TypeInfo]] = {} + for _, tv in declared_tvars: + if tv.default_depends is not None: + default_depends[tv.fullname] = tv.default_depends + + if any(has_placeholder(tvd) for tvd in tvar_defs): + # Some type variable bounds or values are not ready, we need to + # re-analyze this class. Note we force progress to handle cases like + # class C[T = C], this matches logic in process_typevar_parameters() + # for "old style" type variables. + self.defer(force_progress=tvar_defs != defn.type_vars) self.analyze_class_keywords(defn) bases_result = self.analyze_base_classes(bases) @@ -2004,6 +2015,10 @@ def analyze_class(self, defn: ClassDef) -> None: # are okay in nested positions, since they can't affect the MRO. self.mark_incomplete(defn.name, defn) return + if any(has_placeholder(base) for base, _ in base_types): + # We need to manually call defer() in case a placeholder was brought by a + # type variable default, so that type analyzer didn't call it. + self.defer() declared_metaclass, should_defer, any_meta = self.get_declared_metaclass( defn.name, defn.metaclass @@ -2017,14 +2032,18 @@ def analyze_class(self, defn: ClassDef) -> None: if defn.info: self.setup_type_vars(defn, tvar_defs) self.setup_alias_type_vars(defn) + defn.info.default_depends = default_depends return if self.analyze_namedtuple_classdef(defn, tvar_defs): + if defn.info: + defn.info.default_depends = default_depends return # Create TypeInfo for class now that base classes and the MRO can be calculated. self.prepare_class_def(defn) self.setup_type_vars(defn, tvar_defs) + defn.info.default_depends = default_depends if base_error: defn.info.fallback_to_any = True if any_meta: @@ -2264,7 +2283,7 @@ def analyze_class_decorator_common(self, defn: ClassDef, decorator: Expression) def clean_up_bases_and_infer_type_variables( self, defn: ClassDef, base_type_exprs: list[Expression], context: Context - ) -> tuple[list[Expression], list[TypeVarLikeType], bool]: + ) -> tuple[list[Expression], list[TypeVarLikeType], bool, list[tuple[str, TypeVarLikeExpr]]]: """Remove extra base classes such as Generic and infer type vars. For example, consider this class: @@ -2276,7 +2295,8 @@ class Foo(Bar, Generic[T]): ... Note that this is performed *before* semantic analysis. - Returns (remaining base expressions, inferred type variables, is protocol). + Returns a tuple: + (remaining base expressions, type variables, is protocol, type variable expressions). """ removed: list[int] = [] declared_tvars: TypeVarLikeList = [] @@ -2356,7 +2376,7 @@ class Foo(Bar, Generic[T]): ... defn.removed_base_type_exprs.append(defn.base_type_exprs[i]) del base_type_exprs[i] tvar_defs = self.tvar_defs_from_tvars(declared_tvars, context) - return base_type_exprs, tvar_defs, is_protocol + return base_type_exprs, tvar_defs, is_protocol, declared_tvars def analyze_class_typevar_declaration( self, base: Type, has_type_var_tuple: bool @@ -3980,11 +4000,14 @@ def analyze_alias( declared_type_vars: TypeVarLikeList | None = None, all_declared_type_params_names: list[str] | None = None, python_3_12_type_alias: bool = False, - ) -> tuple[Type | None, list[TypeVarLikeType], set[str], bool]: + ) -> tuple[ + Type | None, list[TypeVarLikeType], set[str], bool, dict[str, set[TypeAlias | TypeInfo]] + ]: """Check if 'rvalue' is a valid type allowed for aliasing (e.g. not a type variable). If yes, return the corresponding type, a list of type variables for generic aliases, - a set of names the alias depends on, and True if the original type has empty tuple index. + a set of names the alias depends on, whether the original type has empty tuple index, + and any type variables whose defaults depend on other classes or type aliases. An example for the dependencies: A = int B = str @@ -4000,7 +4023,7 @@ def analyze_alias( self.fail( "Invalid type alias: expression is not a valid type", rvalue, code=codes.VALID_TYPE ) - return None, [], set(), False + return None, [], set(), False, {} found_type_vars = self.find_type_var_likes(typ) namespace = self.qualified_name(name) @@ -4039,7 +4062,11 @@ def analyze_alias( new_tvar_defs.append(td) indexed = bool(isinstance(typ, UnboundType) and (typ.args or typ.empty_tuple_index)) - return analyzed, new_tvar_defs, depends_on, indexed + default_depends = {} + for _, tv in alias_type_vars: + if tv.default_depends is not None: + default_depends[tv.fullname] = tv.default_depends + return analyzed, new_tvar_defs, depends_on, indexed, default_depends def is_pep_613(self, s: AssignmentStmt) -> bool: if s.unanalyzed_type is not None and isinstance(s.unanalyzed_type, UnboundType): @@ -4139,9 +4166,10 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool: alias_tvars: list[TypeVarLikeType] = [] depends_on: set[str] = set() indexed = False + default_depends: dict[str, set[TypeAlias | TypeInfo]] = {} else: tag = self.track_incomplete_refs() - res, alias_tvars, depends_on, indexed = self.analyze_alias( + res, alias_tvars, depends_on, indexed, default_depends = self.analyze_alias( lvalue.name, rvalue, allow_placeholder=True, @@ -4164,6 +4192,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool: # may appear in nested positions), therefore use becomes_typeinfo=True. self.mark_incomplete(lvalue.name, rvalue, becomes_typeinfo=True) return True + self.add_type_alias_deps(depends_on) check_for_explicit_any(res, self.options, self.is_typeshed_stub_file, self.msg, context=s) # When this type alias gets "inlined", the Any is not explicit anymore, @@ -4202,6 +4231,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool: eager=eager, python_3_12_type_alias=pep_695, ) + alias_node.default_depends = default_depends if isinstance(s.rvalue, (IndexExpr, CallExpr, OpExpr)): # Note: CallExpr is for "void = type(None)" and OpExpr is for "X | Y" union syntax. if not isinstance(s.rvalue.analyzed, TypeAliasExpr): @@ -4221,6 +4251,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool: # Copy expansion to the existing alias, this matches how we update base classes # for a TypeInfo _in place_ if there are nested placeholders. existing.node.target = res + existing.node.default_depends = default_depends existing.node.alias_tvars = alias_tvars existing.node.no_args = no_args updated = True @@ -4230,6 +4261,8 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool: # Otherwise just replace existing placeholder with type alias *in place*. existing._node = alias_node updated = True + # TODO: switch type aliases to if has_placeholder(): process_placeholder() pattern. + # Type aliases are last notable exception from this logic. if updated: if self.final_iteration: self.cannot_resolve_name(lvalue.name, "name", s) @@ -4753,6 +4786,8 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool: n_values = call.arg_kinds[1:].count(ARG_POS) values = self.analyze_value_types(call.args[1 : 1 + n_values]) + if self.types_fixed is not None: + self.types_fixed.clear() res = self.process_typevar_parameters( call.args[1 + n_values :], call.arg_names[1 + n_values :], @@ -4760,6 +4795,8 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool: n_values, s, ) + default_depends = self.types_fixed + self.types_fixed = None if res is None: return False variance, upper_bound, default = res @@ -4800,6 +4837,7 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool: type_var = TypeVarExpr( name, self.qualified_name(name), values, upper_bound, default, variance ) + type_var.default_depends = default_depends type_var.line = call.line call.analyzed = type_var updated = True @@ -4813,6 +4851,7 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool: call.analyzed.upper_bound = upper_bound call.analyzed.values = values call.analyzed.default = default + call.analyzed.default_depends = default_depends if any(has_placeholder(v) for v in values): self.process_placeholder(None, "TypeVar values", s, force_progress=updated) elif has_placeholder(upper_bound): @@ -4965,6 +5004,11 @@ def process_typevar_parameters( variance = INVARIANT return variance, upper_bound, default + def record_fixed_type(self, fixed: TypeInfo | TypeAlias) -> None: + if self.types_fixed is None: + self.types_fixed = set() + self.types_fixed.add(fixed) + def get_typevarlike_argument( self, typevarlike_name: str, @@ -4976,7 +5020,7 @@ def get_typevarlike_argument( allow_param_spec_literals: bool = False, allow_unpack: bool = False, report_invalid_typevar_arg: bool = True, - ) -> ProperType | None: + ) -> Type | None: try: # We want to use our custom error message below, so we suppress # the default error message for invalid types here. @@ -4987,6 +5031,7 @@ def get_typevarlike_argument( allow_unbound_tvars=allow_unbound_tvars, allow_param_spec_literals=allow_param_spec_literals, allow_unpack=allow_unpack, + analyzing_tvar_def=param_name == "default", ) if analyzed is None: # Type variables are special: we need to place them in the symbol table @@ -4996,15 +5041,19 @@ def get_typevarlike_argument( # class Custom(Generic[T]): # ... analyzed = PlaceholderType(None, [], context.line) - typ = get_proper_type(analyzed) - if report_invalid_typevar_arg and isinstance(typ, AnyType) and typ.is_from_error: + if ( + report_invalid_typevar_arg + and isinstance(analyzed, ProperType) + and isinstance(analyzed, AnyType) + and analyzed.is_from_error + ): self.fail( message_registry.TYPEVAR_ARG_MUST_BE_TYPE.format(typevarlike_name, param_name), param_value, ) # Note: we do not return 'None' here -- we want to continue # using the AnyType. - return typ + return analyzed except TypeTranslationError: if report_invalid_typevar_arg: self.fail( @@ -5049,6 +5098,8 @@ def process_paramspec_declaration(self, s: AssignmentStmt) -> bool: if n_values != 0: self.fail('Too many positional arguments for "ParamSpec"', s) + if self.types_fixed is not None: + self.types_fixed.clear() default: Type = AnyType(TypeOfAny.from_omitted_generics) for param_value, param_name in zip( call.args[1 + n_values :], call.arg_names[1 + n_values :] @@ -5073,6 +5124,8 @@ def process_paramspec_declaration(self, s: AssignmentStmt) -> bool: "The variance and bound arguments to ParamSpec do not have defined semantics yet", s, ) + default_depends = self.types_fixed + self.types_fixed = None # PEP 612 reserves the right to define bound, covariant and contravariant arguments to # ParamSpec in a later PEP. If and when that happens, we should do something @@ -5083,12 +5136,14 @@ def process_paramspec_declaration(self, s: AssignmentStmt) -> bool: name, self.qualified_name(name), self.object_type(), default, INVARIANT ) paramspec_var.line = call.line + paramspec_var.default_depends = default_depends call.analyzed = paramspec_var updated = True else: assert isinstance(call.analyzed, ParamSpecExpr) updated = default != call.analyzed.default call.analyzed.default = default + call.analyzed.default_depends = default_depends if has_placeholder(default): self.process_placeholder(None, "ParamSpec default", s, force_progress=updated) @@ -5111,6 +5166,8 @@ def process_typevartuple_declaration(self, s: AssignmentStmt) -> bool: self.fail('Too many positional arguments for "TypeVarTuple"', s) default: Type = AnyType(TypeOfAny.from_omitted_generics) + if self.types_fixed is not None: + self.types_fixed.clear() for param_value, param_name in zip( call.args[1 + n_values :], call.arg_names[1 + n_values :] ): @@ -5129,6 +5186,9 @@ def process_typevartuple_declaration(self, s: AssignmentStmt) -> bool: else: self.fail(f'Unexpected keyword argument "{param_name}" for "TypeVarTuple"', s) + default_depends = self.types_fixed + self.types_fixed = None + name = self.extract_typevarlike_name(s, call) if name is None: return False @@ -5146,12 +5206,14 @@ def process_typevartuple_declaration(self, s: AssignmentStmt) -> bool: INVARIANT, ) typevartuple_var.line = call.line + typevartuple_var.default_depends = default_depends call.analyzed = typevartuple_var updated = True else: assert isinstance(call.analyzed, TypeVarTupleExpr) updated = default != call.analyzed.default call.analyzed.default = default + call.analyzed.default_depends = default_depends if has_placeholder(default): self.process_placeholder(None, "TypeVarTuple default", s, force_progress=updated) @@ -5682,7 +5744,7 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None: return tag = self.track_incomplete_refs() - res, alias_tvars, depends_on, indexed = self.analyze_alias( + res, alias_tvars, depends_on, indexed, default_depends = self.analyze_alias( s.name.name, s.value.expr(), allow_placeholder=True, @@ -5709,13 +5771,10 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None: self.mark_incomplete(s.name.name, s.value, becomes_typeinfo=True) return - # Now go through all new variables and temporary replace all tvars that still - # refer to some placeholders. We defer the whole alias and will revisit it again, - # as well as all its dependents. - for i, tv in enumerate(alias_tvars): - if has_placeholder(tv): - self.mark_incomplete(s.name.name, s.value, becomes_typeinfo=True) - alias_tvars[i] = self._trivial_typevarlike_like(tv) + if any(has_placeholder(tv) for tv in alias_tvars): + # Defer the alias if some type variables are not ready, same as for classes. + # Note: progress is forced below (if needed). + self.defer() self.add_type_alias_deps(depends_on) check_for_explicit_any( @@ -5741,6 +5800,7 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None: eager=eager, python_3_12_type_alias=True, ) + alias_node.default_depends = default_depends s.alias_node = alias_node if ( @@ -5757,6 +5817,7 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None: # Copy expansion to the existing alias, this matches how we update base classes # for a TypeInfo _in place_ if there are nested placeholders. existing.node.target = res + existing.node.default_depends = default_depends existing.node.alias_tvars = alias_tvars updated = True # Invalidate recursive status cache in case it was previously set. @@ -5783,46 +5844,6 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None: finally: self.pop_type_args(s.type_args) - def _trivial_typevarlike_like(self, tv: TypeVarLikeType) -> TypeVarLikeType: - object_type = self.named_type("builtins.object") - if isinstance(tv, TypeVarType): - return TypeVarType( - tv.name, - tv.fullname, - tv.id, - values=[], - upper_bound=object_type, - default=AnyType(TypeOfAny.from_omitted_generics), - variance=tv.variance, - line=tv.line, - column=tv.column, - ) - elif isinstance(tv, TypeVarTupleType): - tuple_type = self.named_type("builtins.tuple", [object_type]) - return TypeVarTupleType( - tv.name, - tv.fullname, - tv.id, - upper_bound=tuple_type, - tuple_fallback=tuple_type, - default=AnyType(TypeOfAny.from_omitted_generics), - line=tv.line, - column=tv.column, - ) - elif isinstance(tv, ParamSpecType): - return ParamSpecType( - tv.name, - tv.fullname, - tv.id, - flavor=tv.flavor, - upper_bound=object_type, - default=AnyType(TypeOfAny.from_omitted_generics), - line=tv.line, - column=tv.column, - ) - else: - assert False, f"Unknown TypeVarLike: {tv!r}" - # # Expressions # @@ -7708,6 +7729,7 @@ def expr_to_analyzed_type( allow_unbound_tvars: bool = False, allow_param_spec_literals: bool = False, allow_unpack: bool = False, + analyzing_tvar_def: bool = False, ) -> Type | None: if isinstance(expr, CallExpr): # This is a legacy syntax intended mostly for Python 2, we keep it for @@ -7739,6 +7761,7 @@ def expr_to_analyzed_type( allow_unbound_tvars=allow_unbound_tvars, allow_param_spec_literals=allow_param_spec_literals, allow_unpack=allow_unpack, + analyzing_tvar_def=analyzing_tvar_def, ) def analyze_type_expr(self, expr: Expression) -> None: @@ -7766,6 +7789,7 @@ def type_analyzer( prohibit_self_type: str | None = None, prohibit_special_class_field_types: str | None = None, allow_type_any: bool = False, + analyzing_tvar_def: bool = False, ) -> TypeAnalyser: if tvar_scope is None: tvar_scope = self.tvar_scope @@ -7787,6 +7811,7 @@ def type_analyzer( prohibit_self_type=prohibit_self_type, prohibit_special_class_field_types=prohibit_special_class_field_types, allow_type_any=allow_type_any, + analyzing_tvar_def=analyzing_tvar_def, ) tpan.in_dynamic_func = bool(self.function_stack and self.function_stack[-1].is_dynamic()) tpan.global_scope = not self.type and not self.function_stack @@ -7813,6 +7838,7 @@ def anal_type( prohibit_self_type: str | None = None, prohibit_special_class_field_types: str | None = None, allow_type_any: bool = False, + analyzing_tvar_def: bool = False, ) -> Type | None: """Semantically analyze a type. @@ -7850,6 +7876,7 @@ def anal_type( prohibit_self_type=prohibit_self_type, prohibit_special_class_field_types=prohibit_special_class_field_types, allow_type_any=allow_type_any, + analyzing_tvar_def=analyzing_tvar_def, ) tag = self.track_incomplete_refs() typ = typ.accept(a) diff --git a/mypy/semanal_shared.py b/mypy/semanal_shared.py index a85d4ed00b5e..c682a3eeb6fb 100644 --- a/mypy/semanal_shared.py +++ b/mypy/semanal_shared.py @@ -24,6 +24,7 @@ SymbolNode, SymbolTable, SymbolTableNode, + TypeAlias, TypeInfo, ) from mypy.plugin import SemanticAnalyzerPluginInterface @@ -84,6 +85,10 @@ def lookup_fully_qualified(self, fullname: str, /) -> SymbolTableNode: def lookup_fully_qualified_or_none(self, fullname: str, /) -> SymbolTableNode | None: raise NotImplementedError + @abstractmethod + def record_fixed_type(self, fixed: TypeInfo | TypeAlias) -> None: + raise NotImplementedError + @abstractmethod def fail( self, diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index 1b38481ba000..d668121bc5b9 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -595,7 +595,15 @@ def visit_type_alias_type(self, t: TypeAliasType, /) -> bool: elif t in self.seen_aliases: return self.default self.seen_aliases.add(t) - return get_proper_type(t).accept(self) + res = get_proper_type(t).accept(self) + # This is a weird edge case: if a type alias has unused type variables, we + # should visit arguments even if we didn't find anything in the expansion. + # As an optimization, do this only for new style type aliases. + assert t.alias is not None + if self.strategy == ANY_STRATEGY: + return res or (t.alias.python_3_12_type_alias and self.query_types(t.args)) + else: + return res and (not t.alias.python_3_12_type_alias or self.query_types(t.args)) def query_types(self, types: list[Type] | tuple[Type, ...]) -> bool: """Perform a query for a sequence of types using the strategy to combine the results.""" diff --git a/mypy/typeanal.py b/mypy/typeanal.py index db5625619262..c5101fd8ee21 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -222,6 +222,7 @@ def __init__( allowed_alias_tvars: list[TypeVarLikeType] | None = None, allow_type_any: bool = False, alias_type_params_names: list[str] | None = None, + analyzing_tvar_def: bool = False, ) -> None: self.api = api self.fail_func = api.fail @@ -268,6 +269,8 @@ def __init__( self.allow_type_any = allow_type_any self.allow_type_var_tuple = False self.allow_unpack = allow_unpack + # Set when we are analyzing a default of a type variable. + self.analyzing_tvar_def = analyzing_tvar_def def lookup_qualified( self, name: str, ctx: Context, suppress_errors: bool = False @@ -473,17 +476,22 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) an_args = self.pack_paramspec_args(an_args, t.empty_tuple_index) disallow_any = self.options.disallow_any_generics and not self.is_typeshed_stub - res = instantiate_type_alias( + res, used_default = instantiate_type_alias( node, an_args, self.fail, + self.note, node.no_args, t, self.options, unexpanded_type=t, disallow_any=disallow_any, empty_tuple_index=t.empty_tuple_index, + analyzing_tvar_def=self.analyzing_tvar_def, ) + if self.analyzing_tvar_def and used_default and isinstance(res, TypeAliasType): + assert res.alias is not None + self.api.record_fixed_type(res.alias) # The only case where instantiate_type_alias() can return an incorrect instance is # when it is top-level instance, so no need to recurse. if ( @@ -492,7 +500,7 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) and not (self.defining_alias and self.nesting_level == 0) and not validate_instance(res, self.fail, t.empty_tuple_index) ): - fix_instance( + used_default = fix_instance( res, self.fail, self.note, @@ -500,7 +508,10 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) options=self.options, use_generic_error=True, unexpanded_type=t, + analyzing_tvar_def=self.analyzing_tvar_def, ) + if self.analyzing_tvar_def and used_default: + self.api.record_fixed_type(res.type) if node.eager: res = get_proper_type(res) return res @@ -877,29 +888,39 @@ def analyze_type_with_type_info( if not (self.defining_alias and self.nesting_level == 0) and not validate_instance( instance, self.fail, empty_tuple_index ): - fix_instance( + used_default = fix_instance( instance, self.fail, self.note, disallow_any=self.options.disallow_any_generics and not self.is_typeshed_stub, options=self.options, + analyzing_tvar_def=self.analyzing_tvar_def, ) + if self.analyzing_tvar_def and used_default: + self.api.record_fixed_type(info) tup = info.tuple_type if tup is not None: # The class has a Tuple[...] base class so it will be # represented as a tuple type. if info.special_alias: - return instantiate_type_alias( + res, used_default = instantiate_type_alias( info.special_alias, # TODO: should we allow NamedTuples generic in ParamSpec? self.anal_array(args, allow_unpack=True), self.fail, + self.note, False, ctx, self.options, use_standard_error=True, + analyzing_tvar_def=self.analyzing_tvar_def, ) + if self.analyzing_tvar_def and used_default: + # For convenience, we make default depend on the original TypeInfo, + # *not* on the special alias. + self.api.record_fixed_type(info) + return res return tup.copy_modified( items=self.anal_array(tup.items, allow_unpack=True), fallback=instance ) @@ -908,16 +929,23 @@ def analyze_type_with_type_info( # The class has a TypedDict[...] base class so it will be # represented as a typeddict type. if info.special_alias: - return instantiate_type_alias( + res, used_default = instantiate_type_alias( info.special_alias, # TODO: should we allow TypedDicts generic in ParamSpec? self.anal_array(args, allow_unpack=True), self.fail, + self.note, False, ctx, self.options, use_standard_error=True, + analyzing_tvar_def=self.analyzing_tvar_def, ) + if self.analyzing_tvar_def and used_default: + # For convenience, we make default depend on the original TypeInfo, + # *not* on the special alias. + self.api.record_fixed_type(info) + return res # Create a named TypedDictType return td.copy_modified( item_types=self.anal_array(list(td.items.values())), fallback=instance @@ -2026,6 +2054,7 @@ def get_omitted_any( options: Options, fullname: str | None = None, unexpanded_type: Type | None = None, + used_default: bool = False, ) -> AnyType: if disallow_any: typ = unexpanded_type or orig_type @@ -2036,6 +2065,8 @@ def get_omitted_any( typ, code=codes.TYPE_ARG, ) + if used_default: + note(message_registry.NO_CYCLIC_DEFAULT, typ, code=codes.TYPE_ARG) any_type = AnyType(TypeOfAny.from_error, line=typ.line, column=typ.column) else: @@ -2065,11 +2096,13 @@ def fix_instance( options: Options, use_generic_error: bool = False, unexpanded_type: Type | None = None, -) -> None: + analyzing_tvar_def: bool = False, +) -> bool: """Fix a malformed instance by replacing all type arguments with TypeVar default or Any. Also emit a suitable error if this is not due to implicit Any's. """ + used_default = False arg_count = len(t.args) min_tv_count = sum(not tv.has_default() for tv in t.type.defn.type_vars) max_tv_count = len(t.type.type_vars) @@ -2088,15 +2121,33 @@ def fix_instance( if tv is None: continue if arg is None: + use_any = False if tv.has_default(): arg = tv.default + if analyzing_tvar_def: + # Record the use of default only when analyzing another default. + used_default = True + if is_typevar_default_recursive(tv.fullname, t.type): + # If this results in infinite recursion, use Any instead. + use_any = True else: + use_any = True + if use_any: if any_type is None: fullname = None if use_generic_error else t.type.fullname any_type = get_omitted_any( - disallow_any, fail, note, t, options, fullname, unexpanded_type + disallow_any, + fail, + note, + t, + options, + fullname, + unexpanded_type, + used_default, ) arg = any_type + else: + assert arg is not None args.append(arg) env[tv.id] = arg t.args = tuple(args) @@ -2106,12 +2157,14 @@ def fix_instance( fixed = expand_type(t, env) assert isinstance(fixed, Instance) t.args = fixed.args + return used_default def instantiate_type_alias( node: TypeAlias, args: list[Type], fail: MsgCallback, + note: MsgCallback, no_args: bool, ctx: Context, options: Options, @@ -2120,7 +2173,8 @@ def instantiate_type_alias( disallow_any: bool = False, use_standard_error: bool = False, empty_tuple_index: bool = False, -) -> Type: + analyzing_tvar_def: bool = False, +) -> tuple[Type, bool]: """Create an instance of a (generic) type alias from alias node and type arguments. We are following the rules outlined in TypeAlias docstring. @@ -2167,15 +2221,17 @@ def instantiate_type_alias( options, disallow_any=disallow_any, fail=fail, + note=note, unexpanded_type=unexpanded_type, + analyzing_tvar_def=analyzing_tvar_def, ) if max_tv_count == 0 and act_len == 0: if no_args: assert isinstance(node.target, Instance) # type: ignore[misc] # Note: this is the only case where we use an eager expansion. See more info about # no_args aliases like L = List in the docstring for TypeAlias class. - return Instance(node.target.type, [], line=ctx.line, column=ctx.column) - return TypeAliasType(node, [], line=ctx.line, column=ctx.column) + return Instance(node.target.type, [], line=ctx.line, column=ctx.column), False + return TypeAliasType(node, [], line=ctx.line, column=ctx.column), False if ( max_tv_count == 0 and act_len > 0 @@ -2187,7 +2243,7 @@ def instantiate_type_alias( tp.column = ctx.column tp.end_line = ctx.end_line tp.end_column = ctx.end_column - return tp + return tp, False if node.tvar_tuple_index is None: if any(isinstance(a, UnpackType) for a in args): # A variadic unpack in fixed size alias (fixed unpacks must be flattened by the caller) @@ -2233,7 +2289,18 @@ def instantiate_type_alias( ) fail(msg, ctx, code=codes.TYPE_ARG) args = [] - return set_any_tvars(node, args, ctx.line, ctx.column, options, from_error=True) + return set_any_tvars( + node, + args, + ctx.line, + ctx.column, + options, + disallow_any=disallow_any, + fail=fail, + note=note, + from_error=not correct, + analyzing_tvar_def=analyzing_tvar_def, + ) elif node.tvar_tuple_index is not None: # We also need to check if we are not performing a type variable tuple split. unpack = find_unpack_in_list(args) @@ -2260,8 +2327,8 @@ def instantiate_type_alias( ): exp = get_proper_type(typ) assert isinstance(exp, Instance) - return exp.args[-1] - return typ + return exp.args[-1], False + return typ, False def set_any_tvars( @@ -2275,8 +2342,11 @@ def set_any_tvars( disallow_any: bool = False, special_form: bool = False, fail: MsgCallback | None = None, + note: MsgCallback | None = None, unexpanded_type: Type | None = None, -) -> TypeAliasType: + analyzing_tvar_def: bool = False, +) -> tuple[TypeAliasType, bool]: + used_default = False if from_error or disallow_any: type_of_any = TypeOfAny.from_error elif special_form: @@ -2294,6 +2364,12 @@ def set_any_tvars( if arg is None: if tv.has_default(): arg = tv.default + # Same as for instances, record and avoid infinite recursion. + if analyzing_tvar_def: + used_default = True + if is_typevar_default_recursive(tv.fullname, node): + arg = any_type + used_any_type = True else: arg = any_type used_any_type = True @@ -2310,7 +2386,7 @@ def set_any_tvars( assert isinstance(fixed, TypeAliasType) t.args = fixed.args - if used_any_type and disallow_any and node.alias_tvars: + if used_any_type and disallow_any and node.alias_tvars and not from_error: assert fail is not None if unexpanded_type: type_str = ( @@ -2326,7 +2402,34 @@ def set_any_tvars( Context(newline, newcolumn), code=codes.TYPE_ARG, ) - return t + if used_default: + assert note is not None + note( + message_registry.NO_CYCLIC_DEFAULT, + Context(newline, newcolumn), + code=codes.TYPE_ARG, + ) + return t, used_default + + +def is_typevar_default_recursive(tv_fname: str, start: TypeInfo | TypeAlias) -> bool: + """Check if the type variable can lead to infinite recursion via defaults.""" + if tv_fname not in start.default_depends: + return False + todo = start.default_depends[tv_fname].copy() + seen: set[TypeAlias | TypeInfo] = set() + while todo: + node = todo.pop() + if node is start: + return True + if node in seen: + # We don't return True here, since we are interested only in + # recursion via the original type variable. + continue + seen.add(node) + for dep_nodes in node.default_depends.values(): + todo |= dep_nodes + return False class DivergingAliasDetector(TrivialSyntheticTypeTranslator): diff --git a/mypy/types.py b/mypy/types.py index 40c3839e2efc..d0b0f1b7a1bc 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -699,6 +699,7 @@ def __eq__(self, other: object) -> bool: self.id == other.id and self.upper_bound == other.upper_bound and self.values == other.values + and self.default == other.default ) def serialize(self) -> JsonDict: @@ -854,7 +855,12 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, ParamSpecType): return NotImplemented # Upper bound can be ignored, since it's determined by flavor. - return self.id == other.id and self.flavor == other.flavor and self.prefix == other.prefix + return ( + self.id == other.id + and self.flavor == other.flavor + and self.prefix == other.prefix + and self.default == other.default + ) def serialize(self) -> JsonDict: assert not self.id.is_meta_var() @@ -1003,7 +1009,9 @@ def __hash__(self) -> int: def __eq__(self, other: object) -> bool: if not isinstance(other, TypeVarTupleType): return NotImplemented - return self.id == other.id and self.min_len == other.min_len + return ( + self.id == other.id and self.min_len == other.min_len and self.default == other.default + ) def copy_modified( self, diff --git a/test-data/unit/check-flags.test b/test-data/unit/check-flags.test index a281218af58e..dd4687181ca4 100644 --- a/test-data/unit/check-flags.test +++ b/test-data/unit/check-flags.test @@ -2717,3 +2717,34 @@ if ( or z is None ): pass + +[case testRecursiveTypeVarDefaultMutualDisallow] +# flags: --disallow-any-generic +from typing import TypeVar, Generic + +class C(Generic["T"]): + pass + +class D(Generic["S"]): + pass + +T = TypeVar("T", default=D) # E: Missing type arguments for generic type "D" \ + # N: Cyclic type variable defaults are not supported +S = TypeVar("S", default=C) # E: Missing type arguments for generic type "C" \ + # N: Cyclic type variable defaults are not supported + +c: C +d: D +reveal_type(c) # N: Revealed type is "__main__.C[__main__.D[Any]]" +reveal_type(d) # N: Revealed type is "__main__.D[__main__.C[Any]]" + +[case testRecursiveTypeVarDefaultOnlyAliasDisallow] +# flags: --disallow-any-generic +from typing import TypeVar + +T = TypeVar("T", default="A") # E: Missing type arguments for generic type "A" \ + # N: Cyclic type variable defaults are not supported +A = list[T] +a: A +reveal_type(a) # N: Revealed type is "builtins.list[builtins.list[Any]]" +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-python313.test b/test-data/unit/check-python313.test index 8a80977fb22a..87bd722742b5 100644 --- a/test-data/unit/check-python313.test +++ b/test-data/unit/check-python313.test @@ -337,3 +337,89 @@ type Result[T, E] = Ok[T, E] | Err[E, T] class Bar[U]: def foo(data: U, cond: bool) -> Result[U, str]: return Ok(data) if cond else Err("Error") + +[case testRecursiveTypeVarDefaultBasicNewStyle] +class C[T: C = C]: + pass + +c: C +reveal_type(c) # N: Revealed type is "__main__.C[__main__.C[Any]]" + +[case testRecursiveTypeVarDefaultMutualNewStyle] +class C[T = D]: + pass + +class D[S = C]: + pass + +c: C +d: D +reveal_type(c) # N: Revealed type is "__main__.C[__main__.D[Any]]" +reveal_type(d) # N: Revealed type is "__main__.D[__main__.C[Any]]" + +[case testNonRecursiveSimpleTypeVarDefaultNewStyle] +class Child[S = Parent]: ... + +class Parent[T = int]: ... + +reveal_type(Child()) # N: Revealed type is "__main__.Child[__main__.Parent[builtins.int]]" + +[case testNonRecursiveTypeVarDefaultImportCycleClassNewStyle] +import exp +[file exp.pyi] +import ind + +class F[T: D = D]: + x: T + +class D(E): ... +class E: ... + +[file ind.pyi] +from exp import F + +class Ind(F): ... +x: Ind +reveal_type(x.x) # N: Revealed type is "exp.D" + +[case testNonRecursiveTypeVarDefaultImportCycleAliasNewStyle] +import exp +[file exp.pyi] +import ind + +type F[T: D = D] = list[T] + +class D(E): ... +class E: ... + +[file ind.pyi] +from exp import F + +type Ind = list[F] +x: Ind +reveal_type(x) # N: Revealed type is "builtins.list[builtins.list[exp.D]]" + +[case testRecursiveTypeVarDefaultClassAndAliasNewStyle] +class Trait[T = Pattern]: + pass + +class Pattern1(Trait): + pass +class Pattern2(Trait): + pass + +type Pattern = Pattern1 | Pattern2 + +reveal_type(Trait()) # N: Revealed type is "__main__.Trait[__main__.Pattern1 | __main__.Pattern2]" +[builtins fixtures/tuple.pyi] + +[case testRecursiveTypeVarDefaultOnlyAliasNewStyle] +type A[T = A] = list[T] +a: A +reveal_type(a) # N: Revealed type is "builtins.list[builtins.list[Any]]" +[builtins fixtures/tuple.pyi] + +[case testRecursiveAliasTypeVarDefaultNewStyle] +type A[T = A] = int +a: A +reveal_type(a) # N: Revealed type is "builtins.int" diff --git a/test-data/unit/check-typevar-defaults.test b/test-data/unit/check-typevar-defaults.test index 535d882ccf3c..5895830631b4 100644 --- a/test-data/unit/check-typevar-defaults.test +++ b/test-data/unit/check-typevar-defaults.test @@ -937,3 +937,107 @@ reveal_type(D) # N: Revealed type is "def [T2 = Any, T1 = Any] () -> __main__.D d: D reveal_type(d) # N: Revealed type is "__main__.D[Any, Any]" [builtins fixtures/tuple.pyi] + +[case testRecursiveTypeVarDefaultBasic] +from typing import TypeVar, Generic + +class C(Generic["T"]): + pass + +T = TypeVar("T", bound=C, default=C) + +c: C +reveal_type(c) # N: Revealed type is "__main__.C[__main__.C[Any]]" + +[case testRecursiveTypeVarDefaultMutual] +from typing import TypeVar, Generic + +class C(Generic["T"]): + pass + +class D(Generic["S"]): + pass + +T = TypeVar("T", default=D) +S = TypeVar("S", default=C) + +c: C +d: D +reveal_type(c) # N: Revealed type is "__main__.C[__main__.D[Any]]" +reveal_type(d) # N: Revealed type is "__main__.D[__main__.C[Any]]" + +[case testNonRecursiveSimpleTypeVarDefault] +from typing import TypeVar, Generic + +S = TypeVar("S", default="Parent") +class Child(Generic[S]): ... + +T = TypeVar("T", default=int) +class Parent(Generic[T]): ... + +reveal_type(Child()) # N: Revealed type is "__main__.Child[__main__.Parent[builtins.int]]" + +[case testNonRecursiveTypeVarDefaultImportCycleClass] +import exp +[file exp.pyi] +from typing import Generic, TypeVar +import ind + +T = TypeVar("T", bound=D, default=D) +class F(Generic[T]): + x: T + +class D(E): ... +class E: ... + +[file ind.pyi] +from exp import F + +class Ind(F): ... +x: Ind +reveal_type(x.x) # N: Revealed type is "exp.D" + +[case testNonRecursiveTypeVarDefaultImportCycleAlias] +import exp +[file exp.pyi] +from typing import TypeVar +import ind + +T = TypeVar("T", bound=D, default=D) +F = list[T] + +class D(E): ... +class E: ... + +[file ind.pyi] +from exp import F + +Ind = list[F] +x: Ind +reveal_type(x) # N: Revealed type is "builtins.list[builtins.list[exp.D]]" + +[case testRecursiveTypeVarDefaultClassAndAlias] +from typing import Generic, TypeVar, Union + +T = TypeVar("T", default="Pattern") + +class Trait(Generic[T]): + pass + +class Pattern1(Trait): + pass + +Pattern = Union[Pattern1, None] + +reveal_type(Trait()) # N: Revealed type is "__main__.Trait[__main__.Pattern1 | None]" +[builtins fixtures/tuple.pyi] + +[case testRecursiveTypeVarDefaultOnlyAlias] +from typing import TypeVar + +T = TypeVar("T", default="A") + +A = list[T] +a: A +reveal_type(a) # N: Revealed type is "builtins.list[builtins.list[Any]]" +[builtins fixtures/tuple.pyi]