import plotly.graph_objs as go
import plotly.io as pio
from collections import namedtuple, OrderedDict
from ._special_inputs import IdentityMap, Constant, Range
from .trendline_functions import ols, lowess, rolling, expanding, ewm
from _plotly_utils.basevalidators import ColorscaleValidator
from plotly.colors import qualitative, sequential
import math
from plotly._subplots import (
make_subplots,
_set_trace_grid_reference,
_subplot_type_for_trace_type,
)
import narwhals.stable.v1 as nw
# The reason to use narwhals.stable.v1 is to have a stable and perfectly
# backwards-compatible API, hence the confidence to not pin the Narwhals version exactly,
# allowing for multiple major libraries to have Narwhals as a dependency without
# forbidding users to install them all together due to dependency conflicts.
NO_COLOR = "px_no_color_constant"
trendline_functions = dict(
lowess=lowess, rolling=rolling, ewm=ewm, expanding=expanding, ols=ols
)
# Declare all supported attributes, across all plot types
direct_attrables = (
["base", "x", "y", "z", "a", "b", "c", "r", "theta", "size", "x_start", "x_end"]
+ ["hover_name", "text", "names", "values", "parents", "wide_cross"]
+ ["ids", "error_x", "error_x_minus", "error_y", "error_y_minus", "error_z"]
+ ["error_z_minus", "lat", "lon", "locations", "animation_group"]
)
array_attrables = ["dimensions", "custom_data", "hover_data", "path", "wide_variable"]
group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"]
renameable_group_attrables = [
"color", # renamed to marker.color or line.color in infer_config
"symbol", # renamed to marker.symbol in infer_config
"line_dash", # renamed to line.dash in infer_config
"pattern_shape", # renamed to marker.pattern.shape in infer_config
]
all_attrables = (
direct_attrables + array_attrables + group_attrables + renameable_group_attrables
)
cartesians = [go.Scatter, go.Scattergl, go.Bar, go.Funnel, go.Box, go.Violin]
cartesians += [go.Histogram, go.Histogram2d, go.Histogram2dContour]
class PxDefaults(object):
__slots__ = [
"template",
"width",
"height",
"color_discrete_sequence",
"color_discrete_map",
"color_continuous_scale",
"symbol_sequence",
"symbol_map",
"line_dash_sequence",
"line_dash_map",
"pattern_shape_sequence",
"pattern_shape_map",
"size_max",
"category_orders",
"labels",
]
def __init__(self):
self.reset()
def reset(self):
self.template = None
self.width = None
self.height = None
self.color_discrete_sequence = None
self.color_discrete_map = {}
self.color_continuous_scale = None
self.symbol_sequence = None
self.symbol_map = {}
self.line_dash_sequence = None
self.line_dash_map = {}
self.pattern_shape_sequence = None
self.pattern_shape_map = {}
self.size_max = 20
self.category_orders = {}
self.labels = {}
defaults = PxDefaults()
del PxDefaults
MAPBOX_TOKEN = None
def set_mapbox_access_token(token):
"""
Arguments:
token: A Mapbox token to be used in `plotly.express.scatter_mapbox` and \
`plotly.express.line_mapbox` figures. See \
https://docs.mapbox.com/help/how-mapbox-works/access-tokens/ for more details
"""
global MAPBOX_TOKEN
MAPBOX_TOKEN = token
def get_trendline_results(fig):
"""
Extracts fit statistics for trendlines (when applied to figures generated with
the `trendline` argument set to `"ols"`).
Arguments:
fig: the output of a `plotly.express` charting call
Returns:
A `pandas.DataFrame` with a column "px_fit_results" containing the `statsmodels`
results objects, along with columns identifying the subset of the data the
trendline was fit on.
"""
return fig._px_trendlines
Mapping = namedtuple(
"Mapping",
[
"show_in_trace_name",
"grouper",
"val_map",
"sequence",
"updater",
"variable",
"facet",
],
)
TraceSpec = namedtuple("TraceSpec", ["constructor", "attrs", "trace_patch", "marginal"])
def get_label(args, column):
try:
return args["labels"][column]
except Exception:
return column
def invert_label(args, column):
"""Invert mapping.
Find key corresponding to value column in dict args["labels"].
Returns `column` if the value does not exist.
"""
reversed_labels = {value: key for (key, value) in args["labels"].items()}
try:
return reversed_labels[column]
except Exception:
return column
def _is_continuous(df: nw.DataFrame, col_name: str) -> bool:
if nw.dependencies.is_pandas_like_dataframe(df_native := df.to_native()):
# fastpath for pandas: Narwhals' Series.dtype has a bit of overhead, as it
# tries to distinguish between true "object" columns, and "string" columns
# disguised as "object". But here, we deal with neither.
return df_native[col_name].dtype.kind in "ifc"
return df.get_column(col_name).dtype.is_numeric()
def _to_unix_epoch_seconds(s: nw.Series) -> nw.Series:
dtype = s.dtype
if dtype == nw.Date:
return s.dt.timestamp("ms") / 1_000
if dtype == nw.Datetime:
if dtype.time_unit in ("s", "ms"):
return s.dt.timestamp("ms") / 1_000
elif dtype.time_unit == "us":
return s.dt.timestamp("us") / 1_000_000
elif dtype.time_unit == "ns":
return s.dt.timestamp("ns") / 1_000_000_000
else:
msg = "Unexpected dtype, please report a bug"
raise ValueError(msg)
else:
msg = f"Expected Date or Datetime, got {dtype}"
raise TypeError(msg)
def _generate_temporary_column_name(n_bytes, columns) -> str:
"""Wraps of Narwhals generate_temporary_column_name to generate a token
which is guaranteed to not be in columns, nor in [col + token for col in columns]
"""
counter = 0
while True:
# This is guaranteed to not be in columns by Narwhals
token = nw.generate_temporary_column_name(n_bytes, columns=columns)
# Now check that it is not in the [col + token for col in columns] list
if token not in {f"{c}{token}" for c in columns}:
return token
counter += 1
if counter > 100:
msg = (
"Internal Error: Plotly was not able to generate a column name with "
f"{n_bytes=} and not in {columns}.\n"
"Please report this to "
"https://github.com/plotly/plotly.py/issues/new and we will try to "
"replicate and fix it."
)
raise AssertionError(msg)
def get_decorated_label(args, column, role):
original_label = label = get_label(args, column)
if "histfunc" in args and (
(role == "z")
or (role == "x" and "orientation" in args and args["orientation"] == "h")
or (role == "y" and "orientation" in args and args["orientation"] == "v")
):
histfunc = args["histfunc"] or "count"
if histfunc != "count":
label = "%s of %s" % (histfunc, label)
else:
label = "count"
if "histnorm" in args and args["histnorm"] is not None:
if label == "count":
label = args["histnorm"]
else:
histnorm = args["histnorm"]
if histfunc == "sum":
if histnorm == "probability":
label = "%s of %s" % ("fraction", label)
elif histnorm == "percent":
label = "%s of %s" % (histnorm, label)
else:
label = "%s weighted by %s" % (histnorm, original_label)
elif histnorm == "probability":
label = "%s of sum of %s" % ("fraction", label)
elif histnorm == "percent":
label = "%s of sum of %s" % ("percent", label)
else:
label = "%s of %s" % (histnorm, label)
if "barnorm" in args and args["barnorm"] is not None:
label = "%s (normalized as %s)" % (label, args["barnorm"])
return label
def make_mapping(args, variable):
if variable == "line_group" or variable == "animation_frame":
return Mapping(
show_in_trace_name=False,
grouper=args[variable],
val_map={},
sequence=[""],
variable=variable,
updater=(lambda trace, v: v),
facet=None,
)
if variable == "facet_row" or variable == "facet_col":
letter = "x" if variable == "facet_col" else "y"
return Mapping(
show_in_trace_name=False,
variable=letter,
grouper=args[variable],
val_map={},
sequence=[i for i in range(1, 1000)],
updater=(lambda trace, v: v),
facet="row" if variable == "facet_row" else "col",
)
(parent, variable, *other_variables) = variable.split(".")
vprefix = variable
arg_name = variable
if variable == "color":
vprefix = "color_discrete"
if variable == "dash":
arg_name = "line_dash"
vprefix = "line_dash"
if variable in ["pattern", "shape"]:
arg_name = "pattern_shape"
vprefix = "pattern_shape"
if args[vprefix + "_map"] == "identity":
val_map = IdentityMap()
else:
val_map = args[vprefix + "_map"].copy()
return Mapping(
show_in_trace_name=True,
variable=variable,
grouper=args[arg_name],
val_map=val_map,
sequence=args[vprefix + "_sequence"],
updater=lambda trace, v: trace.update(
{parent: {".".join([variable] + other_variables): v}}
),
facet=None,
)
def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
"""Populates a dict with arguments to update trace
Parameters
----------
args : dict
args to be used for the trace
trace_spec : NamedTuple
which kind of trace to be used (has constructor, marginal etc.
attributes)
trace_data : pandas DataFrame
data
mapping_labels : dict
to be used for hovertemplate
sizeref : float
marker sizeref
Returns
-------
trace_patch : dict
dict to be used to update trace
fit_results : dict
fit information to be used for trendlines
"""
trace_data: nw.DataFrame
df: nw.DataFrame = args["data_frame"]
if "line_close" in args and args["line_close"]:
trace_data = nw.concat([trace_data, trace_data.head(1)], how="vertical")
trace_patch = trace_spec.trace_patch.copy() or {}
fit_results = None
hover_header = ""
for attr_name in trace_spec.attrs:
attr_value = args[attr_name]
attr_label = get_decorated_label(args, attr_value, attr_name)
if attr_name == "dimensions":
dims = [
(name, trace_data.get_column(name))
for name in trace_data.columns
if ((not attr_value) or (name in attr_value))
and (trace_spec.constructor != go.Parcoords or _is_continuous(df, name))
and (
trace_spec.constructor != go.Parcats
or (attr_value is not None and name in attr_value)
or nw.to_py_scalar(df.get_column(name).n_unique())
<= args["dimensions_max_cardinality"]
)
]
trace_patch["dimensions"] = [
dict(label=get_label(args, name), values=column)
for (name, column) in dims
]
if trace_spec.constructor == go.Splom:
for d in trace_patch["dimensions"]:
d["axis"] = dict(matches=True)
mapping_labels["%{xaxis.title.text}"] = "%{x}"
mapping_labels["%{yaxis.title.text}"] = "%{y}"
elif attr_value is not None:
if attr_name == "size":
if "marker" not in trace_patch:
trace_patch["marker"] = dict()
trace_patch["marker"]["size"] = trace_data.get_column(attr_value)
trace_patch["marker"]["sizemode"] = "area"
trace_patch["marker"]["sizeref"] = sizeref
mapping_labels[attr_label] = "%{marker.size}"
elif attr_name == "marginal_x":
if trace_spec.constructor == go.Histogram:
mapping_labels["count"] = "%{y}"
elif attr_name == "marginal_y":
if trace_spec.constructor == go.Histogram:
mapping_labels["count"] = "%{x}"
elif attr_name == "trendline":
if (
args["x"]
and args["y"]
and len(
trace_data.select(nw.col(args["x"], args["y"])).drop_nulls()
)
> 1
):
# sorting is bad but trace_specs with "trendline" have no other attrs
sorted_trace_data = trace_data.sort(by=args["x"], nulls_last=True)
y = sorted_trace_data.get_column(args["y"])
x = sorted_trace_data.get_column(args["x"])
if x.dtype == nw.Datetime or x.dtype == nw.Date:
# convert to unix epoch seconds
x = _to_unix_epoch_seconds(x)
elif not x.dtype.is_numeric():
try:
x = x.cast(nw.Float64())
except ValueError:
raise ValueError(
"Could not convert value of 'x' ('%s') into a numeric type. "
"If 'x' contains stringified dates, please convert to a datetime column."
% args["x"]
)
if not y.dtype.is_numeric():
try:
y = y.cast(nw.Float64())
except ValueError:
raise ValueError(
"Could not convert value of 'y' into a numeric type."
)
# preserve original values of "x" in case they're dates
# otherwise numpy/pandas can mess with the timezones
# NB this means trendline functions must output one-to-one with the input series
# i.e. we can't do resampling, because then the X values might not line up!
non_missing = ~(x.is_null() | y.is_null())
trace_patch["x"] = sorted_trace_data.filter(non_missing).get_column(
args["x"]
)
if (
trace_patch["x"].dtype == nw.Datetime
and trace_patch["x"].dtype.time_zone is not None
):
# Remove time zone so that local time is displayed
trace_patch["x"] = (
trace_patch["x"].dt.replace_time_zone(None).to_numpy()
)
else:
trace_patch["x"] = trace_patch["x"].to_numpy()
trendline_function = trendline_functions[attr_value]
y_out, hover_header, fit_results = trendline_function(
args["trendline_options"],
sorted_trace_data.get_column(args["x"]), # narwhals series
x.to_numpy(), # numpy array
y.to_numpy(), # numpy array
args["x"],
args["y"],
non_missing.to_numpy(), # numpy array
)
assert len(y_out) == len(trace_patch["x"]), (
"missing-data-handling failure in trendline code"
)
trace_patch["y"] = y_out
mapping_labels[get_label(args, args["x"])] = "%{x}"
mapping_labels[get_label(args, args["y"])] = "%{y} (trend)"
elif attr_name.startswith("error"):
error_xy = attr_name[:7]
arr = "arrayminus" if attr_name.endswith("minus") else "array"
if error_xy not in trace_patch:
trace_patch[error_xy] = {}
trace_patch[error_xy][arr] = trace_data.get_column(attr_value)
elif attr_name == "custom_data":
if len(attr_value) > 0:
# here we store a data frame in customdata, and it's serialized
# as a list of row lists, which is what we want
trace_patch["customdata"] = trace_data.select(nw.col(attr_value))
elif attr_name == "hover_name":
if trace_spec.constructor not in [
go.Histogram,
go.Histogram2d,
go.Histogram2dContour,
]:
trace_patch["hovertext"] = trace_data.get_column(attr_value)
if hover_header == "":
hover_header = "%{hovertext}
"
elif attr_name == "hover_data":
if trace_spec.constructor not in [
go.Histogram,
go.Histogram2d,
go.Histogram2dContour,
]:
hover_is_dict = isinstance(attr_value, dict)
customdata_cols = args.get("custom_data") or []
for col in attr_value:
if hover_is_dict and not attr_value[col]:
continue
if col in [
args.get("x"),
args.get("y"),
args.get("z"),
args.get("base"),
]:
continue
try:
position = args["custom_data"].index(col)
except (ValueError, AttributeError, KeyError):
position = len(customdata_cols)
customdata_cols.append(col)
attr_label_col = get_decorated_label(args, col, None)
mapping_labels[attr_label_col] = "%%{customdata[%d]}" % (
position
)
if len(customdata_cols) > 0:
# here we store a data frame in customdata, and it's serialized
# as a list of row lists, which is what we want
# dict.fromkeys(customdata_cols) allows to deduplicate column
# names, yet maintaining the original order.
trace_patch["customdata"] = trace_data.select(
*[nw.col(c) for c in dict.fromkeys(customdata_cols)]
)
elif attr_name == "color":
if trace_spec.constructor in [
go.Choropleth,
go.Choroplethmap,
go.Choroplethmapbox,
]:
trace_patch["z"] = trace_data.get_column(attr_value)
trace_patch["coloraxis"] = "coloraxis1"
mapping_labels[attr_label] = "%{z}"
elif trace_spec.constructor in [
go.Sunburst,
go.Treemap,
go.Icicle,
go.Pie,
go.Funnelarea,
]:
if "marker" not in trace_patch:
trace_patch["marker"] = dict()
if args.get("color_is_continuous"):
trace_patch["marker"]["colors"] = trace_data.get_column(
attr_value
)
trace_patch["marker"]["coloraxis"] = "coloraxis1"
mapping_labels[attr_label] = "%{color}"
else:
trace_patch["marker"]["colors"] = []
if args["color_discrete_map"] is not None:
mapping = args["color_discrete_map"].copy()
else:
mapping = {}
for cat in trace_data.get_column(attr_value).to_list():
# although trace_data.get_column(attr_value) is a Narwhals
# Series, which is an iterable, explicitly calling a to_list()
# makes sure that the elements we loop over are python objects
# in all cases, since depending on the backend this may not be
# the case (e.g. PyArrow)
if mapping.get(cat) is None:
mapping[cat] = args["color_discrete_sequence"][
len(mapping) % len(args["color_discrete_sequence"])
]
trace_patch["marker"]["colors"].append(mapping[cat])
else:
colorable = "marker"
if trace_spec.constructor in [go.Parcats, go.Parcoords]:
colorable = "line"
if colorable not in trace_patch:
trace_patch[colorable] = dict()
trace_patch[colorable]["color"] = trace_data.get_column(attr_value)
trace_patch[colorable]["coloraxis"] = "coloraxis1"
mapping_labels[attr_label] = "%%{%s.color}" % colorable
elif attr_name == "animation_group":
trace_patch["ids"] = trace_data.get_column(attr_value)
elif attr_name == "locations":
trace_patch[attr_name] = trace_data.get_column(attr_value)
mapping_labels[attr_label] = "%{location}"
elif attr_name == "values":
trace_patch[attr_name] = trace_data.get_column(attr_value)
_label = "value" if attr_label == "values" else attr_label
mapping_labels[_label] = "%{value}"
elif attr_name == "parents":
trace_patch[attr_name] = trace_data.get_column(attr_value)
_label = "parent" if attr_label == "parents" else attr_label
mapping_labels[_label] = "%{parent}"
elif attr_name == "ids":
trace_patch[attr_name] = trace_data.get_column(attr_value)
_label = "id" if attr_label == "ids" else attr_label
mapping_labels[_label] = "%{id}"
elif attr_name == "names":
if trace_spec.constructor in [
go.Sunburst,
go.Treemap,
go.Icicle,
go.Pie,
go.Funnelarea,
]:
trace_patch["labels"] = trace_data.get_column(attr_value)
_label = "label" if attr_label == "names" else attr_label
mapping_labels[_label] = "%{label}"
else:
trace_patch[attr_name] = trace_data.get_column(attr_value)
else:
trace_patch[attr_name] = trace_data.get_column(attr_value)
mapping_labels[attr_label] = "%%{%s}" % attr_name
elif (trace_spec.constructor == go.Histogram and attr_name in ["x", "y"]) or (
trace_spec.constructor in [go.Histogram2d, go.Histogram2dContour]
and attr_name == "z"
):
# ensure that stuff like "count" gets into the hoverlabel
mapping_labels[attr_label] = "%%{%s}" % attr_name
if trace_spec.constructor not in [go.Parcoords, go.Parcats]:
# Modify mapping_labels according to hover_data keys
# if hover_data is a dict
mapping_labels_copy = OrderedDict(mapping_labels)
if args["hover_data"] and isinstance(args["hover_data"], dict):
for k, v in mapping_labels.items():
# We need to invert the mapping here
k_args = invert_label(args, k)
if k_args in args["hover_data"]:
formatter = args["hover_data"][k_args][0]
if formatter:
if isinstance(formatter, str):
mapping_labels_copy[k] = v.replace("}", "%s}" % formatter)
else:
_ = mapping_labels_copy.pop(k)
hover_lines = [k + "=" + v for k, v in mapping_labels_copy.items()]
trace_patch["hovertemplate"] = hover_header + "
".join(hover_lines)
trace_patch["hovertemplate"] += ""
return trace_patch, fit_results
def configure_axes(args, constructor, fig, orders):
configurators = {
go.Scatter3d: configure_3d_axes,
go.Scatterternary: configure_ternary_axes,
go.Scatterpolar: configure_polar_axes,
go.Scatterpolargl: configure_polar_axes,
go.Barpolar: configure_polar_axes,
go.Scattermap: configure_map,
go.Choroplethmap: configure_map,
go.Densitymap: configure_map,
go.Scattermapbox: configure_mapbox,
go.Choroplethmapbox: configure_mapbox,
go.Densitymapbox: configure_mapbox,
go.Scattergeo: configure_geo,
go.Choropleth: configure_geo,
}
for c in cartesians:
configurators[c] = configure_cartesian_axes
if constructor in configurators:
configurators[constructor](args, fig, orders)
def set_cartesian_axis_opts(args, axis, letter, orders):
log_key = "log_" + letter
range_key = "range_" + letter
if log_key in args and args[log_key]:
axis["type"] = "log"
if range_key in args and args[range_key]:
axis["range"] = [math.log(r, 10) for r in args[range_key]]
elif range_key in args and args[range_key]:
axis["range"] = args[range_key]
if args[letter] in orders:
axis["categoryorder"] = "array"
axis["categoryarray"] = (
orders[args[letter]]
if isinstance(axis, go.layout.XAxis)
else list(reversed(orders[args[letter]])) # top down for Y axis
)
def configure_cartesian_marginal_axes(args, fig, orders):
nrows = len(fig._grid_ref)
ncols = len(fig._grid_ref[0])
# Set y-axis titles and axis options in the left-most column
for yaxis in fig.select_yaxes(col=1):
set_cartesian_axis_opts(args, yaxis, "y", orders)
# Set x-axis titles and axis options in the bottom-most row
for xaxis in fig.select_xaxes(row=1):
set_cartesian_axis_opts(args, xaxis, "x", orders)
# Configure axis ticks on marginal subplots
if args["marginal_x"]:
fig.update_yaxes(
showticklabels=False, showline=False, ticks="", range=None, row=nrows
)
if args["template"].layout.yaxis.showgrid is None:
fig.update_yaxes(showgrid=args["marginal_x"] == "histogram", row=nrows)
if args["template"].layout.xaxis.showgrid is None:
fig.update_xaxes(showgrid=True, row=nrows)
if args["marginal_y"]:
fig.update_xaxes(
showticklabels=False, showline=False, ticks="", range=None, col=ncols
)
if args["template"].layout.xaxis.showgrid is None:
fig.update_xaxes(showgrid=args["marginal_y"] == "histogram", col=ncols)
if args["template"].layout.yaxis.showgrid is None:
fig.update_yaxes(showgrid=True, col=ncols)
# Add axis titles to non-marginal subplots
y_title = get_decorated_label(args, args["y"], "y")
if args["marginal_x"]:
fig.update_yaxes(title_text=y_title, row=1, col=1)
else:
for row in range(1, nrows + 1):
fig.update_yaxes(title_text=y_title, row=row, col=1)
x_title = get_decorated_label(args, args["x"], "x")
if args["marginal_y"]:
fig.update_xaxes(title_text=x_title, row=1, col=1)
else:
for col in range(1, ncols + 1):
fig.update_xaxes(title_text=x_title, row=1, col=col)
# Configure axis type across all x-axes
if "log_x" in args and args["log_x"]:
fig.update_xaxes(type="log")
# Configure axis type across all y-axes
if "log_y" in args and args["log_y"]:
fig.update_yaxes(type="log")
# Configure matching and axis type for marginal y-axes
matches_y = "y" + str(ncols + 1)
if args["marginal_x"]:
for row in range(2, nrows + 1, 2):
fig.update_yaxes(matches=matches_y, type=None, row=row)
if args["marginal_y"]:
for col in range(2, ncols + 1, 2):
fig.update_xaxes(matches="x2", type=None, col=col)
def configure_cartesian_axes(args, fig, orders):
if ("marginal_x" in args and args["marginal_x"]) or (
"marginal_y" in args and args["marginal_y"]
):
configure_cartesian_marginal_axes(args, fig, orders)
return
# Set y-axis titles and axis options in the left-most column
y_title = get_decorated_label(args, args["y"], "y")
for yaxis in fig.select_yaxes(col=1):
yaxis.update(title_text=y_title)
set_cartesian_axis_opts(args, yaxis, "y", orders)
# Set x-axis titles and axis options in the bottom-most row
x_title = get_decorated_label(args, args["x"], "x")
for xaxis in fig.select_xaxes(row=1):
if "is_timeline" not in args:
xaxis.update(title_text=x_title)
set_cartesian_axis_opts(args, xaxis, "x", orders)
# Configure axis type across all x-axes
if "log_x" in args and args["log_x"]:
fig.update_xaxes(type="log")
# Configure axis type across all y-axes
if "log_y" in args and args["log_y"]:
fig.update_yaxes(type="log")
if "is_timeline" in args:
fig.update_xaxes(type="date")
if "ecdfmode" in args:
if args["orientation"] == "v":
fig.update_yaxes(rangemode="tozero")
else:
fig.update_xaxes(rangemode="tozero")
def configure_ternary_axes(args, fig, orders):
fig.update_ternaries(
aaxis=dict(title_text=get_label(args, args["a"])),
baxis=dict(title_text=get_label(args, args["b"])),
caxis=dict(title_text=get_label(args, args["c"])),
)
def configure_polar_axes(args, fig, orders):
patch = dict(
angularaxis=dict(direction=args["direction"], rotation=args["start_angle"]),
radialaxis=dict(),
)
for var, axis in [("r", "radialaxis"), ("theta", "angularaxis")]:
if args[var] in orders:
patch[axis]["categoryorder"] = "array"
patch[axis]["categoryarray"] = orders[args[var]]
radialaxis = patch["radialaxis"]
if args["log_r"]:
radialaxis["type"] = "log"
if args["range_r"]:
radialaxis["range"] = [math.log(x, 10) for x in args["range_r"]]
else:
if args["range_r"]:
radialaxis["range"] = args["range_r"]
if args["range_theta"]:
patch["sector"] = args["range_theta"]
fig.update_polars(patch)
def configure_3d_axes(args, fig, orders):
patch = dict(
xaxis=dict(title_text=get_label(args, args["x"])),
yaxis=dict(title_text=get_label(args, args["y"])),
zaxis=dict(title_text=get_label(args, args["z"])),
)
for letter in ["x", "y", "z"]:
axis = patch[letter + "axis"]
if args["log_" + letter]:
axis["type"] = "log"
if args["range_" + letter]:
axis["range"] = [math.log(x, 10) for x in args["range_" + letter]]
else:
if args["range_" + letter]:
axis["range"] = args["range_" + letter]
if args[letter] in orders:
axis["categoryorder"] = "array"
axis["categoryarray"] = orders[args[letter]]
fig.update_scenes(patch)
def configure_mapbox(args, fig, orders):
center = args["center"]
if not center and "lat" in args and "lon" in args:
center = dict(
lat=args["data_frame"][args["lat"]].mean(),
lon=args["data_frame"][args["lon"]].mean(),
)
fig.update_mapboxes(
accesstoken=MAPBOX_TOKEN,
center=center,
zoom=args["zoom"],
style=args["mapbox_style"],
)
def configure_map(args, fig, orders):
center = args["center"]
if not center and "lat" in args and "lon" in args:
center = dict(
lat=args["data_frame"][args["lat"]].mean(),
lon=args["data_frame"][args["lon"]].mean(),
)
fig.update_maps(
center=center,
zoom=args["zoom"],
style=args["map_style"],
)
def configure_geo(args, fig, orders):
fig.update_geos(
center=args["center"],
scope=args["scope"],
fitbounds=args["fitbounds"],
visible=args["basemap_visible"],
projection=dict(type=args["projection"]),
)
def configure_animation_controls(args, constructor, fig):
def frame_args(duration):
return {
"frame": {"duration": duration, "redraw": constructor != go.Scatter},
"mode": "immediate",
"fromcurrent": True,
"transition": {"duration": duration, "easing": "linear"},
}
if "animation_frame" in args and args["animation_frame"] and len(fig.frames) > 1:
fig.layout.updatemenus = [
{
"buttons": [
{
"args": [None, frame_args(500)],
"label": "▶",
"method": "animate",
},
{
"args": [[None], frame_args(0)],
"label": "◼",
"method": "animate",
},
],
"direction": "left",
"pad": {"r": 10, "t": 70},
"showactive": False,
"type": "buttons",
"x": 0.1,
"xanchor": "right",
"y": 0,
"yanchor": "top",
}
]
fig.layout.sliders = [
{
"active": 0,
"yanchor": "top",
"xanchor": "left",
"currentvalue": {
"prefix": get_label(args, args["animation_frame"]) + "="
},
"pad": {"b": 10, "t": 60},
"len": 0.9,
"x": 0.1,
"y": 0,
"steps": [
{
"args": [[f.name], frame_args(0)],
"label": f.name,
"method": "animate",
}
for f in fig.frames
],
}
]
def make_trace_spec(args, constructor, attrs, trace_patch):
if constructor in [go.Scatter, go.Scatterpolar]:
if "render_mode" in args and (
args["render_mode"] == "webgl"
or (
args["render_mode"] == "auto"
and len(args["data_frame"]) > 1000
and args.get("line_shape") != "spline"
and args["animation_frame"] is None
)
):
if constructor == go.Scatter:
constructor = go.Scattergl
if "orientation" in trace_patch:
del trace_patch["orientation"]
else:
constructor = go.Scatterpolargl
# Create base trace specification
result = [TraceSpec(constructor, attrs, trace_patch, None)]
# Add marginal trace specifications
for letter in ["x", "y"]:
if "marginal_" + letter in args and args["marginal_" + letter]:
trace_spec = None
axis_map = dict(
xaxis="x1" if letter == "x" else "x2",
yaxis="y1" if letter == "y" else "y2",
)
if args["marginal_" + letter] == "histogram":
trace_spec = TraceSpec(
constructor=go.Histogram,
attrs=[letter, "marginal_" + letter],
trace_patch=dict(opacity=0.5, bingroup=letter, **axis_map),
marginal=letter,
)
elif args["marginal_" + letter] == "violin":
trace_spec = TraceSpec(
constructor=go.Violin,
attrs=[letter, "hover_name", "hover_data"],
trace_patch=dict(scalegroup=letter),
marginal=letter,
)
elif args["marginal_" + letter] == "box":
trace_spec = TraceSpec(
constructor=go.Box,
attrs=[letter, "hover_name", "hover_data"],
trace_patch=dict(notched=True),
marginal=letter,
)
elif args["marginal_" + letter] == "rug":
symbols = {"x": "line-ns-open", "y": "line-ew-open"}
trace_spec = TraceSpec(
constructor=go.Box,
attrs=[letter, "hover_name", "hover_data"],
trace_patch=dict(
fillcolor="rgba(255,255,255,0)",
line={"color": "rgba(255,255,255,0)"},
boxpoints="all",
jitter=0,
hoveron="points",
marker={"symbol": symbols[letter]},
),
marginal=letter,
)
if "color" in attrs or "color" not in args:
if "marker" not in trace_spec.trace_patch:
trace_spec.trace_patch["marker"] = dict()
first_default_color = args["color_continuous_scale"][0]
trace_spec.trace_patch["marker"]["color"] = first_default_color
result.append(trace_spec)
# Add trendline trace specifications
if args.get("trendline") and args.get("trendline_scope", "trace") == "trace":
result.append(make_trendline_spec(args, constructor))
return result
def make_trendline_spec(args, constructor):
trace_spec = TraceSpec(
constructor=(
go.Scattergl
if constructor == go.Scattergl # could be contour
else go.Scatter
),
attrs=["trendline"],
trace_patch=dict(mode="lines"),
marginal=None,
)
if args["trendline_color_override"]:
trace_spec.trace_patch["line"] = dict(color=args["trendline_color_override"])
return trace_spec
def one_group(x):
return ""
def apply_default_cascade(args):
# first we apply px.defaults to unspecified args
for param in defaults.__slots__:
if param in args and args[param] is None:
args[param] = getattr(defaults, param)
# load the default template if set, otherwise "plotly"
if args["template"] is None:
if pio.templates.default is not None:
args["template"] = pio.templates.default
else:
args["template"] = "plotly"
try:
# retrieve the actual template if we were given a name
args["template"] = pio.templates[args["template"]]
except Exception:
# otherwise try to build a real template
args["template"] = go.layout.Template(args["template"])
# if colors not set explicitly or in px.defaults, defer to a template
# if the template doesn't have one, we set some final fallback defaults
if "color_continuous_scale" in args:
if (
args["color_continuous_scale"] is None
and args["template"].layout.colorscale.sequential
):
args["color_continuous_scale"] = [
x[1] for x in args["template"].layout.colorscale.sequential
]
if args["color_continuous_scale"] is None:
args["color_continuous_scale"] = sequential.Viridis
if "color_discrete_sequence" in args:
if args["color_discrete_sequence"] is None and args["template"].layout.colorway:
args["color_discrete_sequence"] = args["template"].layout.colorway
if args["color_discrete_sequence"] is None:
args["color_discrete_sequence"] = qualitative.D3
# if symbol_sequence/line_dash_sequence not set explicitly or in px.defaults,
# see if we can defer to template. If not, set reasonable defaults
if "symbol_sequence" in args:
if args["symbol_sequence"] is None and args["template"].data.scatter:
args["symbol_sequence"] = [
scatter.marker.symbol for scatter in args["template"].data.scatter
]
if not args["symbol_sequence"] or not any(args["symbol_sequence"]):
args["symbol_sequence"] = ["circle", "diamond", "square", "x", "cross"]
if "line_dash_sequence" in args:
if args["line_dash_sequence"] is None and args["template"].data.scatter:
args["line_dash_sequence"] = [
scatter.line.dash for scatter in args["template"].data.scatter
]
if not args["line_dash_sequence"] or not any(args["line_dash_sequence"]):
args["line_dash_sequence"] = [
"solid",
"dot",
"dash",
"longdash",
"dashdot",
"longdashdot",
]
if "pattern_shape_sequence" in args:
if args["pattern_shape_sequence"] is None and args["template"].data.bar:
args["pattern_shape_sequence"] = [
bar.marker.pattern.shape for bar in args["template"].data.bar
]
if not args["pattern_shape_sequence"] or not any(
args["pattern_shape_sequence"]
):
args["pattern_shape_sequence"] = ["", "/", "\\", "x", "+", "."]
def _check_name_not_reserved(field_name, reserved_names):
if field_name not in reserved_names:
return field_name
else:
raise NameError(
"A name conflict was encountered for argument '%s'. "
"A column or index with name '%s' is ambiguous." % (field_name, field_name)
)
def _get_reserved_col_names(args):
"""
This function builds a list of columns of the data_frame argument used
as arguments, either as str/int arguments or given as columns
(pandas series type).
"""
df: nw.DataFrame = args["data_frame"]
reserved_names = set()
for field in args:
if field not in all_attrables:
continue
names = args[field] if field in array_attrables else [args[field]]
if names is None:
continue
for arg in names:
if arg is None:
continue
elif isinstance(arg, str): # no need to add ints since kw arg are not ints
reserved_names.add(arg)
elif nw.dependencies.is_into_series(arg):
arg_series = nw.from_native(arg, series_only=True)
arg_name = arg_series.name
if arg_name and arg_name in df.columns:
in_df = (arg_series == df.get_column(arg_name)).all()
if in_df:
reserved_names.add(arg_name)
elif arg is nw.maybe_get_index(df) and arg.name is not None:
reserved_names.add(arg.name)
return reserved_names
def _is_col_list(columns, arg, is_pd_like, native_namespace):
"""Returns True if arg looks like it's a list of columns or references to columns
in df_input, and False otherwise (in which case it's assumed to be a single column
or reference to a column).
"""
if arg is None or isinstance(arg, str) or isinstance(arg, int):
return False
if is_pd_like and isinstance(arg, native_namespace.MultiIndex):
return False # just to keep existing behaviour for now
try:
iter(arg)
except TypeError:
return False # not iterable
for c in arg:
if isinstance(c, str) or isinstance(c, int):
if columns is None or c not in columns:
return False
else:
try:
iter(c)
except TypeError:
return False # not iterable
return True
def _isinstance_listlike(x):
"""Returns True if x is an iterable which can be transformed into a pandas Series,
False for the other types of possible values of a `hover_data` dict.
A tuple of length 2 is a special case corresponding to a (format, data) tuple.
"""
if (
isinstance(x, str)
or (isinstance(x, tuple) and len(x) == 2)
or isinstance(x, bool)
or x is None
):
return False
else:
return True
def _escape_col_name(columns, col_name, extra):
if columns is None:
return col_name
while col_name in columns or col_name in extra:
col_name = "_" + col_name
return col_name
def to_named_series(x, name=None, native_namespace=None):
"""Assuming x is list-like or even an existing Series, returns a new Series named `name`."""
# With `pass_through=True`, the original object will be returned if unable to convert
# to a Narwhals Series.
x = nw.from_native(x, series_only=True, pass_through=True)
if isinstance(x, nw.Series):
return x.rename(name)
elif native_namespace is not None:
return nw.new_series(name=name, values=x, native_namespace=native_namespace)
else:
try:
import pandas as pd
return nw.new_series(name=name, values=x, native_namespace=pd)
except ImportError:
msg = "Pandas installation is required if no dataframe is provided."
raise NotImplementedError(msg)
def process_args_into_dataframe(
args, wide_mode, var_name, value_name, is_pd_like, native_namespace
):
"""
After this function runs, the `all_attrables` keys of `args` all contain only
references to columns of `df_output`. This function handles the extraction of data
from `args["attrable"]` and column-name-generation as appropriate, and adds the
data to `df_output` and then replaces `args["attrable"]` with the appropriate
reference.
"""
df_input: nw.DataFrame | None = args["data_frame"]
df_provided = df_input is not None
# we use a dict instead of a dataframe directly so that it doesn't cause
# PerformanceWarning by pandas by repeatedly setting the columns.
# a dict is used instead of a list as the columns needs to be overwritten.
df_output = {}
constants = {}
ranges = []
wide_id_vars = set()
reserved_names = _get_reserved_col_names(args) if df_provided else set()
# Case of functions with a "dimensions" kw: scatter_matrix, parcats, parcoords
if "dimensions" in args and args["dimensions"] is None:
if not df_provided:
raise ValueError(
"No data were provided. Please provide data either with the `data_frame` or with the `dimensions` argument."
)
else:
df_output = {col: df_input.get_column(col) for col in df_input.columns}
# hover_data is a dict
hover_data_is_dict = (
"hover_data" in args
and args["hover_data"]
and isinstance(args["hover_data"], dict)
)
# If dict, convert all values of hover_data to tuples to simplify processing
if hover_data_is_dict:
for k in args["hover_data"]:
if _isinstance_listlike(args["hover_data"][k]):
args["hover_data"][k] = (True, args["hover_data"][k])
if not isinstance(args["hover_data"][k], tuple):
args["hover_data"][k] = (args["hover_data"][k], None)
if df_provided and args["hover_data"][k][1] is not None and k in df_input:
raise ValueError(
"Ambiguous input: values for '%s' appear both in hover_data and data_frame"
% k
)
# Loop over possible arguments
for field_name in all_attrables:
# Massaging variables
argument_list = (
[args.get(field_name)]
if field_name not in array_attrables
else args.get(field_name)
)
# argument not specified, continue
# The original also tested `or argument_list is [None]` but
# that clause is always False, so it has been removed. The
# alternative fix would have been to test that `argument_list`
# is of length 1 and its sole element is `None`, but that
# feels pedantic. All tests pass with the change below; let's
# see if the world decides we were wrong.
if argument_list is None:
continue
# Argument name: field_name if the argument is not a list
# Else we give names like ["hover_data_0, hover_data_1"] etc.
field_list = (
[field_name]
if field_name not in array_attrables
else [field_name + "_" + str(i) for i in range(len(argument_list))]
)
# argument_list and field_list ready, iterate over them
# Core of the loop starts here
for i, (argument, field) in enumerate(zip(argument_list, field_list)):
length = len(df_output[next(iter(df_output))]) if len(df_output) else 0
if argument is None:
continue
col_name = None
# Case of multiindex
if is_pd_like and isinstance(argument, native_namespace.MultiIndex):
raise TypeError(
f"Argument '{field}' is a {native_namespace.__name__} MultiIndex. "
f"{native_namespace.__name__} MultiIndex is not supported by plotly "
"express at the moment."
)
# ----------------- argument is a special value ----------------------
if isinstance(argument, (Constant, Range)):
col_name = _check_name_not_reserved(
str(argument.label) if argument.label is not None else field,
reserved_names,
)
if isinstance(argument, Constant):
constants[col_name] = argument.value
else:
ranges.append(col_name)
# ----------------- argument is likely a col name ----------------------
elif isinstance(argument, str) or not hasattr(argument, "__len__"):
if (
field_name == "hover_data"
and hover_data_is_dict
and args["hover_data"][str(argument)][1] is not None
):
# hover_data has onboard data
# previously-checked to have no name-conflict with data_frame
col_name = str(argument)
real_argument = args["hover_data"][col_name][1]
if length and (real_length := len(real_argument)) != length:
raise ValueError(
"All arguments should have the same length. "
"The length of hover_data key `%s` is %d, whereas the "
"length of previously-processed arguments %s is %d"
% (
argument,
real_length,
str(list(df_output.keys())),
length,
)
)
df_output[col_name] = to_named_series(
real_argument, col_name, native_namespace
)
elif not df_provided:
raise ValueError(
"String or int arguments are only possible when a "
"DataFrame or an array is provided in the `data_frame` "
"argument. No DataFrame was provided, but argument "
"'%s' is of type str or int." % field
)
# Check validity of column name
elif argument not in df_input.columns:
if wide_mode and argument in (value_name, var_name):
continue
else:
err_msg = (
"Value of '%s' is not the name of a column in 'data_frame'. "
"Expected one of %s but received: %s"
% (field, str(list(df_input.columns)), argument)
)
if argument == "index":
err_msg += "\n To use the index, pass it in directly as `df.index`."
raise ValueError(err_msg)
elif length and (actual_len := len(df_input)) != length:
raise ValueError(
"All arguments should have the same length. "
"The length of column argument `df[%s]` is %d, whereas the "
"length of previously-processed arguments %s is %d"
% (
field,
actual_len,
str(list(df_output.keys())),
length,
)
)
else:
col_name = str(argument)
df_output[col_name] = to_named_series(
df_input.get_column(argument), col_name
)
# ----------------- argument is likely a column / array / list.... -------
else:
if df_provided and hasattr(argument, "name"):
if is_pd_like and argument is nw.maybe_get_index(df_input):
if argument.name is None or argument.name in df_input.columns:
col_name = "index"
else:
col_name = argument.name
col_name = _escape_col_name(
df_input.columns, col_name, [var_name, value_name]
)
else:
if (
argument.name is not None
and argument.name in df_input.columns
and (
to_named_series(
argument, argument.name, native_namespace
)
== df_input.get_column(argument.name)
).all()
):
col_name = argument.name
if col_name is None: # numpy array, list...
col_name = _check_name_not_reserved(field, reserved_names)
if length and (len_arg := len(argument)) != length:
raise ValueError(
"All arguments should have the same length. "
"The length of argument `%s` is %d, whereas the "
"length of previously-processed arguments %s is %d"
% (field, len_arg, str(list(df_output.keys())), length)
)
df_output[str(col_name)] = to_named_series(
x=argument,
name=str(col_name),
native_namespace=native_namespace,
)
# Finally, update argument with column name now that column exists
assert col_name is not None, (
"Data-frame processing failure, likely due to a internal bug. "
"Please report this to "
"https://github.com/plotly/plotly.py/issues/new and we will try to "
"replicate and fix it."
)
if field_name not in array_attrables:
args[field_name] = str(col_name)
elif isinstance(args[field_name], dict):
pass
else:
args[field_name][i] = str(col_name)
if field_name != "wide_variable":
wide_id_vars.add(str(col_name))
length = len(df_output[next(iter(df_output))]) if len(df_output) else 0
if native_namespace is None:
try:
import pandas as pd
native_namespace = pd
except ImportError:
msg = "Pandas installation is required if no dataframe is provided."
raise NotImplementedError(msg)
if ranges:
import numpy as np
range_series = nw.new_series(
name="__placeholder__",
values=np.arange(length),
native_namespace=native_namespace,
)
df_output.update(
{col_name: range_series.alias(col_name) for col_name in ranges}
)
df_output.update(
{
# constant is single value. repeat by len to avoid creating NaN on concatenating
col_name: nw.new_series(
name=col_name,
values=[constants[col_name]] * length,
native_namespace=native_namespace,
)
for col_name in constants
}
)
if df_output:
df_output = nw.from_dict(df_output)
else:
try:
import pandas as pd
except ImportError:
msg = "Pandas installation is required."
raise NotImplementedError(msg)
df_output = nw.from_native(pd.DataFrame({}), eager_only=True)
return df_output, wide_id_vars
def build_dataframe(args, constructor):
"""
Constructs a dataframe and modifies `args` in-place.
The argument values in `args` can be either strings corresponding to
existing columns of a dataframe, or data arrays (lists, numpy arrays,
pandas columns, series).
Parameters
----------
args : OrderedDict
arguments passed to the px function and subsequently modified
constructor : graph_object trace class
the trace type selected for this figure
"""
# make copies of all the fields via dict() and list()
for field in args:
if field in array_attrables and args[field] is not None:
if isinstance(args[field], dict):
args[field] = dict(args[field])
elif field in ["custom_data", "hover_data"] and isinstance(
args[field], str
):
args[field] = [args[field]]
else:
args[field] = list(args[field])
# Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.)
df_provided = args["data_frame"] is not None
# Flag that indicates if the resulting data_frame after parsing is pandas-like
# (in terms of resulting Narwhals DataFrame).
# True if pandas, modin.pandas or cudf DataFrame/Series instance, or converted from
# PySpark to pandas.
is_pd_like = False
# Flag that indicates if data_frame needs to be converted to PyArrow.
# True if Ibis, DuckDB, Vaex, or implements __dataframe__
needs_interchanging = False
# If data_frame is provided, we parse it into a narwhals DataFrame, while accounting
# for compatibility with pandas specific paths (e.g. Index/MultiIndex case).
if df_provided:
# data_frame is pandas-like DataFrame (pandas, modin.pandas, cudf)
if nw.dependencies.is_pandas_like_dataframe(args["data_frame"]):
columns = args["data_frame"].columns # This can be multi index
args["data_frame"] = nw.from_native(args["data_frame"], eager_only=True)
is_pd_like = True
# data_frame is pandas-like Series (pandas, modin.pandas, cudf)
elif nw.dependencies.is_pandas_like_series(args["data_frame"]):
args["data_frame"] = nw.from_native(
args["data_frame"], series_only=True
).to_frame()
columns = args["data_frame"].columns
is_pd_like = True
# data_frame is any other DataFrame object natively supported via Narwhals.
# With `pass_through=True`, the original object will be returned if unable to convert
# to a Narwhals DataFrame, making this condition False.
elif isinstance(
data_frame := nw.from_native(
args["data_frame"], eager_or_interchange_only=True, pass_through=True
),
nw.DataFrame,
):
args["data_frame"] = data_frame
needs_interchanging = nw.get_level(data_frame) == "interchange"
columns = args["data_frame"].columns
# data_frame is any other Series object natively supported via Narwhals.
# With `pass_through=True`, the original object will be returned if unable to convert
# to a Narwhals Series, making this condition False.
elif isinstance(
series := nw.from_native(
args["data_frame"], series_only=True, pass_through=True
),
nw.Series,
):
args["data_frame"] = series.to_frame()
columns = args["data_frame"].columns
# data_frame is PySpark: it does not support interchange protocol and it is not
# integrated in Narwhals. We use its native method to convert it to pandas.
elif hasattr(args["data_frame"], "toPandas"):
args["data_frame"] = nw.from_native(
args["data_frame"].toPandas(), eager_only=True
)
columns = args["data_frame"].columns
is_pd_like = True
# data_frame is some other object type (e.g. dict, list, ...)
# We try to import pandas, and then try to instantiate a pandas dataframe from
# this such object
else:
try:
import pandas as pd
try:
args["data_frame"] = nw.from_native(
pd.DataFrame(args["data_frame"])
)
columns = args["data_frame"].columns
is_pd_like = True
except Exception:
msg = (
f"Unable to convert data_frame of type {type(args['data_frame'])} "
"to pandas DataFrame. Please provide a supported dataframe type "
"or a type that can be passed to pd.DataFrame."
)
raise NotImplementedError(msg)
except ImportError:
msg = (
f"Attempting to convert data_frame of type {type(args['data_frame'])} "
"to pandas DataFrame, but Pandas is not installed. "
"Convert it to supported dataframe type or install pandas."
)
raise NotImplementedError(msg)
# data_frame is not provided
else:
columns = None
df_input: nw.DataFrame | None = args["data_frame"]
index = (
nw.maybe_get_index(df_input)
if df_provided and not needs_interchanging
else None
)
native_namespace = (
nw.get_native_namespace(df_input)
if df_provided and not needs_interchanging
else None
)
# now we handle special cases like wide-mode or x-xor-y specification
# by rearranging args to tee things up for process_args_into_dataframe to work
no_x = args.get("x") is None
no_y = args.get("y") is None
wide_x = (
False
if no_x
else _is_col_list(columns, args["x"], is_pd_like, native_namespace)
)
wide_y = (
False
if no_y
else _is_col_list(columns, args["y"], is_pd_like, native_namespace)
)
wide_mode = False
var_name = None # will likely be "variable" in wide_mode
wide_cross_name = None # will likely be "index" in wide_mode
value_name = None # will likely be "value" in wide_mode
hist2d_types = [go.Histogram2d, go.Histogram2dContour]
hist1d_orientation = constructor == go.Histogram or "ecdfmode" in args
if constructor in cartesians:
if wide_x and wide_y:
raise ValueError(
"Cannot accept list of column references or list of columns for both `x` and `y`."
)
if df_provided and no_x and no_y:
wide_mode = True
if is_pd_like and isinstance(columns, native_namespace.MultiIndex):
raise TypeError(
f"Data frame columns is a {native_namespace.__name__} MultiIndex. "
f"{native_namespace.__name__} MultiIndex is not supported by plotly "
"express at the moment."
)
args["wide_variable"] = list(columns)
if is_pd_like and isinstance(columns, native_namespace.Index):
var_name = columns.name
else:
var_name = None
if var_name in [None, "value", "index"] or var_name in columns:
var_name = "variable"
if constructor == go.Funnel:
wide_orientation = args.get("orientation") or "h"
else:
wide_orientation = args.get("orientation") or "v"
args["orientation"] = wide_orientation
args["wide_cross"] = None
elif wide_x != wide_y:
wide_mode = True
args["wide_variable"] = args["y"] if wide_y else args["x"]
if df_provided and is_pd_like and args["wide_variable"] is columns:
var_name = columns.name
if is_pd_like and isinstance(args["wide_variable"], native_namespace.Index):
args["wide_variable"] = list(args["wide_variable"])
if var_name in [None, "value", "index"] or (
df_provided and var_name in columns
):
var_name = "variable"
if hist1d_orientation:
wide_orientation = "v" if wide_x else "h"
else:
wide_orientation = "v" if wide_y else "h"
args["y" if wide_y else "x"] = None
args["wide_cross"] = None
if not no_x and not no_y:
wide_cross_name = "__x__" if wide_y else "__y__"
if wide_mode:
value_name = _escape_col_name(columns, "value", [])
var_name = _escape_col_name(columns, var_name, [])
# If the data_frame has interchange-only support levelin Narwhals, then we need to
# convert it to a full support level backend.
# Hence we convert requires Interchange to PyArrow.
if needs_interchanging:
if wide_mode:
args["data_frame"] = nw.from_native(
args["data_frame"].to_arrow(), eager_only=True
)
else:
# Save precious resources by only interchanging columns that are
# actually going to be plotted. This is tricky to do in the general case,
# because Plotly allows calls like `px.line(df, x='x', y=['y1', df['y1']])`,
# but interchange-only objects (e.g. DuckDB) don't typically have a concept
# of self-standing Series. It's more important to perform project pushdown
# here seeing as we're materialising to an (eager) PyArrow table.
necessary_columns = {
i for i in args.values() if isinstance(i, str) and i in columns
}
for field in args:
if args[field] is not None and field in array_attrables:
necessary_columns.update(i for i in args[field] if i in columns)
columns = list(necessary_columns)
args["data_frame"] = nw.from_native(
args["data_frame"].select(columns).to_arrow(), eager_only=True
)
import pyarrow as pa
native_namespace = pa
missing_bar_dim = None
if (
constructor in [go.Scatter, go.Bar, go.Funnel] + hist2d_types
and not hist1d_orientation
):
if not wide_mode and (no_x != no_y):
for ax in ["x", "y"]:
if args.get(ax) is None:
args[ax] = (
index
if index is not None
else Range(
label=_escape_col_name(columns, ax, [var_name, value_name])
)
)
if constructor == go.Bar:
missing_bar_dim = ax
else:
if args["orientation"] is None:
args["orientation"] = "v" if ax == "x" else "h"
if wide_mode and wide_cross_name is None:
if no_x != no_y and args["orientation"] is None:
args["orientation"] = "v" if no_x else "h"
if df_provided and is_pd_like and index is not None:
if isinstance(index, native_namespace.MultiIndex):
raise TypeError(
f"Data frame index is a {native_namespace.__name__} MultiIndex. "
f"{native_namespace.__name__} MultiIndex is not supported by "
"plotly express at the moment."
)
args["wide_cross"] = index
else:
args["wide_cross"] = Range(
label=_escape_col_name(columns, "index", [var_name, value_name])
)
no_color = False
if isinstance(args.get("color"), str) and args["color"] == NO_COLOR:
no_color = True
args["color"] = None
# now that things have been prepped, we do the systematic rewriting of `args`
df_output, wide_id_vars = process_args_into_dataframe(
args,
wide_mode,
var_name,
value_name,
is_pd_like,
native_namespace,
)
df_output: nw.DataFrame
# now that `df_output` exists and `args` contains only references, we complete
# the special-case and wide-mode handling by further rewriting args and/or mutating
# df_output
count_name = _escape_col_name(df_output.columns, "count", [var_name, value_name])
if not wide_mode and missing_bar_dim and constructor == go.Bar:
# now that we've populated df_output, we check to see if the non-missing
# dimension is categorical: if so, then setting the missing dimension to a
# constant 1 is a less-insane thing to do than setting it to the index by
# default and we let the normal auto-orientation-code do its thing later
other_dim = "x" if missing_bar_dim == "y" else "y"
if not _is_continuous(df_output, args[other_dim]):
args[missing_bar_dim] = count_name
df_output = df_output.with_columns(nw.lit(1).alias(count_name))
else:
# on the other hand, if the non-missing dimension is continuous, then we
# can use this information to override the normal auto-orientation code
if args["orientation"] is None:
args["orientation"] = "v" if missing_bar_dim == "x" else "h"
if constructor in hist2d_types:
del args["orientation"]
if wide_mode:
# at this point, `df_output` is semi-long/semi-wide, but we know which columns
# are which, so we melt it and reassign `args` to refer to the newly-tidy
# columns, keeping track of various names and manglings set up above
wide_value_vars = [c for c in args["wide_variable"] if c not in wide_id_vars]
del args["wide_variable"]
if wide_cross_name == "__x__":
wide_cross_name = args["x"]
elif wide_cross_name == "__y__":
wide_cross_name = args["y"]
else:
wide_cross_name = args["wide_cross"]
del args["wide_cross"]
dtype = None
for v in wide_value_vars:
v_dtype = df_output.get_column(v).dtype
v_dtype = "number" if v_dtype.is_numeric() else str(v_dtype)
if dtype is None:
dtype = v_dtype
elif dtype != v_dtype:
raise ValueError(
"Plotly Express cannot process wide-form data with columns of different type."
)
df_output = df_output.unpivot(
index=wide_id_vars,
on=wide_value_vars,
variable_name=var_name,
value_name=value_name,
)
assert len(df_output.columns) == len(set(df_output.columns)), (
"Wide-mode name-inference failure, likely due to a internal bug. "
"Please report this to "
"https://github.com/plotly/plotly.py/issues/new and we will try to "
"replicate and fix it."
)
df_output = df_output.with_columns(nw.col(var_name).cast(nw.String))
orient_v = wide_orientation == "v"
if hist1d_orientation:
args["x" if orient_v else "y"] = value_name
args["y" if orient_v else "x"] = wide_cross_name
args["color"] = args["color"] or var_name
elif constructor in [go.Scatter, go.Funnel] + hist2d_types:
args["x" if orient_v else "y"] = wide_cross_name
args["y" if orient_v else "x"] = value_name
if constructor != go.Histogram2d:
args["color"] = args["color"] or var_name
if "line_group" in args:
args["line_group"] = args["line_group"] or var_name
elif constructor == go.Bar:
if _is_continuous(df_output, value_name):
args["x" if orient_v else "y"] = wide_cross_name
args["y" if orient_v else "x"] = value_name
args["color"] = args["color"] or var_name
else:
args["x" if orient_v else "y"] = value_name
args["y" if orient_v else "x"] = count_name
df_output = df_output.with_columns(nw.lit(1).alias(count_name))
args["color"] = args["color"] or var_name
elif constructor in [go.Violin, go.Box]:
args["x" if orient_v else "y"] = wide_cross_name or var_name
args["y" if orient_v else "x"] = value_name
if hist1d_orientation and constructor == go.Scatter:
if args["x"] is not None and args["y"] is not None:
args["histfunc"] = "sum"
elif args["x"] is None:
args["histfunc"] = None
args["orientation"] = "h"
args["x"] = count_name
df_output = df_output.with_columns(nw.lit(1).alias(count_name))
else:
args["histfunc"] = None
args["orientation"] = "v"
args["y"] = count_name
df_output = df_output.with_columns(nw.lit(1).alias(count_name))
if no_color:
args["color"] = None
args["data_frame"] = df_output
return args
def _check_dataframe_all_leaves(df: nw.DataFrame) -> None:
cols = df.columns
df_sorted = df.sort(by=cols, descending=False, nulls_last=True)
null_mask = df_sorted.select(nw.all().is_null())
df_sorted = df_sorted.select(nw.all().cast(nw.String()))
null_indices_mask = null_mask.select(
null_mask=nw.any_horizontal(nw.all())
).get_column("null_mask")
null_mask_filtered = null_mask.filter(null_indices_mask)
if not null_mask_filtered.is_empty():
for col_idx in range(1, null_mask_filtered.shape[1]):
# For each row, if a True value is encountered, then check that
# all values in subsequent columns are also True
null_entries_with_non_null_children = (
~null_mask_filtered[:, col_idx] & null_mask_filtered[:, col_idx - 1]
)
if nw.to_py_scalar(null_entries_with_non_null_children.any()):
row_idx = null_entries_with_non_null_children.to_list().index(True)
raise ValueError(
"None entries cannot have not-None children",
df_sorted.row(row_idx),
)
fill_series = nw.new_series(
name="fill_value",
values=[""] * len(df_sorted),
dtype=nw.String(),
native_namespace=nw.get_native_namespace(df_sorted),
)
df_sorted = df_sorted.with_columns(
**{
c: df_sorted.get_column(c).zip_with(~null_mask.get_column(c), fill_series)
for c in cols
}
)
# Conversion to list is due to python native vs pyarrow scalars
row_strings = (
df_sorted.select(
row_strings=nw.concat_str(cols, separator="", ignore_nulls=False)
)
.get_column("row_strings")
.to_list()
)
null_indices = set(null_indices_mask.arg_true().to_list())
for i, (current_row, next_row) in enumerate(
zip(row_strings[:-1], row_strings[1:]), start=1
):
if (next_row in current_row) and (i in null_indices):
raise ValueError(
"Non-leaves rows are not permitted in the dataframe \n",
df_sorted.row(i),
"is not a leaf.",
)
def process_dataframe_hierarchy(args):
"""
Build dataframe for sunburst, treemap, or icicle when the path argument is provided.
"""
df: nw.DataFrame = args["data_frame"]
path = args["path"][::-1]
_check_dataframe_all_leaves(df[path[::-1]])
discrete_color = not _is_continuous(df, args["color"]) if args["color"] else False
df = df.lazy()
new_path = [col_name + "_path_copy" for col_name in path]
df = df.with_columns(
nw.col(col_name).alias(new_col_name)
for new_col_name, col_name in zip(new_path, path)
)
path = new_path
# ------------ Define aggregation functions --------------------------------
agg_f = {}
if args["values"]:
try:
df = df.with_columns(nw.col(args["values"]).cast(nw.Float64()))
except Exception: # pandas, Polars and pyarrow exception types are different
raise ValueError(
"Column `%s` of `df` could not be converted to a numerical data type."
% args["values"]
)
if args["color"] and args["color"] == args["values"]:
new_value_col_name = args["values"] + "_sum"
df = df.with_columns(nw.col(args["values"]).alias(new_value_col_name))
args["values"] = new_value_col_name
count_colname = args["values"]
else:
# we need a count column for the first groupby and the weighted mean of color
# trick to be sure the col name is unused: take the sum of existing names
columns = df.collect_schema().names()
count_colname = (
"count" if "count" not in columns else "".join([str(el) for el in columns])
)
# we can modify df because it's a copy of the px argument
df = df.with_columns(nw.lit(1).alias(count_colname))
args["values"] = count_colname
# Since count_colname is always in agg_f, it can be used later to normalize color
# in the continuous case after some gymnastic
agg_f[count_colname] = nw.sum(count_colname)
discrete_aggs = []
continuous_aggs = []
n_unique_token = _generate_temporary_column_name(
n_bytes=16, columns=df.collect_schema().names()
)
# In theory, for discrete columns aggregation, we should have a way to do
# `.agg(nw.col(x).unique())` in group_by and successively unpack/parse it as:
# ```
# (nw.when(nw.col(x).list.len()==1)
# .then(nw.col(x).list.first())
# .otherwise(nw.lit("(?)"))
# )
# ```
# which replicates the original pandas only codebase:
# ```
# def discrete_agg(x):
# uniques = x.unique()
# return uniques[0] if len(uniques) == 1 else "(?)"
#
# df.groupby(path[i:]).agg(...)
# ```
# However this is not possible, therefore the following workaround is provided.
# We make two aggregations for the same column:
# - take the max value
# - take the number of unique values
# Finally, after the group by statement, it is unpacked via:
# ```
# (nw.when(nw.col(col_n_unique) == 1)
# .then(nw.col(col_max_value)) # which is the unique value
# .otherwise(nw.lit("(?)"))
# )
# ```
if args["color"]:
if discrete_color:
discrete_aggs.append(args["color"])
agg_f[args["color"]] = nw.col(args["color"]).max()
agg_f[f"{args['color']}{n_unique_token}"] = (
nw.col(args["color"])
.n_unique()
.alias(f"{args['color']}{n_unique_token}")
)
else:
# This first needs to be multiplied by `count_colname`
continuous_aggs.append(args["color"])
agg_f[args["color"]] = nw.sum(args["color"])
# Other columns (for color, hover_data, custom_data etc.)
cols = list(set(df.collect_schema().names()).difference(path))
df = df.with_columns(nw.col(c).cast(nw.String()) for c in cols if c not in agg_f)
for col in cols: # for hover_data, custom_data etc.
if col not in agg_f:
# Similar trick as above
discrete_aggs.append(col)
agg_f[col] = nw.col(col).max()
agg_f[f"{col}{n_unique_token}"] = (
nw.col(col).n_unique().alias(f"{col}{n_unique_token}")
)
# Avoid collisions with reserved names - columns in the path have been copied already
cols = list(set(cols) - set(["labels", "parent", "id"]))
# ----------------------------------------------------------------------------
all_trees = []
if args["color"] and not discrete_color:
df = df.with_columns(
(nw.col(args["color"]) * nw.col(count_colname)).alias(args["color"])
)
def post_agg(dframe: nw.LazyFrame, continuous_aggs, discrete_aggs) -> nw.LazyFrame:
"""
- continuous_aggs is either [] or [args["color"]]
- discrete_aggs is either [args["color"], ] or []
"""
return dframe.with_columns(
*[nw.col(col) / nw.col(count_colname) for col in continuous_aggs],
*[
(
nw.when(nw.col(f"{col}{n_unique_token}") == 1)
.then(nw.col(col))
.otherwise(nw.lit("(?)"))
.alias(col)
)
for col in discrete_aggs
],
).drop([f"{col}{n_unique_token}" for col in discrete_aggs])
for i, level in enumerate(path):
dfg = (
df.group_by(path[i:], drop_null_keys=True)
.agg(**agg_f)
.pipe(post_agg, continuous_aggs, discrete_aggs)
)
# Path label massaging
df_tree = dfg.with_columns(
*cols,
labels=nw.col(level).cast(nw.String()),
parent=nw.lit(""),
id=nw.col(level).cast(nw.String()),
)
if i < len(path) - 1:
_concat_str_token = _generate_temporary_column_name(
n_bytes=16, columns=[*cols, "labels", "parent", "id"]
)
df_tree = (
df_tree.with_columns(
nw.concat_str(
[
nw.col(path[j]).cast(nw.String())
for j in range(len(path) - 1, i, -1)
],
separator="/",
).alias(_concat_str_token)
)
.with_columns(
parent=nw.concat_str(
[nw.col(_concat_str_token), nw.col("parent")], separator="/"
),
id=nw.concat_str(
[nw.col(_concat_str_token), nw.col("id")], separator="/"
),
)
.drop(_concat_str_token)
)
# strip "/" if at the end of the string, equivalent to `.str.rstrip`
df_tree = df_tree.with_columns(
parent=nw.col("parent").str.replace("/?$", "").str.replace("^/?", "")
)
all_trees.append(df_tree.select(*["labels", "parent", "id", *cols]))
df_all_trees = nw.maybe_reset_index(nw.concat(all_trees, how="vertical").collect())
# we want to make sure than (?) is the first color of the sequence
if args["color"] and discrete_color:
sort_col_name = "sort_color_if_discrete_color"
while sort_col_name in df_all_trees.columns:
sort_col_name += "0"
df_all_trees = df_all_trees.with_columns(
nw.col(args["color"]).cast(nw.String()).alias(sort_col_name)
).sort(by=sort_col_name, nulls_last=True)
# Now modify arguments
args["data_frame"] = df_all_trees
args["path"] = None
args["ids"] = "id"
args["names"] = "labels"
args["parents"] = "parent"
if args["color"]:
if not args["hover_data"]:
args["hover_data"] = [args["color"]]
elif isinstance(args["hover_data"], dict):
if not args["hover_data"].get(args["color"]):
args["hover_data"][args["color"]] = (True, None)
else:
args["hover_data"].append(args["color"])
return args
def process_dataframe_timeline(args):
"""
Massage input for bar traces for px.timeline()
"""
args["is_timeline"] = True
if args["x_start"] is None or args["x_end"] is None:
raise ValueError("Both x_start and x_end are required")
df: nw.DataFrame = args["data_frame"]
schema = df.schema
to_convert_to_datetime = [
col
for col in [args["x_start"], args["x_end"]]
if schema[col] != nw.Datetime and schema[col] != nw.Date
]
if to_convert_to_datetime:
try:
df = df.with_columns(nw.col(to_convert_to_datetime).str.to_datetime())
except Exception as exc:
raise TypeError(
"Both x_start and x_end must refer to data convertible to datetimes."
) from exc
# note that we are not adding any columns to the data frame here, so no risk of overwrite
args["data_frame"] = df.with_columns(
(nw.col(args["x_end"]) - nw.col(args["x_start"]))
.dt.total_milliseconds()
.alias(args["x_end"])
)
args["x"] = args["x_end"]
args["base"] = args["x_start"]
del args["x_start"], args["x_end"]
return args
def process_dataframe_pie(args, trace_patch):
import numpy as np
names = args.get("names")
if names is None:
return args, trace_patch
order_in = args["category_orders"].get(names, {}).copy()
if not order_in:
return args, trace_patch
df: nw.DataFrame = args["data_frame"]
trace_patch["sort"] = False
trace_patch["direction"] = "clockwise"
uniques = df.get_column(names).unique(maintain_order=True).to_list()
order = [x for x in OrderedDict.fromkeys(list(order_in) + uniques) if x in uniques]
# Sort args['data_frame'] by column `names` according to order `order`.
token = nw.generate_temporary_column_name(8, df.columns)
args["data_frame"] = (
df.with_columns(
nw.col(names)
.replace_strict(order, np.arange(len(order)), return_dtype=nw.UInt32)
.alias(token)
)
.sort(token)
.drop(token)
)
return args, trace_patch
def infer_config(args, constructor, trace_patch, layout_patch):
attrs = [k for k in direct_attrables + array_attrables if k in args]
grouped_attrs = []
df: nw.DataFrame = args["data_frame"]
# Compute sizeref
sizeref = 0
if "size" in args and args["size"]:
sizeref = (
nw.to_py_scalar(df.get_column(args["size"]).max()) / args["size_max"] ** 2
)
# Compute color attributes and grouping attributes
if "color" in args:
if "color_continuous_scale" in args:
if "color_discrete_sequence" not in args:
attrs.append("color")
else:
if args["color"] and _is_continuous(df, args["color"]):
attrs.append("color")
args["color_is_continuous"] = True
elif constructor in [go.Sunburst, go.Treemap, go.Icicle]:
attrs.append("color")
args["color_is_continuous"] = False
else:
grouped_attrs.append("marker.color")
elif "line_group" in args or constructor == go.Histogram2dContour:
grouped_attrs.append("line.color")
elif constructor in [go.Pie, go.Funnelarea]:
attrs.append("color")
if args["color"]:
if args["hover_data"] is None:
args["hover_data"] = []
args["hover_data"].append(args["color"])
else:
grouped_attrs.append("marker.color")
show_colorbar = bool(
"color" in attrs
and args["color"]
and constructor not in [go.Pie, go.Funnelarea]
and (
constructor not in [go.Treemap, go.Sunburst, go.Icicle]
or args.get("color_is_continuous")
)
)
else:
show_colorbar = False
if "line_dash" in args:
grouped_attrs.append("line.dash")
if "symbol" in args:
grouped_attrs.append("marker.symbol")
if "pattern_shape" in args:
if constructor in [go.Scatter]:
grouped_attrs.append("fillpattern.shape")
else:
grouped_attrs.append("marker.pattern.shape")
if "orientation" in args:
has_x = args["x"] is not None
has_y = args["y"] is not None
if args["orientation"] is None:
if constructor in [go.Histogram, go.Scatter]:
if has_y and not has_x:
args["orientation"] = "h"
elif constructor in [go.Violin, go.Box, go.Bar, go.Funnel]:
if has_x and not has_y:
args["orientation"] = "h"
if args["orientation"] is None and has_x and has_y:
x_is_continuous = _is_continuous(df, args["x"])
y_is_continuous = _is_continuous(df, args["y"])
if x_is_continuous and not y_is_continuous:
args["orientation"] = "h"
if y_is_continuous and not x_is_continuous:
args["orientation"] = "v"
if args["orientation"] is None:
args["orientation"] = "v"
if constructor == go.Histogram:
if has_x and has_y and args["histfunc"] is None:
args["histfunc"] = trace_patch["histfunc"] = "sum"
orientation = args["orientation"]
nbins = args["nbins"]
trace_patch["nbinsx"] = nbins if orientation == "v" else None
trace_patch["nbinsy"] = None if orientation == "v" else nbins
trace_patch["bingroup"] = "x" if orientation == "v" else "y"
trace_patch["orientation"] = args["orientation"]
if constructor in [go.Violin, go.Box]:
mode = "boxmode" if constructor == go.Box else "violinmode"
if layout_patch[mode] is None and args["color"] is not None:
if args["y"] == args["color"] and args["orientation"] == "h":
layout_patch[mode] = "overlay"
elif args["x"] == args["color"] and args["orientation"] == "v":
layout_patch[mode] = "overlay"
if layout_patch[mode] is None:
layout_patch[mode] = "group"
if (
constructor == go.Histogram2d
and args["z"] is not None
and args["histfunc"] is None
):
args["histfunc"] = trace_patch["histfunc"] = "sum"
if args.get("text_auto", False) is not False:
if constructor in [go.Histogram2d, go.Histogram2dContour]:
letter = "z"
elif constructor == go.Bar:
letter = "y" if args["orientation"] == "v" else "x"
else:
letter = "value"
if args["text_auto"] is True:
trace_patch["texttemplate"] = "%{" + letter + "}"
else:
trace_patch["texttemplate"] = "%{" + letter + ":" + args["text_auto"] + "}"
if constructor in [go.Histogram2d, go.Densitymap, go.Densitymapbox]:
show_colorbar = True
trace_patch["coloraxis"] = "coloraxis1"
if "opacity" in args:
if args["opacity"] is None:
if "barmode" in args and args["barmode"] == "overlay":
trace_patch["marker"] = dict(opacity=0.5)
elif constructor in [
go.Densitymap,
go.Densitymapbox,
go.Pie,
go.Funnel,
go.Funnelarea,
]:
trace_patch["opacity"] = args["opacity"]
else:
trace_patch["marker"] = dict(opacity=args["opacity"])
if (
"line_group" in args or "line_dash" in args
): # px.line, px.line_*, px.area, px.ecdf
modes = set()
if args.get("lines", True):
modes.add("lines")
if args.get("text") or args.get("symbol") or args.get("markers"):
modes.add("markers")
if args.get("text"):
modes.add("text")
if len(modes) == 0:
modes.add("lines")
trace_patch["mode"] = "+".join(sorted(modes))
elif constructor != go.Splom and (
"symbol" in args or constructor in [go.Scattermap, go.Scattermapbox]
):
trace_patch["mode"] = "markers" + ("+text" if args["text"] else "")
if "line_shape" in args:
trace_patch["line"] = dict(shape=args["line_shape"])
elif "ecdfmode" in args:
trace_patch["line"] = dict(
shape="vh" if args["ecdfmode"] == "reversed" else "hv"
)
if "geojson" in args:
trace_patch["featureidkey"] = args["featureidkey"]
trace_patch["geojson"] = (
args["geojson"]
if not hasattr(args["geojson"], "__geo_interface__") # for geopandas
else args["geojson"].__geo_interface__
)
# Compute marginal attribute: copy to appropriate marginal_*
if "marginal" in args:
position = "marginal_x" if args["orientation"] == "v" else "marginal_y"
other_position = "marginal_x" if args["orientation"] == "h" else "marginal_y"
args[position] = args["marginal"]
args[other_position] = None
# Ignore facet rows and columns when data frame is empty so as to prevent nrows/ncols equaling 0
if df.is_empty():
args["facet_row"] = args["facet_col"] = None
# If both marginals and faceting are specified, faceting wins
if args.get("facet_col") is not None and args.get("marginal_y") is not None:
args["marginal_y"] = None
if args.get("facet_row") is not None and args.get("marginal_x") is not None:
args["marginal_x"] = None
# facet_col_wrap only works if no marginals or row faceting is used
if (
args.get("marginal_x") is not None
or args.get("marginal_y") is not None
or args.get("facet_row") is not None
):
args["facet_col_wrap"] = 0
if "trendline" in args and args["trendline"] is not None:
if args["trendline"] not in trendline_functions:
raise ValueError(
"Value '%s' for `trendline` must be one of %s"
% (args["trendline"], trendline_functions.keys())
)
if "trendline_options" in args and args["trendline_options"] is None:
args["trendline_options"] = dict()
if "ecdfnorm" in args:
if args.get("ecdfnorm", None) not in [None, "percent", "probability"]:
raise ValueError(
"`ecdfnorm` must be one of None, 'percent' or 'probability'. "
+ "'%s' was provided." % args["ecdfnorm"]
)
args["histnorm"] = args["ecdfnorm"]
# Compute applicable grouping attributes
grouped_attrs.extend([k for k in group_attrables if k in args])
# Create grouped mappings
grouped_mappings = [make_mapping(args, a) for a in grouped_attrs]
# Create trace specs
trace_specs = make_trace_spec(args, constructor, attrs, trace_patch)
return trace_specs, grouped_mappings, sizeref, show_colorbar
def get_groups_and_orders(args, grouper):
"""
`orders` is the user-supplied ordering with the remaining data-frame-supplied
ordering appended if the column is used for grouping. It includes anything the user
gave, for any variable, including values not present in the dataset. It's a dict
where the keys are e.g. "x" or "color"
`groups` is the dicts of groups, ordered by the order above. Its keys are
tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
of a single dimension-group
"""
orders = {} if "category_orders" not in args else args["category_orders"].copy()
df: nw.DataFrame = args["data_frame"]
# figure out orders and what the single group name would be if there were one
single_group_name = []
unique_cache = dict()
for i, col in enumerate(grouper):
if col == one_group:
single_group_name.append("")
else:
if col not in unique_cache:
unique_cache[col] = (
df.get_column(col).unique(maintain_order=True).to_list()
)
uniques = unique_cache[col]
if len(uniques) == 1:
single_group_name.append(uniques[0])
if col not in orders:
orders[col] = uniques
else:
orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques))
if len(single_group_name) == len(grouper):
# we have a single group, so we can skip all group-by operations!
groups = {tuple(single_group_name): df}
else:
required_grouper = [group for group in orders if group in grouper]
grouped = dict(df.group_by(required_grouper, drop_null_keys=True).__iter__())
sorted_group_names = sorted(
grouped.keys(),
key=lambda values: [
orders[group].index(value) if value in orders[group] else -1
for group, value in zip(required_grouper, values)
],
)
# calculate the full group_names by inserting "" in the tuple index for one_group groups
full_sorted_group_names = [
tuple(
[
(
""
if col == one_group
else sub_group_names[required_grouper.index(col)]
)
for col in grouper
]
)
for sub_group_names in sorted_group_names
]
groups = {
sf: grouped[s] for sf, s in zip(full_sorted_group_names, sorted_group_names)
}
return groups, orders
def make_figure(args, constructor, trace_patch=None, layout_patch=None):
trace_patch = trace_patch or {}
layout_patch = layout_patch or {}
apply_default_cascade(args)
args = build_dataframe(args, constructor)
if constructor in [go.Treemap, go.Sunburst, go.Icicle] and args["path"] is not None:
args = process_dataframe_hierarchy(args)
if constructor in [go.Pie]:
args, trace_patch = process_dataframe_pie(args, trace_patch)
if constructor == "timeline":
constructor = go.Bar
args = process_dataframe_timeline(args)
# If we have marginal histograms, set barmode to "overlay"
if "histogram" in [args.get("marginal_x"), args.get("marginal_y")]:
layout_patch["barmode"] = "overlay"
trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
args, constructor, trace_patch, layout_patch
)
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
groups, orders = get_groups_and_orders(args, grouper)
col_labels = []
row_labels = []
nrows = ncols = 1
for m in grouped_mappings:
if m.grouper not in orders:
m.val_map[""] = m.sequence[0]
else:
sorted_values = orders[m.grouper]
if m.facet == "col":
prefix = get_label(args, args["facet_col"]) + "="
col_labels = [prefix + str(s) for s in sorted_values]
ncols = len(col_labels)
if m.facet == "row":
prefix = get_label(args, args["facet_row"]) + "="
row_labels = [prefix + str(s) for s in sorted_values]
nrows = len(row_labels)
for val in sorted_values:
if val not in m.val_map: # always False if it's an IdentityMap
m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)]
subplot_type = _subplot_type_for_trace_type(constructor().type)
trace_names_by_frame = {}
frames = OrderedDict()
trendline_rows = []
trace_name_labels = None
facet_col_wrap = args.get("facet_col_wrap", 0)
for group_name, group in groups.items():
mapping_labels = OrderedDict()
trace_name_labels = OrderedDict()
frame_name = ""
for col, val, m in zip(grouper, group_name, grouped_mappings):
if col != one_group:
key = get_label(args, col)
if not isinstance(m.val_map, IdentityMap):
mapping_labels[key] = str(val)
if m.show_in_trace_name:
trace_name_labels[key] = str(val)
if m.variable == "animation_frame":
frame_name = val
trace_name = ", ".join(trace_name_labels.values())
if frame_name not in trace_names_by_frame:
trace_names_by_frame[frame_name] = set()
trace_names = trace_names_by_frame[frame_name]
for trace_spec in trace_specs:
# Create the trace
trace = trace_spec.constructor(name=trace_name)
if trace_spec.constructor not in [
go.Parcats,
go.Parcoords,
go.Choropleth,
go.Choroplethmap,
go.Choroplethmapbox,
go.Densitymap,
go.Densitymapbox,
go.Histogram2d,
go.Sunburst,
go.Treemap,
go.Icicle,
]:
trace.update(
legendgroup=trace_name,
showlegend=(trace_name != "" and trace_name not in trace_names),
)
# Set 'offsetgroup' only in group barmode (or if no barmode is set)
barmode = layout_patch.get("barmode")
if trace_spec.constructor in [go.Bar, go.Box, go.Violin, go.Histogram] and (
barmode == "group" or barmode is None
):
trace.update(alignmentgroup=True, offsetgroup=trace_name)
trace_names.add(trace_name)
# Init subplot row/col
trace._subplot_row = 1
trace._subplot_col = 1
for i, m in enumerate(grouped_mappings):
val = group_name[i]
try:
m.updater(trace, m.val_map[val]) # covers most cases
except ValueError:
# this catches some odd cases like marginals
if (
trace_spec != trace_specs[0]
and (
trace_spec.constructor in [go.Violin, go.Box]
and m.variable in ["symbol", "pattern", "dash"]
)
or (
trace_spec.constructor in [go.Histogram]
and m.variable in ["symbol", "dash"]
)
):
pass
elif (
trace_spec != trace_specs[0]
and trace_spec.constructor in [go.Histogram]
and m.variable == "color"
):
trace.update(marker=dict(color=m.val_map[val]))
elif (
trace_spec.constructor
in [go.Choropleth, go.Choroplethmap, go.Choroplethmapbox]
and m.variable == "color"
):
trace.update(
z=[1] * len(group),
colorscale=[m.val_map[val]] * 2,
showscale=False,
showlegend=True,
)
else:
raise
# Find row for trace, handling facet_row and marginal_x
if m.facet == "row":
row = m.val_map[val]
else:
if (
args.get("marginal_x") is not None # there is a marginal
and trace_spec.marginal != "x" # and we're not it
):
row = 2
else:
row = 1
# Find col for trace, handling facet_col and marginal_y
if m.facet == "col":
col = m.val_map[val]
if facet_col_wrap: # assumes no facet_row, no marginals
row = 1 + ((col - 1) // facet_col_wrap)
col = 1 + ((col - 1) % facet_col_wrap)
else:
if trace_spec.marginal == "y":
col = 2
else:
col = 1
if row > 1:
trace._subplot_row = row
if col > 1:
trace._subplot_col = col
if (
trace_specs[0].constructor == go.Histogram2dContour
and trace_spec.constructor == go.Box
and trace.line.color
):
trace.update(marker=dict(color=trace.line.color))
if "ecdfmode" in args:
base = args["x"] if args["orientation"] == "v" else args["y"]
var = args["x"] if args["orientation"] == "h" else args["y"]
ascending = args.get("ecdfmode", "standard") != "reversed"
group = group.sort(by=base, descending=not ascending, nulls_last=True)
group_sum = group.get_column(
var
).sum() # compute here before next line mutates
group = group.with_columns(nw.col(var).cum_sum().alias(var))
if not ascending:
group = group.sort(by=base, descending=False, nulls_last=True)
if args.get("ecdfmode", "standard") == "complementary":
group = group.with_columns((group_sum - nw.col(var)).alias(var))
if args["ecdfnorm"] == "probability":
group = group.with_columns(nw.col(var) / group_sum)
elif args["ecdfnorm"] == "percent":
group = group.with_columns((nw.col(var) / group_sum) * 100.0)
patch, fit_results = make_trace_kwargs(
args, trace_spec, group, mapping_labels.copy(), sizeref
)
trace.update(patch)
if fit_results is not None:
trendline_rows.append(mapping_labels.copy())
trendline_rows[-1]["px_fit_results"] = fit_results
if frame_name not in frames:
frames[frame_name] = dict(data=[], name=frame_name)
frames[frame_name]["data"].append(trace)
frame_list = [f for f in frames.values()]
if len(frame_list) > 1:
frame_list = sorted(
frame_list, key=lambda f: orders[args["animation_frame"]].index(f["name"])
)
if show_colorbar:
colorvar = (
"z"
if constructor in [go.Histogram2d, go.Densitymap, go.Densitymapbox]
else "color"
)
range_color = args["range_color"] or [None, None]
colorscale_validator = ColorscaleValidator("colorscale", "make_figure")
layout_patch["coloraxis1"] = dict(
colorscale=colorscale_validator.validate_coerce(
args["color_continuous_scale"]
),
cmid=args["color_continuous_midpoint"],
cmin=range_color[0],
cmax=range_color[1],
colorbar=dict(
title_text=get_decorated_label(args, args[colorvar], colorvar)
),
)
for v in ["height", "width"]:
if args[v]:
layout_patch[v] = args[v]
layout_patch["legend"] = dict(tracegroupgap=0)
if trace_name_labels:
layout_patch["legend"]["title_text"] = ", ".join(trace_name_labels)
if args["title"]:
layout_patch["title_text"] = args["title"]
elif args["template"].layout.margin.t is None:
layout_patch["margin"] = {"t": 60}
if args["subtitle"]:
layout_patch["title_subtitle_text"] = args["subtitle"]
if (
"size" in args
and args["size"]
and args["template"].layout.legend.itemsizing is None
):
layout_patch["legend"]["itemsizing"] = "constant"
if facet_col_wrap:
nrows = math.ceil(ncols / facet_col_wrap)
ncols = min(ncols, facet_col_wrap)
if args.get("marginal_x") is not None:
nrows += 1
if args.get("marginal_y") is not None:
ncols += 1
fig = init_figure(
args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels
)
# Position traces in subplots
for frame in frame_list:
for trace in frame["data"]:
if isinstance(trace, go.Splom):
# Special case that is not compatible with make_subplots
continue
_set_trace_grid_reference(
trace,
fig.layout,
fig._grid_ref,
nrows - trace._subplot_row + 1,
trace._subplot_col,
)
# Add traces, layout and frames to figure
fig.add_traces(frame_list[0]["data"] if len(frame_list) > 0 else [])
fig.update_layout(layout_patch)
if "template" in args and args["template"] is not None:
fig.update_layout(template=args["template"], overwrite=True)
for f in frame_list:
f["name"] = str(f["name"])
fig.frames = frame_list if len(frames) > 1 else []
if args.get("trendline") and args.get("trendline_scope", "trace") == "overall":
trendline_spec = make_trendline_spec(args, constructor)
trendline_trace = trendline_spec.constructor(
name="Overall Trendline", legendgroup="Overall Trendline", showlegend=False
)
if "line" not in trendline_spec.trace_patch: # no color override
for m in grouped_mappings:
if m.variable == "color":
next_color = m.sequence[len(m.val_map) % len(m.sequence)]
trendline_spec.trace_patch["line"] = dict(color=next_color)
patch, fit_results = make_trace_kwargs(
args, trendline_spec, args["data_frame"], {}, sizeref
)
trendline_trace.update(patch)
fig.add_trace(
trendline_trace, row="all", col="all", exclude_empty_subplots=True
)
fig.update_traces(selector=-1, showlegend=True)
if fit_results is not None:
trendline_rows.append(dict(px_fit_results=fit_results))
if trendline_rows:
try:
import pandas as pd
fig._px_trendlines = pd.DataFrame(trendline_rows)
except ImportError:
msg = "Trendlines require pandas to be installed."
raise NotImplementedError(msg)
else:
fig._px_trendlines = []
configure_axes(args, constructor, fig, orders)
configure_animation_controls(args, constructor, fig)
return fig
def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels):
# Build subplot specs
specs = [[dict(type=subplot_type or "domain")] * ncols for _ in range(nrows)]
# Default row/column widths uniform
column_widths = [1.0] * ncols
row_heights = [1.0] * nrows
facet_col_wrap = args.get("facet_col_wrap", 0)
# Build column_widths/row_heights
if subplot_type == "xy":
if args.get("marginal_x") is not None:
if args["marginal_x"] == "histogram" or ("color" in args and args["color"]):
main_size = 0.74
else:
main_size = 0.84
row_heights = [main_size] * (nrows - 1) + [1 - main_size]
vertical_spacing = 0.01
elif facet_col_wrap:
vertical_spacing = args.get("facet_row_spacing") or 0.07
else:
vertical_spacing = args.get("facet_row_spacing") or 0.03
if args.get("marginal_y") is not None:
if args["marginal_y"] == "histogram" or ("color" in args and args["color"]):
main_size = 0.74
else:
main_size = 0.84
column_widths = [main_size] * (ncols - 1) + [1 - main_size]
horizontal_spacing = 0.005
else:
horizontal_spacing = args.get("facet_col_spacing") or 0.02
else:
# Other subplot types:
# 'scene', 'geo', 'polar', 'ternary', 'mapbox', 'domain', None
#
# We can customize subplot spacing per type once we enable faceting
# for all plot types
if facet_col_wrap:
vertical_spacing = args.get("facet_row_spacing") or 0.07
else:
vertical_spacing = args.get("facet_row_spacing") or 0.03
horizontal_spacing = args.get("facet_col_spacing") or 0.02
if facet_col_wrap:
subplot_labels = [None] * nrows * ncols
while len(col_labels) < nrows * ncols:
col_labels.append(None)
for i in range(nrows):
for j in range(ncols):
subplot_labels[i * ncols + j] = col_labels[(nrows - 1 - i) * ncols + j]
def _spacing_error_translator(e, direction, facet_arg):
"""
Translates the spacing errors thrown by the underlying make_subplots
routine into one that describes an argument adjustable through px.
"""
if ("%s spacing" % (direction,)) in e.args[0]:
e.args = (
e.args[0]
+ """
Use the {facet_arg} argument to adjust this spacing.""".format(facet_arg=facet_arg),
)
raise e
# Create figure with subplots
try:
fig = make_subplots(
rows=nrows,
cols=ncols,
specs=specs,
shared_xaxes="all",
shared_yaxes="all",
row_titles=[] if facet_col_wrap else list(reversed(row_labels)),
column_titles=[] if facet_col_wrap else col_labels,
subplot_titles=subplot_labels if facet_col_wrap else [],
horizontal_spacing=horizontal_spacing,
vertical_spacing=vertical_spacing,
row_heights=row_heights,
column_widths=column_widths,
start_cell="bottom-left",
)
except ValueError as e:
_spacing_error_translator(e, "Horizontal", "facet_col_spacing")
_spacing_error_translator(e, "Vertical", "facet_row_spacing")
raise
# Remove explicit font size of row/col titles so template can take over
for annot in fig.layout.annotations:
annot.update(font=None)
return fig