from __future__ import annotations from functools import partial from typing import Callable, Final import mypy.errorcodes as codes from mypy import message_registry from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr from mypy.plugin import ( AttributeContext, ClassDefContext, FunctionContext, FunctionSigContext, MethodContext, MethodSigContext, Plugin, ) from mypy.plugins.attrs import ( attr_class_maker_callback, attr_class_makers, attr_dataclass_makers, attr_define_makers, attr_frozen_makers, attr_tag_callback, evolve_function_sig_callback, fields_function_sig_callback, ) from mypy.plugins.common import try_getting_str_literals from mypy.plugins.constants import ( ENUM_NAME_ACCESS, ENUM_VALUE_ACCESS, SINGLEDISPATCH_CALLABLE_CALL_METHOD, SINGLEDISPATCH_REGISTER_CALLABLE_CALL_METHOD, SINGLEDISPATCH_REGISTER_METHOD, ) from mypy.plugins.ctypes import ( array_constructor_callback, array_getitem_callback, array_iter_callback, array_raw_callback, array_setitem_callback, array_value_callback, ) from mypy.plugins.dataclasses import ( dataclass_class_maker_callback, dataclass_makers, dataclass_tag_callback, replace_function_sig_callback, ) from mypy.plugins.enums import enum_member_callback, enum_name_callback, enum_value_callback from mypy.plugins.functools import ( functools_total_ordering_maker_callback, functools_total_ordering_makers, partial_call_callback, partial_new_callback, ) from mypy.plugins.singledispatch import ( call_singledispatch_function_after_register_argument, call_singledispatch_function_callback, create_singledispatch_function_callback, singledispatch_register_callback, ) from mypy.subtypes import is_subtype from mypy.typeops import is_literal_type_like, make_simplified_union from mypy.types import ( TPDICT_FB_NAMES, AnyType, CallableType, FunctionLike, Instance, LiteralType, NoneType, TupleType, Type, TypedDictType, TypeOfAny, TypeVarType, UnionType, get_proper_type, get_proper_types, ) TD_SETDEFAULT_NAMES: Final = {n + ".setdefault" for n in TPDICT_FB_NAMES} TD_POP_NAMES: Final = {n + ".pop" for n in TPDICT_FB_NAMES} TD_DELITEM_NAMES: Final = {n + ".__delitem__" for n in TPDICT_FB_NAMES} TD_UPDATE_METHOD_NAMES: Final = ( {n + ".update" for n in TPDICT_FB_NAMES} | {n + ".__or__" for n in TPDICT_FB_NAMES} | {n + ".__ror__" for n in TPDICT_FB_NAMES} | {n + ".__ior__" for n in TPDICT_FB_NAMES} ) class DefaultPlugin(Plugin): """Type checker plugin that is enabled by default.""" def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None: if fullname == "_ctypes.Array": return array_constructor_callback elif fullname == "functools.singledispatch": return create_singledispatch_function_callback elif fullname == "functools.partial": return partial_new_callback elif fullname == "enum.member": return enum_member_callback return None def get_function_signature_hook( self, fullname: str ) -> Callable[[FunctionSigContext], FunctionLike] | None: if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"): return evolve_function_sig_callback elif fullname in ("attr.fields", "attrs.fields"): return fields_function_sig_callback elif fullname == "dataclasses.replace": return replace_function_sig_callback return None def get_method_signature_hook( self, fullname: str ) -> Callable[[MethodSigContext], FunctionLike] | None: if fullname == "typing.Mapping.get": return typed_dict_get_signature_callback elif fullname in TD_SETDEFAULT_NAMES: return typed_dict_setdefault_signature_callback elif fullname in TD_POP_NAMES: return typed_dict_pop_signature_callback elif fullname == "_ctypes.Array.__setitem__": return array_setitem_callback elif fullname == SINGLEDISPATCH_CALLABLE_CALL_METHOD: return call_singledispatch_function_callback elif fullname in TD_UPDATE_METHOD_NAMES: return typed_dict_update_signature_callback return None def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None: if fullname == "typing.Mapping.get": return typed_dict_get_callback elif fullname == "builtins.int.__pow__": return int_pow_callback elif fullname == "builtins.int.__neg__": return int_neg_callback elif fullname == "builtins.int.__pos__": return int_pos_callback elif fullname in ("builtins.tuple.__mul__", "builtins.tuple.__rmul__"): return tuple_mul_callback elif fullname in TD_SETDEFAULT_NAMES: return typed_dict_setdefault_callback elif fullname in TD_POP_NAMES: return typed_dict_pop_callback elif fullname in TD_DELITEM_NAMES: return typed_dict_delitem_callback elif fullname == "_ctypes.Array.__getitem__": return array_getitem_callback elif fullname == "_ctypes.Array.__iter__": return array_iter_callback elif fullname == SINGLEDISPATCH_REGISTER_METHOD: return singledispatch_register_callback elif fullname == SINGLEDISPATCH_REGISTER_CALLABLE_CALL_METHOD: return call_singledispatch_function_after_register_argument elif fullname == "functools.partial.__call__": return partial_call_callback return None def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None: if fullname == "_ctypes.Array.value": return array_value_callback elif fullname == "_ctypes.Array.raw": return array_raw_callback elif fullname in ENUM_NAME_ACCESS: return enum_name_callback elif fullname in ENUM_VALUE_ACCESS: return enum_value_callback return None def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None: # These dataclass and attrs hooks run in the main semantic analysis pass # and only tag known dataclasses/attrs classes, so that the second # hooks (in get_class_decorator_hook_2) can detect dataclasses/attrs classes # in the MRO. if fullname in dataclass_makers: return dataclass_tag_callback if ( fullname in attr_class_makers or fullname in attr_dataclass_makers or fullname in attr_frozen_makers or fullname in attr_define_makers ): return attr_tag_callback return None def get_class_decorator_hook_2( self, fullname: str ) -> Callable[[ClassDefContext], bool] | None: if fullname in dataclass_makers: return dataclass_class_maker_callback elif fullname in functools_total_ordering_makers: return functools_total_ordering_maker_callback elif fullname in attr_class_makers: return attr_class_maker_callback elif fullname in attr_dataclass_makers: return partial(attr_class_maker_callback, auto_attribs_default=True) elif fullname in attr_frozen_makers: return partial( attr_class_maker_callback, auto_attribs_default=None, frozen_default=True ) elif fullname in attr_define_makers: return partial( attr_class_maker_callback, auto_attribs_default=None, slots_default=True ) return None def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType: """Try to infer a better signature type for TypedDict.get. This is used to get better type context for the second argument that depends on a TypedDict value type. """ signature = ctx.default_signature if ( isinstance(ctx.type, TypedDictType) and len(ctx.args) == 2 and len(ctx.args[0]) == 1 and isinstance(ctx.args[0][0], StrExpr) and len(signature.arg_types) == 2 and len(signature.variables) == 1 and len(ctx.args[1]) == 1 ): key = ctx.args[0][0].value value_type = get_proper_type(ctx.type.items.get(key)) ret_type = signature.ret_type if value_type: default_arg = ctx.args[1][0] if ( isinstance(value_type, TypedDictType) and isinstance(default_arg, DictExpr) and len(default_arg.items) == 0 ): # Caller has empty dict {} as default for typed dict. value_type = value_type.copy_modified(required_keys=set()) # Tweak the signature to include the value type as context. It's # only needed for type inference since there's a union with a type # variable that accepts everything. tv = signature.variables[0] assert isinstance(tv, TypeVarType) return signature.copy_modified( arg_types=[signature.arg_types[0], make_simplified_union([value_type, tv])], ret_type=ret_type, ) return signature def typed_dict_get_callback(ctx: MethodContext) -> Type: """Infer a precise return type for TypedDict.get with literal first argument.""" if ( isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) >= 1 and len(ctx.arg_types[0]) == 1 ): keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0]) if keys is None: return ctx.default_return_type output_types: list[Type] = [] for key in keys: value_type = get_proper_type(ctx.type.items.get(key)) if value_type is None: return ctx.default_return_type if len(ctx.arg_types) == 1: output_types.append(value_type) elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1: default_arg = ctx.args[1][0] if ( isinstance(default_arg, DictExpr) and len(default_arg.items) == 0 and isinstance(value_type, TypedDictType) ): # Special case '{}' as the default for a typed dict type. output_types.append(value_type.copy_modified(required_keys=set())) else: output_types.append(value_type) output_types.append(ctx.arg_types[1][0]) if len(ctx.arg_types) == 1: output_types.append(NoneType()) return make_simplified_union(output_types) return ctx.default_return_type def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType: """Try to infer a better signature type for TypedDict.pop. This is used to get better type context for the second argument that depends on a TypedDict value type. """ signature = ctx.default_signature str_type = ctx.api.named_generic_type("builtins.str", []) if ( isinstance(ctx.type, TypedDictType) and len(ctx.args) == 2 and len(ctx.args[0]) == 1 and isinstance(ctx.args[0][0], StrExpr) and len(signature.arg_types) == 2 and len(signature.variables) == 1 and len(ctx.args[1]) == 1 ): key = ctx.args[0][0].value value_type = ctx.type.items.get(key) if value_type: # Tweak the signature to include the value type as context. It's # only needed for type inference since there's a union with a type # variable that accepts everything. tv = signature.variables[0] assert isinstance(tv, TypeVarType) typ = make_simplified_union([value_type, tv]) return signature.copy_modified(arg_types=[str_type, typ], ret_type=typ) return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]]) def typed_dict_pop_callback(ctx: MethodContext) -> Type: """Type check and infer a precise return type for TypedDict.pop.""" if ( isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) >= 1 and len(ctx.arg_types[0]) == 1 ): key_expr = ctx.args[0][0] keys = try_getting_str_literals(key_expr, ctx.arg_types[0][0]) if keys is None: ctx.api.fail( message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, key_expr, code=codes.LITERAL_REQ, ) return AnyType(TypeOfAny.from_error) value_types = [] for key in keys: if key in ctx.type.required_keys or key in ctx.type.readonly_keys: ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, key_expr) value_type = ctx.type.items.get(key) if value_type: value_types.append(value_type) else: ctx.api.msg.typeddict_key_not_found(ctx.type, key, key_expr) return AnyType(TypeOfAny.from_error) if len(ctx.args[1]) == 0: return make_simplified_union(value_types) elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1: return make_simplified_union([*value_types, ctx.arg_types[1][0]]) return ctx.default_return_type def typed_dict_setdefault_signature_callback(ctx: MethodSigContext) -> CallableType: """Try to infer a better signature type for TypedDict.setdefault. This is used to get better type context for the second argument that depends on a TypedDict value type. """ signature = ctx.default_signature str_type = ctx.api.named_generic_type("builtins.str", []) if ( isinstance(ctx.type, TypedDictType) and len(ctx.args) == 2 and len(ctx.args[0]) == 1 and isinstance(ctx.args[0][0], StrExpr) and len(signature.arg_types) == 2 and len(ctx.args[1]) == 1 ): key = ctx.args[0][0].value value_type = ctx.type.items.get(key) if value_type: return signature.copy_modified(arg_types=[str_type, value_type]) return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]]) def typed_dict_setdefault_callback(ctx: MethodContext) -> Type: """Type check TypedDict.setdefault and infer a precise return type.""" if ( isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types[1]) == 1 ): key_expr = ctx.args[0][0] keys = try_getting_str_literals(key_expr, ctx.arg_types[0][0]) if keys is None: ctx.api.fail( message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, key_expr, code=codes.LITERAL_REQ, ) return AnyType(TypeOfAny.from_error) assigned_readonly_keys = ctx.type.readonly_keys & set(keys) if assigned_readonly_keys: ctx.api.msg.readonly_keys_mutated(assigned_readonly_keys, context=key_expr) default_type = ctx.arg_types[1][0] default_expr = ctx.args[1][0] value_types = [] for key in keys: value_type = ctx.type.items.get(key) if value_type is None: ctx.api.msg.typeddict_key_not_found(ctx.type, key, key_expr) return AnyType(TypeOfAny.from_error) # The signature_callback above can't always infer the right signature # (e.g. when the expression is a variable that happens to be a Literal str) # so we need to handle the check ourselves here and make sure the provided # default can be assigned to all key-value pairs we're updating. if not is_subtype(default_type, value_type): ctx.api.msg.typeddict_setdefault_arguments_inconsistent( default_type, value_type, default_expr ) return AnyType(TypeOfAny.from_error) value_types.append(value_type) return make_simplified_union(value_types) return ctx.default_return_type def typed_dict_delitem_callback(ctx: MethodContext) -> Type: """Type check TypedDict.__delitem__.""" if ( isinstance(ctx.type, TypedDictType) and len(ctx.arg_types) == 1 and len(ctx.arg_types[0]) == 1 ): key_expr = ctx.args[0][0] keys = try_getting_str_literals(key_expr, ctx.arg_types[0][0]) if keys is None: ctx.api.fail( message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, key_expr, code=codes.LITERAL_REQ, ) return AnyType(TypeOfAny.from_error) for key in keys: if key in ctx.type.required_keys or key in ctx.type.readonly_keys: ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, key_expr) elif key not in ctx.type.items: ctx.api.msg.typeddict_key_not_found(ctx.type, key, key_expr) return ctx.default_return_type _TP_DICT_MUTATING_METHODS: Final = frozenset({"update of TypedDict", "__ior__ of TypedDict"}) def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType: """Try to infer a better signature type for methods that update `TypedDict`. This includes: `TypedDict.update`, `TypedDict.__or__`, `TypedDict.__ror__`, and `TypedDict.__ior__`. """ signature = ctx.default_signature if isinstance(ctx.type, TypedDictType) and len(signature.arg_types) == 1: arg_type = get_proper_type(signature.arg_types[0]) if not isinstance(arg_type, TypedDictType): return signature arg_type = arg_type.as_anonymous() arg_type = arg_type.copy_modified(required_keys=set()) if ctx.args and ctx.args[0]: if signature.name in _TP_DICT_MUTATING_METHODS: # If we want to mutate this object in place, we need to set this flag, # it will trigger an extra check in TypedDict's checker. arg_type.to_be_mutated = True with ctx.api.msg.filter_errors( filter_errors=lambda name, info: info.code != codes.TYPEDDICT_READONLY_MUTATED, save_filtered_errors=True, ): inferred = get_proper_type( ctx.api.get_expression_type(ctx.args[0][0], type_context=arg_type) ) if arg_type.to_be_mutated: arg_type.to_be_mutated = False # Done! possible_tds = [] if isinstance(inferred, TypedDictType): possible_tds = [inferred] elif isinstance(inferred, UnionType): possible_tds = [ t for t in get_proper_types(inferred.relevant_items()) if isinstance(t, TypedDictType) ] items = [] for td in possible_tds: item = arg_type.copy_modified( required_keys=(arg_type.required_keys | td.required_keys) & arg_type.items.keys() ) if not ctx.api.options.extra_checks: item = item.copy_modified(item_names=list(td.items)) items.append(item) if items: arg_type = make_simplified_union(items) return signature.copy_modified(arg_types=[arg_type]) return signature def int_pow_callback(ctx: MethodContext) -> Type: """Infer a more precise return type for int.__pow__.""" # int.__pow__ has an optional modulo argument, # so we expect 2 argument positions if len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types[1]) == 0: arg = ctx.args[0][0] if isinstance(arg, IntExpr): exponent = arg.value elif isinstance(arg, UnaryExpr) and arg.op == "-" and isinstance(arg.expr, IntExpr): exponent = -arg.expr.value else: # Right operand not an int literal or a negated literal -- give up. return ctx.default_return_type if exponent >= 0: return ctx.api.named_generic_type("builtins.int", []) else: return ctx.api.named_generic_type("builtins.float", []) return ctx.default_return_type def int_neg_callback(ctx: MethodContext, multiplier: int = -1) -> Type: """Infer a more precise return type for int.__neg__ and int.__pos__. This is mainly used to infer the return type as LiteralType if the original underlying object is a LiteralType object. """ if isinstance(ctx.type, Instance) and ctx.type.last_known_value is not None: value = ctx.type.last_known_value.value fallback = ctx.type.last_known_value.fallback if isinstance(value, int): if is_literal_type_like(ctx.api.type_context[-1]): return LiteralType(value=multiplier * value, fallback=fallback) else: return ctx.type.copy_modified( last_known_value=LiteralType( value=multiplier * value, fallback=fallback, line=ctx.type.line, column=ctx.type.column, ) ) elif isinstance(ctx.type, LiteralType): value = ctx.type.value fallback = ctx.type.fallback if isinstance(value, int): return LiteralType(value=multiplier * value, fallback=fallback) return ctx.default_return_type def int_pos_callback(ctx: MethodContext) -> Type: """Infer a more precise return type for int.__pos__. This is identical to __neg__, except the value is not inverted. """ return int_neg_callback(ctx, +1) def tuple_mul_callback(ctx: MethodContext) -> Type: """Infer a more precise return type for tuple.__mul__ and tuple.__rmul__. This is used to return a specific sized tuple if multiplied by Literal int """ if not isinstance(ctx.type, TupleType): return ctx.default_return_type arg_type = get_proper_type(ctx.arg_types[0][0]) if isinstance(arg_type, Instance) and arg_type.last_known_value is not None: value = arg_type.last_known_value.value if isinstance(value, int): return ctx.type.copy_modified(items=ctx.type.items * value) elif isinstance(arg_type, LiteralType): value = arg_type.value if isinstance(value, int): return ctx.type.copy_modified(items=ctx.type.items * value) return ctx.default_return_type