from __future__ import annotations from typing import TYPE_CHECKING, Any, cast import pyarrow as pa import pyarrow.compute as pc from narwhals._arrow.series import ArrowSeries from narwhals._compliant import EagerExpr from narwhals._expression_parsing import evaluate_nodes, evaluate_output_names_and_aliases from narwhals._utils import ( Implementation, generate_temporary_column_name, not_implemented, ) from narwhals.functions import col as nw_col if TYPE_CHECKING: from collections.abc import Sequence from typing_extensions import Self from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.namespace import ArrowNamespace from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries from narwhals._utils import Version, _LimitedContext class ArrowExpr(EagerExpr["ArrowDataFrame", ArrowSeries]): _implementation: Implementation = Implementation.PYARROW def __init__( self, call: EvalSeries[ArrowDataFrame, ArrowSeries], *, evaluate_output_names: EvalNames[ArrowDataFrame], alias_output_names: AliasNames | None, version: Version, implementation: Implementation = Implementation.PYARROW, ) -> None: self._call = call self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._version = version @classmethod def from_column_names( cls: type[Self], evaluate_column_names: EvalNames[ArrowDataFrame], /, *, context: _LimitedContext, ) -> Self: def func(df: ArrowDataFrame) -> list[ArrowSeries]: try: return [ ArrowSeries( df.native[column_name], name=column_name, version=df._version ) for column_name in evaluate_column_names(df) ] except KeyError as e: if error := df._check_columns_exist(evaluate_column_names(df)): raise error from e raise return cls( func, evaluate_output_names=evaluate_column_names, alias_output_names=None, version=context._version, ) @classmethod def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self: def func(df: ArrowDataFrame) -> list[ArrowSeries]: tbl = df.native cols = df.columns return [ ArrowSeries.from_native(tbl[i], name=cols[i], context=df) for i in column_indices ] return cls( func, evaluate_output_names=cls._eval_names_indices(column_indices), alias_output_names=None, version=context._version, ) def __narwhals_namespace__(self) -> ArrowNamespace: from narwhals._arrow.namespace import ArrowNamespace return ArrowNamespace(version=self._version) def _reuse_series_extra_kwargs( self, *, returns_scalar: bool = False ) -> dict[str, Any]: return {"_return_py_scalar": False} if returns_scalar else {} def _over_without_partition_by(self, order_by: Sequence[str]) -> Self: # e.g. `nw.col('a').cum_sum().order_by(key)` # which we can always easily support, as it doesn't require grouping. assert order_by # noqa: S101 meta = self._metadata def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]: token = generate_temporary_column_name(8, df.columns) df = df.with_row_index(token, order_by=None).sort( *order_by, descending=False, nulls_last=False ) results = self(df.drop([token], strict=True)) if meta is not None and meta.is_scalar_like: # We need to broadcast the results to the original size, since # `over` is a length-preserving operation. size = len(df) return [s._with_native(pa.repeat(s.item(), size)) for s in results] # TODO(marco): is there a way to do this efficiently without # doing 2 sorts? Here we're sorting the dataframe and then # again calling `sort_indices`. `ArrowSeries.scatter` would also sort. sorting_indices = pc.sort_indices(df.get_column(token).native) return [s._with_native(s.native.take(sorting_indices)) for s in results] return self.__class__( func, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, ) def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self: if not partition_by: assert order_by # noqa: S101 return self._over_without_partition_by(order_by) # We have something like prev.leaf().over(...) (e.g. `nw.col('a').sum().over('b')`), where: # - `prev` must be elementwise (in the example: `nw.col('a')`) # - `leaf` must be an aggregation (in the example: `sum`) # # We first evaluate `prev` as-is, and then evaluate `leaf().over(...)`` by doing a `group_by`. meta = self._metadata if partition_by and ( not meta.current_node.kind.is_scalar_like or (meta.prev is not None and not meta.prev.is_elementwise) ): msg = ( "Only elementary aggregations are supported for `.over` in PyArrow backend " "when `partition_by` is specified.\n\n" "Please see: " "https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/" ) raise NotImplementedError(msg) nodes = list(reversed(list(self._metadata.iter_nodes_reversed()))) def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]: plx = self.__narwhals_namespace__() if meta.prev is not None: df = df.with_columns(cast("ArrowExpr", evaluate_nodes(nodes[:-1], plx))) _, aliases = evaluate_output_names_and_aliases(self, df, []) leaf_ce = cast( "ArrowExpr", nw_col(*aliases)._append_node(nodes[-1])._to_compliant_expr(plx), ) else: _, aliases = evaluate_output_names_and_aliases(self, df, []) leaf_ce = self if order_by: df = df.sort(*order_by, descending=False, nulls_last=False) if overlap := set(aliases).intersection(partition_by): # E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined, # we just don't support it yet. msg = ( f"Column names {overlap} appear in both expression output names and in `over` keys.\n" "This is not yet supported." ) raise NotImplementedError(msg) if not any( ca.null_count > 0 for ca in df.simple_select(*partition_by).native.columns ): tmp = df.group_by(partition_by, drop_null_keys=False).agg(leaf_ce) tmp = df.simple_select(*partition_by).join( tmp, how="left", left_on=partition_by, right_on=partition_by, suffix="_right", ) return [tmp.get_column(alias) for alias in aliases] if len(partition_by) == 1: plx = self.__narwhals_namespace__() tmp_name = generate_temporary_column_name(8, df.columns) dict_array = ( df.native.column(partition_by[0]) .dictionary_encode("encode") .combine_chunks() ) indices = dict_array.indices # type: ignore[attr-defined] indices_expr = plx._expr._from_series( plx._series.from_native(indices, context=plx) ) table_encoded = df.with_columns(indices_expr.alias(tmp_name)) windowed = table_encoded.group_by([tmp_name], drop_null_keys=False).agg( leaf_ce ) ret = ( table_encoded.simple_select(tmp_name) .join( windowed, left_on=[tmp_name], right_on=[tmp_name], how="inner", suffix="_right", ) .drop([tmp_name], strict=False) ) return [ret.get_column(alias) for alias in aliases] msg = "`over` with `partition_by` and multiple columns which contains null values is not yet supported for PyArrow" raise NotImplementedError(msg) return self.__class__( func, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, version=self._version, ) ewm_mean = not_implemented()