from __future__ import annotations from typing import TYPE_CHECKING, Any, Callable, ClassVar, cast import polars as pl from narwhals._polars.utils import ( BACKEND_VERSION, PolarsAnyNamespace, PolarsCatNamespace, PolarsDateTimeNamespace, PolarsListNamespace, PolarsStringNamespace, PolarsStructNamespace, extract_args_kwargs, extract_native, narwhals_to_native_dtype, ) from narwhals._utils import Implementation, no_default, requires if TYPE_CHECKING: from collections.abc import Sequence from typing_extensions import Self from narwhals._compliant.typing import Accessor from narwhals._expression_parsing import ExprMetadata from narwhals._polars.dataframe import Method from narwhals._polars.namespace import PolarsNamespace from narwhals._polars.series import PolarsSeries from narwhals._typing import NoDefault from narwhals._utils import Version from narwhals.typing import IntoDType, ModeKeepStrategy class PolarsExpr: # CompliantExpr _implementation: Implementation = Implementation.POLARS _version: Version _native_expr: pl.Expr _evaluate_output_names: Any _alias_output_names: Any __call__: Any @classmethod def _from_series(cls, series: PolarsSeries) -> Self: return cls(series.native, version=series._version) # type: ignore[arg-type] # CompliantExpr + builtin descriptor # TODO @dangotbanned: Remove in #2713 @classmethod def from_column_names(cls, *_: Any, **__: Any) -> Self: raise NotImplementedError @classmethod def from_column_indices(cls, *_: Any, **__: Any) -> Self: raise NotImplementedError def __narwhals_expr__(self) -> Self: # pragma: no cover return self def __narwhals_namespace__(self) -> PolarsNamespace: # pragma: no cover from narwhals._polars.namespace import PolarsNamespace return PolarsNamespace(version=self._version) def __init__(self, expr: pl.Expr, version: Version) -> None: self._native_expr = expr self._version = version @property def _backend_version(self) -> tuple[int, ...]: return self._implementation._backend_version() @property def native(self) -> pl.Expr: return self._native_expr def __repr__(self) -> str: # pragma: no cover return "PolarsExpr" def _with_native(self, expr: pl.Expr) -> Self: return self.__class__(expr, self._version) def broadcast(self) -> Self: # Let Polars do its thing. return self @property def _metadata(self) -> ExprMetadata: assert self._opt_metadata is not None # noqa: S101 return cast("ExprMetadata", self._opt_metadata) def __getattr__(self, attr: str) -> Any: def func(*args: Any, **kwargs: Any) -> Any: pos, kwds = extract_args_kwargs(args, kwargs) return self._with_native(getattr(self.native, attr)(*pos, **kwds)) return func def _renamed_min_periods(self, min_samples: int, /) -> dict[str, Any]: name = "min_periods" if self._backend_version < (1, 21, 0) else "min_samples" return {name: min_samples} def cast(self, dtype: IntoDType) -> Self: dtype_pl = narwhals_to_native_dtype(dtype, self._version) return self._with_native(self.native.cast(dtype_pl)) def clip_lower(self, lower_bound: PolarsExpr) -> Self: lower_native = extract_native(lower_bound) return self._with_native(self.native.clip(lower_native)) def clip_upper(self, upper_bound: PolarsExpr) -> Self: upper_native = extract_native(upper_bound) return self._with_native(self.native.clip(None, upper_native)) def ewm_mean( self, *, com: float | None, span: float | None, half_life: float | None, alpha: float | None, adjust: bool, min_samples: int, ignore_nulls: bool, ) -> Self: native = self.native.ewm_mean( com=com, span=span, half_life=half_life, alpha=alpha, adjust=adjust, ignore_nulls=ignore_nulls, **self._renamed_min_periods(min_samples), ) if self._backend_version < (1,): # pragma: no cover native = pl.when(~self.native.is_null()).then(native).otherwise(None) return self._with_native(native) def is_nan(self) -> Self: if self._backend_version >= (1, 18): native = self.native.is_nan() else: # pragma: no cover native = pl.when(self.native.is_not_null()).then(self.native.is_nan()) return self._with_native(native) def is_finite(self) -> Self: if self._backend_version >= (1, 18): native = self.native.is_finite() else: # pragma: no cover native = pl.when(self.native.is_not_null()).then(self.native.is_finite()) return self._with_native(native) def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self: # Use `pl.repeat(1, pl.len())` instead of `pl.lit(1)` to avoid issues for # non-numeric types: https://github.com/pola-rs/polars/issues/24756. pl_partition_by = partition_by or pl.repeat(1, pl.len()) if self._backend_version < (1, 9): if order_by: msg = "`order_by` in Polars requires version 1.10 or greater" raise NotImplementedError(msg) native = self.native.over(pl_partition_by) else: native = self.native.over(pl_partition_by, order_by=order_by or None) return self._with_native(native) @requires.backend_version((1,)) def rolling_var( self, window_size: int, *, min_samples: int, center: bool, ddof: int ) -> Self: kwds = self._renamed_min_periods(min_samples) native = self.native.rolling_var( window_size=window_size, center=center, ddof=ddof, **kwds ) return self._with_native(native) @requires.backend_version((1,)) def rolling_std( self, window_size: int, *, min_samples: int, center: bool, ddof: int ) -> Self: kwds = self._renamed_min_periods(min_samples) native = self.native.rolling_std( window_size=window_size, center=center, ddof=ddof, **kwds ) return self._with_native(native) def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: kwds = self._renamed_min_periods(min_samples) native = self.native.rolling_sum(window_size=window_size, center=center, **kwds) return self._with_native(native) def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self: kwds = self._renamed_min_periods(min_samples) native = self.native.rolling_mean(window_size=window_size, center=center, **kwds) return self._with_native(native) def map_batches( self, function: Callable[[Any], Any], return_dtype: IntoDType | None, *, returns_scalar: bool, ) -> Self: pl_version = self._backend_version return_dtype_pl = ( narwhals_to_native_dtype(return_dtype, self._version) if return_dtype is not None else None if pl_version < (1, 32) else pl.self_dtype() ) kwargs = {} if pl_version < (0, 20, 31) else {"returns_scalar": returns_scalar} native = self.native.map_batches(function, return_dtype_pl, **kwargs) return self._with_native(native) @requires.backend_version((1,)) def replace_strict( self, default: PolarsExpr | NoDefault, old: Sequence[Any], new: Sequence[Any], *, return_dtype: IntoDType | None, ) -> Self: return_dtype_pl = ( narwhals_to_native_dtype(return_dtype, self._version) if return_dtype else None ) extra_kwargs = ( {} if default is no_default else {"default": extract_native(default)} ) native = self.native.replace_strict( old, new, return_dtype=return_dtype_pl, **extra_kwargs ) return self._with_native(native) def __eq__(self, other: PolarsExpr) -> Self: # type: ignore[override] return self._with_native(self.native.__eq__(extract_native(other))) def __ne__(self, other: PolarsExpr) -> Self: # type: ignore[override] return self._with_native(self.native.__ne__(extract_native(other))) def __ge__(self, other: Any) -> Self: return self._with_native(self.native.__ge__(extract_native(other))) def __gt__(self, other: Any) -> Self: return self._with_native(self.native.__gt__(extract_native(other))) def __le__(self, other: Any) -> Self: return self._with_native(self.native.__le__(extract_native(other))) def __lt__(self, other: Any) -> Self: return self._with_native(self.native.__lt__(extract_native(other))) def __and__(self, other: PolarsExpr) -> Self: return self._with_native(self.native.__and__(extract_native(other))) def __or__(self, other: PolarsExpr) -> Self: return self._with_native(self.native.__or__(extract_native(other))) def __add__(self, other: Any) -> Self: return self._with_native(self.native.__add__(extract_native(other))) def __sub__(self, other: Any) -> Self: return self._with_native(self.native.__sub__(extract_native(other))) def __mul__(self, other: Any) -> Self: return self._with_native(self.native.__mul__(extract_native(other))) def __pow__(self, other: Any) -> Self: return self._with_native(self.native.__pow__(extract_native(other))) def __truediv__(self, other: Any) -> Self: return self._with_native(self.native.__truediv__(extract_native(other))) def __floordiv__(self, other: Any) -> Self: return self._with_native(self.native.__floordiv__(extract_native(other))) def __rfloordiv__(self, other: Any) -> Self: native = self.native result = native.__rfloordiv__(extract_native(other)) if self._backend_version < (1, 10, 0): # Polars 1.9.0 and earlier returns 0 for division by 0 in rfloordiv. result = pl.when(native != 0).then(result).otherwise(None) return self._with_native(result) def __mod__(self, other: Any) -> Self: return self._with_native(self.native.__mod__(extract_native(other))) def __invert__(self) -> Self: return self._with_native(self.native.__invert__()) def cum_count(self, *, reverse: bool) -> Self: return self._with_native(self.native.cum_count(reverse=reverse)) def mode(self, *, keep: ModeKeepStrategy) -> Self: result = self.native.mode() return self._with_native(result.first() if keep == "any" else result) @property def dt(self) -> PolarsExprDateTimeNamespace: return PolarsExprDateTimeNamespace(self) @property def str(self) -> PolarsExprStringNamespace: return PolarsExprStringNamespace(self) @property def cat(self) -> PolarsExprCatNamespace: return PolarsExprCatNamespace(self) @property def name(self) -> PolarsExprNameNamespace: return PolarsExprNameNamespace(self) @property def list(self) -> PolarsExprListNamespace: return PolarsExprListNamespace(self) @property def struct(self) -> PolarsExprStructNamespace: return PolarsExprStructNamespace(self) # Polars abs: Method[Self] all: Method[Self] any: Method[Self] alias: Method[Self] arg_max: Method[Self] arg_min: Method[Self] arg_true: Method[Self] ceil: Method[Self] count: Method[Self] cum_max: Method[Self] cum_min: Method[Self] cum_prod: Method[Self] cum_sum: Method[Self] diff: Method[Self] drop_nulls: Method[Self] exp: Method[Self] fill_null: Method[Self] fill_nan: Method[Self] first: Method[Self] floor: Method[Self] last: Method[Self] gather_every: Method[Self] head: Method[Self] is_between: Method[Self] is_duplicated: Method[Self] is_first_distinct: Method[Self] is_in: Method[Self] is_last_distinct: Method[Self] is_null: Method[Self] is_unique: Method[Self] kurtosis: Method[Self] len: Method[Self] log: Method[Self] max: Method[Self] mean: Method[Self] median: Method[Self] min: Method[Self] n_unique: Method[Self] null_count: Method[Self] quantile: Method[Self] rank: Method[Self] round: Method[Self] sample: Method[Self] shift: Method[Self] skew: Method[Self] sqrt: Method[Self] std: Method[Self] sum: Method[Self] sort: Method[Self] tail: Method[Self] unique: Method[Self] var: Method[Self] __rsub__: Method[Self] __rmod__: Method[Self] __rpow__: Method[Self] __rtruediv__: Method[Self] class PolarsExprNamespace(PolarsAnyNamespace[PolarsExpr, pl.Expr]): def __init__(self, expr: PolarsExpr) -> None: self._expr = expr @property def compliant(self) -> PolarsExpr: return self._expr @property def native(self) -> pl.Expr: return self._expr.native class PolarsExprDateTimeNamespace( PolarsExprNamespace, PolarsDateTimeNamespace[PolarsExpr, pl.Expr] ): ... class PolarsExprStringNamespace( PolarsExprNamespace, PolarsStringNamespace[PolarsExpr, pl.Expr] ): def to_titlecase(self) -> PolarsExpr: native_expr = self.native if BACKEND_VERSION < (1, 35): native_result = ( native_expr.str.to_lowercase() .str.extract_all(r"[a-z]*[^a-z]*") .list.eval(pl.element().str.to_titlecase()) .list.join("") ) else: # pragma: no cover native_result = native_expr.str.to_titlecase() return self.compliant._with_native(native_result) @requires.backend_version((0, 20, 5)) def zfill(self, width: int) -> PolarsExpr: backend_version = self.compliant._backend_version native_result = self.native.str.zfill(width) if backend_version <= (1, 30, 0): length = self.native.str.len_chars() less_than_width = length < width plus = "+" starts_with_plus = self.native.str.starts_with(plus) native_result = ( pl.when(starts_with_plus & less_than_width) .then( self.native.str.slice(1, length) .str.zfill(width - 1) .str.pad_start(width, plus) ) .otherwise(native_result) ) return self.compliant._with_native(native_result) def replace( self, value: PolarsExpr, pattern: str, *, literal: bool, n: int ) -> PolarsExpr: value_native = extract_native(value) return self.compliant._with_native( self.native.str.replace(pattern, value_native, literal=literal, n=n) ) def replace_all( self, value: PolarsExpr, pattern: str, *, literal: bool ) -> PolarsExpr: value_native = extract_native(value) return self.compliant._with_native( self.native.str.replace_all(pattern, value_native, literal=literal) ) class PolarsExprCatNamespace( PolarsExprNamespace, PolarsCatNamespace[PolarsExpr, pl.Expr] ): ... class PolarsExprNameNamespace(PolarsExprNamespace): _accessor: ClassVar[Accessor] = "name" keep: Method[PolarsExpr] map: Method[PolarsExpr] prefix: Method[PolarsExpr] suffix: Method[PolarsExpr] to_lowercase: Method[PolarsExpr] to_uppercase: Method[PolarsExpr] class PolarsExprListNamespace( PolarsExprNamespace, PolarsListNamespace[PolarsExpr, pl.Expr] ): def len(self) -> PolarsExpr: native_expr = self.native native_result = native_expr.list.len() if self.compliant._backend_version < (1, 16): # pragma: no cover native_result = ( pl.when(~native_expr.is_null()).then(native_result).cast(pl.UInt32()) ) elif self.compliant._backend_version < (1, 17): # pragma: no cover native_result = native_result.cast(pl.UInt32()) return self.compliant._with_native(native_result) def contains(self, item: Any) -> PolarsExpr: if self.compliant._backend_version < (1, 28): result: pl.Expr = pl.when(self.native.is_not_null()).then( self.native.list.contains(item) ) else: result = self.native.list.contains(item) return self.compliant._with_native(result) class PolarsExprStructNamespace( PolarsExprNamespace, PolarsStructNamespace[PolarsExpr, pl.Expr] ): ...