"""Shared definitions used by different parts of type checker.""" from __future__ import annotations from abc import abstractmethod from collections.abc import Iterator, Sequence from contextlib import contextmanager from typing import NamedTuple, overload from mypy_extensions import trait from mypy.errorcodes import ErrorCode from mypy.errors import ErrorWatcher from mypy.message_registry import ErrorMessage from mypy.nodes import ( ArgKind, Context, Expression, FuncItem, LambdaExpr, MypyFile, Node, RefExpr, SymbolNode, TypeInfo, Var, ) from mypy.plugin import CheckerPluginInterface, Plugin from mypy.types import ( CallableType, Instance, LiteralValue, Overloaded, PartialType, TupleType, Type, TypedDictType, TypeType, ) from mypy.typevars import fill_typevars # An object that represents either a precise type or a type with an upper bound; # it is important for correct type inference with isinstance. class TypeRange(NamedTuple): item: Type is_upper_bound: bool # False => precise type @trait class ExpressionCheckerSharedApi: @abstractmethod def accept( self, node: Expression, type_context: Type | None = None, allow_none_return: bool = False, always_allow_any: bool = False, is_callee: bool = False, ) -> Type: raise NotImplementedError @abstractmethod def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type: raise NotImplementedError @abstractmethod def check_call( self, callee: Type, args: list[Expression], arg_kinds: list[ArgKind], context: Context, arg_names: Sequence[str | None] | None = None, callable_node: Expression | None = None, callable_name: str | None = None, object_type: Type | None = None, original_type: Type | None = None, ) -> tuple[Type, Type]: raise NotImplementedError @abstractmethod def transform_callee_type( self, callable_name: str | None, callee: Type, args: list[Expression], arg_kinds: list[ArgKind], context: Context, arg_names: Sequence[str | None] | None = None, object_type: Type | None = None, ) -> Type: raise NotImplementedError @abstractmethod def method_fullname(self, object_type: Type, method_name: str) -> str | None: raise NotImplementedError @abstractmethod def check_method_call_by_name( self, method: str, base_type: Type, args: list[Expression], arg_kinds: list[ArgKind], context: Context, original_type: Type | None = None, ) -> tuple[Type, Type]: raise NotImplementedError @abstractmethod def visit_typeddict_index_expr( self, td_type: TypedDictType, index: Expression, setitem: bool = False ) -> tuple[Type, set[str]]: raise NotImplementedError @abstractmethod def infer_literal_expr_type(self, value: LiteralValue, fallback_name: str) -> Type: raise NotImplementedError @abstractmethod def analyze_static_reference( self, node: SymbolNode, ctx: Context, is_lvalue: bool, *, include_modules: bool = True, suppress_errors: bool = False, ) -> Type: raise NotImplementedError @trait class TypeCheckerSharedApi(CheckerPluginInterface): plugin: Plugin module_refs: set[str] scope: CheckerScope checking_missing_await: bool allow_constructor_cache: bool @property @abstractmethod def expr_checker(self) -> ExpressionCheckerSharedApi: raise NotImplementedError @abstractmethod def named_type(self, name: str) -> Instance: raise NotImplementedError @abstractmethod def lookup_typeinfo(self, fullname: str) -> TypeInfo: raise NotImplementedError @abstractmethod def lookup_type(self, node: Expression) -> Type: raise NotImplementedError @abstractmethod def handle_cannot_determine_type(self, name: str, context: Context) -> None: raise NotImplementedError @abstractmethod def handle_partial_var_type( self, typ: PartialType, is_lvalue: bool, node: Var, context: Context ) -> Type: raise NotImplementedError @overload @abstractmethod def check_subtype( self, subtype: Type, supertype: Type, context: Context, msg: str, subtype_label: str | None = None, supertype_label: str | None = None, *, notes: list[str] | None = None, code: ErrorCode | None = None, outer_context: Context | None = None, ) -> bool: ... @overload @abstractmethod def check_subtype( self, subtype: Type, supertype: Type, context: Context, msg: ErrorMessage, subtype_label: str | None = None, supertype_label: str | None = None, *, notes: list[str] | None = None, outer_context: Context | None = None, ) -> bool: ... # Unfortunately, mypyc doesn't support abstract overloads yet. @abstractmethod def check_subtype( self, subtype: Type, supertype: Type, context: Context, msg: str | ErrorMessage, subtype_label: str | None = None, supertype_label: str | None = None, *, notes: list[str] | None = None, code: ErrorCode | None = None, outer_context: Context | None = None, ) -> bool: raise NotImplementedError @abstractmethod def get_final_context(self) -> bool: raise NotImplementedError @overload @abstractmethod def conditional_types_with_intersection( self, expr_type: Type, type_ranges: list[TypeRange] | None, ctx: Context, default: None = None, ) -> tuple[Type | None, Type | None]: ... @overload @abstractmethod def conditional_types_with_intersection( self, expr_type: Type, type_ranges: list[TypeRange] | None, ctx: Context, default: Type ) -> tuple[Type, Type]: ... # Unfortunately, mypyc doesn't support abstract overloads yet. @abstractmethod def conditional_types_with_intersection( self, expr_type: Type, type_ranges: list[TypeRange] | None, ctx: Context, default: Type | None = None, ) -> tuple[Type | None, Type | None]: raise NotImplementedError @abstractmethod def check_deprecated(self, node: Node | None, context: Context) -> None: raise NotImplementedError @abstractmethod def warn_deprecated(self, node: Node | None, context: Context) -> None: raise NotImplementedError @abstractmethod def type_is_iterable(self, type: Type) -> bool: raise NotImplementedError @abstractmethod def iterable_item_type( self, it: Instance | CallableType | TypeType | Overloaded, context: Context ) -> Type: raise NotImplementedError @abstractmethod @contextmanager def checking_await_set(self) -> Iterator[None]: raise NotImplementedError @abstractmethod def get_precise_awaitable_type(self, typ: Type, local_errors: ErrorWatcher) -> Type | None: raise NotImplementedError @abstractmethod def add_any_attribute_to_type(self, typ: Type, name: str) -> Type: raise NotImplementedError @abstractmethod def is_defined_in_stub(self, typ: Instance, /) -> bool: raise NotImplementedError class CheckerScope: # We keep two stacks combined, to maintain the relative order stack: list[TypeInfo | FuncItem | MypyFile] def __init__(self, module: MypyFile) -> None: self.stack = [module] def current_function(self) -> FuncItem | None: for e in reversed(self.stack): if isinstance(e, FuncItem): return e return None def top_level_function(self) -> FuncItem | None: """Return top-level non-lambda function.""" for e in self.stack: if isinstance(e, FuncItem) and not isinstance(e, LambdaExpr): return e return None def active_class(self) -> TypeInfo | None: if isinstance(self.stack[-1], TypeInfo): return self.stack[-1] return None def enclosing_class(self, func: FuncItem | None = None) -> TypeInfo | None: """Is there a class *directly* enclosing this function?""" func = func or self.current_function() assert func, "This method must be called from inside a function" index = self.stack.index(func) assert index, "CheckerScope stack must always start with a module" enclosing = self.stack[index - 1] if isinstance(enclosing, TypeInfo): return enclosing return None def active_self_type(self) -> Instance | TupleType | None: """An instance or tuple type representing the current class. This returns None unless we are in class body or in a method. In particular, inside a function nested in method this returns None. """ info = self.active_class() if not info and self.current_function(): info = self.enclosing_class() if info: return fill_typevars(info) return None def current_self_type(self) -> Instance | TupleType | None: """Same as active_self_type() but handle functions nested in methods.""" for item in reversed(self.stack): if isinstance(item, TypeInfo): return fill_typevars(item) return None def is_top_level(self) -> bool: """Is current scope top-level (no classes or functions)?""" return len(self.stack) == 1 @contextmanager def push_function(self, item: FuncItem) -> Iterator[None]: self.stack.append(item) yield self.stack.pop() @contextmanager def push_class(self, info: TypeInfo) -> Iterator[None]: self.stack.append(info) yield self.stack.pop()