from __future__ import annotations from functools import wraps from inspect import Parameter from inspect import signature from typing import Any from typing import TYPE_CHECKING from typing import TypeVar import warnings from optuna._deprecated import _validate_two_version from optuna._experimental import _validate_version if TYPE_CHECKING: from collections.abc import Callable from collections.abc import Sequence from typing_extensions import ParamSpec _P = ParamSpec("_P") _T = TypeVar("_T") _DEPRECATION_WARNING_TEMPLATE = ( "Positional arguments {deprecated_positional_arg_names} in {func_name}() " "have been deprecated since v{d_ver}. " "They will be replaced with the corresponding keyword arguments in v{r_ver}, " "so please use the keyword specification instead. " "See https://github.com/optuna/optuna/releases/tag/v{d_ver} for details." ) def _get_positional_arg_names(func: "Callable[_P, _T]") -> list[str]: params = signature(func).parameters positional_arg_names = [ name for name, p in params.items() if p.default == Parameter.empty and p.kind == p.POSITIONAL_OR_KEYWORD ] return positional_arg_names def _infer_kwargs(previous_positional_arg_names: Sequence[str], *args: Any) -> dict[str, Any]: inferred_kwargs = {arg_name: val for val, arg_name in zip(args, previous_positional_arg_names)} return inferred_kwargs def convert_positional_args( *, previous_positional_arg_names: Sequence[str], deprecated_version: str, removed_version: str, warning_stacklevel: int = 2, ) -> "Callable[[Callable[_P, _T]], Callable[_P, _T]]": """Convert positional arguments to keyword arguments. Args: previous_positional_arg_names: List of names previously given as positional arguments. warning_stacklevel: Level of the stack trace where decorated function locates. deprecated_version: The version in which the use of positional arguments is deprecated. removed_version: The version in which the use of positional arguments will be removed. """ if deprecated_version is not None or removed_version is not None: if deprecated_version is None: raise ValueError( "deprecated_version must not be None when removed_version is specified." ) if removed_version is None: raise ValueError( "removed_version must not be None when deprecated_version is specified." ) _validate_version(deprecated_version) _validate_version(removed_version) _validate_two_version(deprecated_version, removed_version) def converter_decorator(func: "Callable[_P, _T]") -> "Callable[_P, _T]": assert set(previous_positional_arg_names).issubset(set(signature(func).parameters)), ( f"{set(previous_positional_arg_names)} is not a subset of" f" {set(signature(func).parameters)}" ) @wraps(func) def converter_wrapper(*args: Any, **kwargs: Any) -> "_T": warning_messages = [] positional_arg_names = _get_positional_arg_names(func) inferred_kwargs = _infer_kwargs(previous_positional_arg_names, *args) if len(inferred_kwargs) > len(positional_arg_names): expected_kwds = set(inferred_kwargs) - set(positional_arg_names) warning_messages.append( f"{func.__name__}() got {expected_kwds} as positional arguments " "but they were expected to be given as keyword arguments." ) if deprecated_version or removed_version: warning_messages.append( _DEPRECATION_WARNING_TEMPLATE.format( deprecated_positional_arg_names=previous_positional_arg_names, func_name=func.__name__, d_ver=deprecated_version, r_ver=removed_version, ) ) if warning_messages: warnings.warn( "\n".join(warning_messages), FutureWarning, stacklevel=warning_stacklevel ) if len(args) > len(previous_positional_arg_names): raise TypeError( f"{func.__name__}() takes {len(previous_positional_arg_names)} positional" f" arguments but {len(args)} were given." ) duplicated_kwds = set(kwargs).intersection(inferred_kwargs) if len(duplicated_kwds): # When specifying positional arguments that are not located at the end of args as # keyword arguments, raise TypeError as follows by imitating the Python standard # behavior raise TypeError( f"{func.__name__}() got multiple values for arguments {duplicated_kwds}." ) kwargs.update(inferred_kwargs) return func(**kwargs) # type: ignore[call-arg] return converter_wrapper return converter_decorator