from __future__ import annotations from collections.abc import Iterable, Sequence from typing import Callable import mypy.subtypes from mypy.erasetype import erase_typevars from mypy.expandtype import expand_type from mypy.nodes import Context, TypeInfo from mypy.type_visitor import TypeTranslator from mypy.typeops import get_all_type_vars from mypy.types import ( AnyType, CallableType, Instance, Parameters, ParamSpecFlavor, ParamSpecType, PartialType, ProperType, Type, TypeAliasType, TypeVarId, TypeVarLikeType, TypeVarTupleType, TypeVarType, UninhabitedType, UnpackType, get_proper_type, remove_dups, ) def get_target_type( tvar: TypeVarLikeType, type: Type, callable: CallableType, report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None], context: Context, skip_unsatisfied: bool, ) -> Type | None: p_type = get_proper_type(type) if isinstance(p_type, UninhabitedType) and tvar.has_default(): return tvar.default if isinstance(tvar, ParamSpecType): return type if isinstance(tvar, TypeVarTupleType): return type assert isinstance(tvar, TypeVarType) values = tvar.values if values: if isinstance(p_type, AnyType): return type if isinstance(p_type, TypeVarType) and p_type.values: # Allow substituting T1 for T if every allowed value of T1 # is also a legal value of T. if all(any(mypy.subtypes.is_same_type(v, v1) for v in values) for v1 in p_type.values): return type matching = [] for value in values: if mypy.subtypes.is_subtype(type, value): matching.append(value) if matching: best = matching[0] # If there are more than one matching value, we select the narrowest for match in matching[1:]: if mypy.subtypes.is_subtype(match, best): best = match return best if skip_unsatisfied: return None report_incompatible_typevar_value(callable, type, tvar.name, context) else: upper_bound = tvar.upper_bound if tvar.name == "Self": # Internally constructed Self-types contain class type variables in upper bound, # so we need to erase them to avoid false positives. This is safe because we do # not support type variables in upper bounds of user defined types. upper_bound = erase_typevars(upper_bound) if not mypy.subtypes.is_subtype(type, upper_bound): if skip_unsatisfied: return None report_incompatible_typevar_value(callable, type, tvar.name, context) return type def apply_generic_arguments( callable: CallableType, orig_types: Sequence[Type | None], report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None], context: Context, skip_unsatisfied: bool = False, ) -> CallableType: """Apply generic type arguments to a callable type. For example, applying [int] to 'def [T] (T) -> T' results in 'def (int) -> int'. Note that each type can be None; in this case, it will not be applied. If `skip_unsatisfied` is True, then just skip the types that don't satisfy type variable bound or constraints, instead of giving an error. """ tvars = callable.variables assert len(orig_types) <= len(tvars) # Check that inferred type variable values are compatible with allowed # values and bounds. Also, promote subtype values to allowed values. # Create a map from type variable id to target type. id_to_type: dict[TypeVarId, Type] = {} for tvar, type in zip(tvars, orig_types): assert not isinstance(type, PartialType), "Internal error: must never apply partial type" if type is None: continue target_type = get_target_type( tvar, type, callable, report_incompatible_typevar_value, context, skip_unsatisfied ) if target_type is not None: id_to_type[tvar.id] = target_type # TODO: validate arg_kinds/arg_names for ParamSpec and TypeVarTuple replacements, # not just type variable bounds above. param_spec = callable.param_spec() if param_spec is not None: nt = id_to_type.get(param_spec.id) if nt is not None: # ParamSpec expansion is special-cased, so we need to always expand callable # as a whole, not expanding arguments individually. callable = expand_type(callable, id_to_type) assert isinstance(callable, CallableType) return callable.copy_modified( variables=[tv for tv in tvars if tv.id not in id_to_type] ) # Apply arguments to argument types. var_arg = callable.var_arg() if var_arg is not None and isinstance(var_arg.typ, UnpackType): # Same as for ParamSpec, callable with variadic types needs to be expanded as a whole. callable = expand_type(callable, id_to_type) assert isinstance(callable, CallableType) return callable.copy_modified(variables=[tv for tv in tvars if tv.id not in id_to_type]) else: callable = callable.copy_modified( arg_types=[expand_type(at, id_to_type) for at in callable.arg_types] ) # Apply arguments to TypeGuard and TypeIs if any. if callable.type_guard is not None: type_guard = expand_type(callable.type_guard, id_to_type) else: type_guard = None if callable.type_is is not None: type_is = expand_type(callable.type_is, id_to_type) else: type_is = None # The callable may retain some type vars if only some were applied. # TODO: move apply_poly() logic here when new inference # becomes universally used (i.e. in all passes + in unification). # With this new logic we can actually *add* some new free variables. remaining_tvars: list[TypeVarLikeType] = [] for tv in tvars: if tv.id in id_to_type: continue if not tv.has_default(): remaining_tvars.append(tv) continue # TypeVarLike isn't in id_to_type mapping. # Only expand the TypeVar default here. typ = expand_type(tv, id_to_type) assert isinstance(typ, TypeVarLikeType) remaining_tvars.append(typ) return callable.copy_modified( ret_type=expand_type(callable.ret_type, id_to_type), variables=remaining_tvars, type_guard=type_guard, type_is=type_is, ) def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> CallableType | None: """Make free type variables generic in the type if possible. This will translate the type `tp` while trying to create valid bindings for type variables `poly_tvars` while traversing the type. This follows the same rules as we do during semantic analysis phase, examples: * Callable[Callable[[T], T], T] -> def [T] (def (T) -> T) -> T * Callable[[], Callable[[T], T]] -> def () -> def [T] (T -> T) * List[T] -> None (not possible) """ try: return tp.copy_modified( arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types], ret_type=tp.ret_type.accept(PolyTranslator(poly_tvars)), variables=[], ) except PolyTranslationError: return None class PolyTranslationError(Exception): pass class PolyTranslator(TypeTranslator): """Make free type variables generic in the type if possible. See docstring for apply_poly() for details. """ def __init__( self, poly_tvars: Iterable[TypeVarLikeType], bound_tvars: frozenset[TypeVarLikeType] = frozenset(), seen_aliases: frozenset[TypeInfo] = frozenset(), ) -> None: super().__init__() self.poly_tvars = set(poly_tvars) # This is a simplified version of TypeVarScope used during semantic analysis. self.bound_tvars = bound_tvars self.seen_aliases = seen_aliases def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]: found_vars = [] for arg in t.arg_types: for tv in get_all_type_vars(arg): if isinstance(tv, ParamSpecType): normalized: TypeVarLikeType = tv.copy_modified( flavor=ParamSpecFlavor.BARE, prefix=Parameters([], [], []) ) else: normalized = tv if normalized in self.poly_tvars and normalized not in self.bound_tvars: found_vars.append(normalized) return remove_dups(found_vars) def visit_callable_type(self, t: CallableType) -> Type: found_vars = self.collect_vars(t) self.bound_tvars |= set(found_vars) result = super().visit_callable_type(t) self.bound_tvars -= set(found_vars) assert isinstance(result, ProperType) and isinstance(result, CallableType) result.variables = result.variables + tuple(found_vars) return result def visit_type_var(self, t: TypeVarType) -> Type: if t in self.poly_tvars and t not in self.bound_tvars: raise PolyTranslationError() return super().visit_type_var(t) def visit_param_spec(self, t: ParamSpecType) -> Type: if t in self.poly_tvars and t not in self.bound_tvars: raise PolyTranslationError() return super().visit_param_spec(t) def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type: if t in self.poly_tvars and t not in self.bound_tvars: raise PolyTranslationError() return super().visit_type_var_tuple(t) def visit_type_alias_type(self, t: TypeAliasType) -> Type: if not t.args: return t.copy_modified() if not t.is_recursive: return get_proper_type(t).accept(self) # We can't handle polymorphic application for recursive generic aliases # without risking an infinite recursion, just give up for now. raise PolyTranslationError() def visit_instance(self, t: Instance) -> Type: if t.type.has_param_spec_type: # We need this special-casing to preserve the possibility to store a # generic function in an instance type. Things like # forall T . Foo[[x: T], T] # are not really expressible in current type system, but this looks like # a useful feature, so let's keep it. param_spec_index = next( i for (i, tv) in enumerate(t.type.defn.type_vars) if isinstance(tv, ParamSpecType) ) p = get_proper_type(t.args[param_spec_index]) if isinstance(p, Parameters): found_vars = self.collect_vars(p) self.bound_tvars |= set(found_vars) new_args = [a.accept(self) for a in t.args] self.bound_tvars -= set(found_vars) repl = new_args[param_spec_index] assert isinstance(repl, ProperType) and isinstance(repl, Parameters) repl.variables = list(repl.variables) + list(found_vars) return t.copy_modified(args=new_args) # There is the same problem with callback protocols as with aliases # (callback protocols are essentially more flexible aliases to callables). if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]: if t.type in self.seen_aliases: raise PolyTranslationError() call = mypy.subtypes.find_member("__call__", t, t, is_operator=True) assert call is not None return call.accept( PolyTranslator(self.poly_tvars, self.bound_tvars, self.seen_aliases | {t.type}) ) return super().visit_instance(t)