import plotly.graph_objs as go import plotly.io as pio from collections import namedtuple, OrderedDict from ._special_inputs import IdentityMap, Constant, Range from .trendline_functions import ols, lowess, rolling, expanding, ewm from _plotly_utils.basevalidators import ColorscaleValidator from plotly.colors import qualitative, sequential import math from plotly._subplots import ( make_subplots, _set_trace_grid_reference, _subplot_type_for_trace_type, ) import narwhals.stable.v1 as nw # The reason to use narwhals.stable.v1 is to have a stable and perfectly # backwards-compatible API, hence the confidence to not pin the Narwhals version exactly, # allowing for multiple major libraries to have Narwhals as a dependency without # forbidding users to install them all together due to dependency conflicts. NO_COLOR = "px_no_color_constant" trendline_functions = dict( lowess=lowess, rolling=rolling, ewm=ewm, expanding=expanding, ols=ols ) # Declare all supported attributes, across all plot types direct_attrables = ( ["base", "x", "y", "z", "a", "b", "c", "r", "theta", "size", "x_start", "x_end"] + ["hover_name", "text", "names", "values", "parents", "wide_cross"] + ["ids", "error_x", "error_x_minus", "error_y", "error_y_minus", "error_z"] + ["error_z_minus", "lat", "lon", "locations", "animation_group"] ) array_attrables = ["dimensions", "custom_data", "hover_data", "path", "wide_variable"] group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] renameable_group_attrables = [ "color", # renamed to marker.color or line.color in infer_config "symbol", # renamed to marker.symbol in infer_config "line_dash", # renamed to line.dash in infer_config "pattern_shape", # renamed to marker.pattern.shape in infer_config ] all_attrables = ( direct_attrables + array_attrables + group_attrables + renameable_group_attrables ) cartesians = [go.Scatter, go.Scattergl, go.Bar, go.Funnel, go.Box, go.Violin] cartesians += [go.Histogram, go.Histogram2d, go.Histogram2dContour] class PxDefaults(object): __slots__ = [ "template", "width", "height", "color_discrete_sequence", "color_discrete_map", "color_continuous_scale", "symbol_sequence", "symbol_map", "line_dash_sequence", "line_dash_map", "pattern_shape_sequence", "pattern_shape_map", "size_max", "category_orders", "labels", ] def __init__(self): self.reset() def reset(self): self.template = None self.width = None self.height = None self.color_discrete_sequence = None self.color_discrete_map = {} self.color_continuous_scale = None self.symbol_sequence = None self.symbol_map = {} self.line_dash_sequence = None self.line_dash_map = {} self.pattern_shape_sequence = None self.pattern_shape_map = {} self.size_max = 20 self.category_orders = {} self.labels = {} defaults = PxDefaults() del PxDefaults MAPBOX_TOKEN = None def set_mapbox_access_token(token): """ Arguments: token: A Mapbox token to be used in `plotly.express.scatter_mapbox` and \ `plotly.express.line_mapbox` figures. See \ https://docs.mapbox.com/help/how-mapbox-works/access-tokens/ for more details """ global MAPBOX_TOKEN MAPBOX_TOKEN = token def get_trendline_results(fig): """ Extracts fit statistics for trendlines (when applied to figures generated with the `trendline` argument set to `"ols"`). Arguments: fig: the output of a `plotly.express` charting call Returns: A `pandas.DataFrame` with a column "px_fit_results" containing the `statsmodels` results objects, along with columns identifying the subset of the data the trendline was fit on. """ return fig._px_trendlines Mapping = namedtuple( "Mapping", [ "show_in_trace_name", "grouper", "val_map", "sequence", "updater", "variable", "facet", ], ) TraceSpec = namedtuple("TraceSpec", ["constructor", "attrs", "trace_patch", "marginal"]) def get_label(args, column): try: return args["labels"][column] except Exception: return column def invert_label(args, column): """Invert mapping. Find key corresponding to value column in dict args["labels"]. Returns `column` if the value does not exist. """ reversed_labels = {value: key for (key, value) in args["labels"].items()} try: return reversed_labels[column] except Exception: return column def _is_continuous(df: nw.DataFrame, col_name: str) -> bool: if nw.dependencies.is_pandas_like_dataframe(df_native := df.to_native()): # fastpath for pandas: Narwhals' Series.dtype has a bit of overhead, as it # tries to distinguish between true "object" columns, and "string" columns # disguised as "object". But here, we deal with neither. return df_native[col_name].dtype.kind in "ifc" return df.get_column(col_name).dtype.is_numeric() def _to_unix_epoch_seconds(s: nw.Series) -> nw.Series: dtype = s.dtype if dtype == nw.Date: return s.dt.timestamp("ms") / 1_000 if dtype == nw.Datetime: if dtype.time_unit in ("s", "ms"): return s.dt.timestamp("ms") / 1_000 elif dtype.time_unit == "us": return s.dt.timestamp("us") / 1_000_000 elif dtype.time_unit == "ns": return s.dt.timestamp("ns") / 1_000_000_000 else: msg = "Unexpected dtype, please report a bug" raise ValueError(msg) else: msg = f"Expected Date or Datetime, got {dtype}" raise TypeError(msg) def _generate_temporary_column_name(n_bytes, columns) -> str: """Wraps of Narwhals generate_temporary_column_name to generate a token which is guaranteed to not be in columns, nor in [col + token for col in columns] """ counter = 0 while True: # This is guaranteed to not be in columns by Narwhals token = nw.generate_temporary_column_name(n_bytes, columns=columns) # Now check that it is not in the [col + token for col in columns] list if token not in {f"{c}{token}" for c in columns}: return token counter += 1 if counter > 100: msg = ( "Internal Error: Plotly was not able to generate a column name with " f"{n_bytes=} and not in {columns}.\n" "Please report this to " "https://github.com/plotly/plotly.py/issues/new and we will try to " "replicate and fix it." ) raise AssertionError(msg) def get_decorated_label(args, column, role): original_label = label = get_label(args, column) if "histfunc" in args and ( (role == "z") or (role == "x" and "orientation" in args and args["orientation"] == "h") or (role == "y" and "orientation" in args and args["orientation"] == "v") ): histfunc = args["histfunc"] or "count" if histfunc != "count": label = "%s of %s" % (histfunc, label) else: label = "count" if "histnorm" in args and args["histnorm"] is not None: if label == "count": label = args["histnorm"] else: histnorm = args["histnorm"] if histfunc == "sum": if histnorm == "probability": label = "%s of %s" % ("fraction", label) elif histnorm == "percent": label = "%s of %s" % (histnorm, label) else: label = "%s weighted by %s" % (histnorm, original_label) elif histnorm == "probability": label = "%s of sum of %s" % ("fraction", label) elif histnorm == "percent": label = "%s of sum of %s" % ("percent", label) else: label = "%s of %s" % (histnorm, label) if "barnorm" in args and args["barnorm"] is not None: label = "%s (normalized as %s)" % (label, args["barnorm"]) return label def make_mapping(args, variable): if variable == "line_group" or variable == "animation_frame": return Mapping( show_in_trace_name=False, grouper=args[variable], val_map={}, sequence=[""], variable=variable, updater=(lambda trace, v: v), facet=None, ) if variable == "facet_row" or variable == "facet_col": letter = "x" if variable == "facet_col" else "y" return Mapping( show_in_trace_name=False, variable=letter, grouper=args[variable], val_map={}, sequence=[i for i in range(1, 1000)], updater=(lambda trace, v: v), facet="row" if variable == "facet_row" else "col", ) (parent, variable, *other_variables) = variable.split(".") vprefix = variable arg_name = variable if variable == "color": vprefix = "color_discrete" if variable == "dash": arg_name = "line_dash" vprefix = "line_dash" if variable in ["pattern", "shape"]: arg_name = "pattern_shape" vprefix = "pattern_shape" if args[vprefix + "_map"] == "identity": val_map = IdentityMap() else: val_map = args[vprefix + "_map"].copy() return Mapping( show_in_trace_name=True, variable=variable, grouper=args[arg_name], val_map=val_map, sequence=args[vprefix + "_sequence"], updater=lambda trace, v: trace.update( {parent: {".".join([variable] + other_variables): v}} ), facet=None, ) def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): """Populates a dict with arguments to update trace Parameters ---------- args : dict args to be used for the trace trace_spec : NamedTuple which kind of trace to be used (has constructor, marginal etc. attributes) trace_data : pandas DataFrame data mapping_labels : dict to be used for hovertemplate sizeref : float marker sizeref Returns ------- trace_patch : dict dict to be used to update trace fit_results : dict fit information to be used for trendlines """ trace_data: nw.DataFrame df: nw.DataFrame = args["data_frame"] if "line_close" in args and args["line_close"]: trace_data = nw.concat([trace_data, trace_data.head(1)], how="vertical") trace_patch = trace_spec.trace_patch.copy() or {} fit_results = None hover_header = "" for attr_name in trace_spec.attrs: attr_value = args[attr_name] attr_label = get_decorated_label(args, attr_value, attr_name) if attr_name == "dimensions": dims = [ (name, trace_data.get_column(name)) for name in trace_data.columns if ((not attr_value) or (name in attr_value)) and (trace_spec.constructor != go.Parcoords or _is_continuous(df, name)) and ( trace_spec.constructor != go.Parcats or (attr_value is not None and name in attr_value) or nw.to_py_scalar(df.get_column(name).n_unique()) <= args["dimensions_max_cardinality"] ) ] trace_patch["dimensions"] = [ dict(label=get_label(args, name), values=column) for (name, column) in dims ] if trace_spec.constructor == go.Splom: for d in trace_patch["dimensions"]: d["axis"] = dict(matches=True) mapping_labels["%{xaxis.title.text}"] = "%{x}" mapping_labels["%{yaxis.title.text}"] = "%{y}" elif attr_value is not None: if attr_name == "size": if "marker" not in trace_patch: trace_patch["marker"] = dict() trace_patch["marker"]["size"] = trace_data.get_column(attr_value) trace_patch["marker"]["sizemode"] = "area" trace_patch["marker"]["sizeref"] = sizeref mapping_labels[attr_label] = "%{marker.size}" elif attr_name == "marginal_x": if trace_spec.constructor == go.Histogram: mapping_labels["count"] = "%{y}" elif attr_name == "marginal_y": if trace_spec.constructor == go.Histogram: mapping_labels["count"] = "%{x}" elif attr_name == "trendline": if ( args["x"] and args["y"] and len( trace_data.select(nw.col(args["x"], args["y"])).drop_nulls() ) > 1 ): # sorting is bad but trace_specs with "trendline" have no other attrs sorted_trace_data = trace_data.sort(by=args["x"], nulls_last=True) y = sorted_trace_data.get_column(args["y"]) x = sorted_trace_data.get_column(args["x"]) if x.dtype == nw.Datetime or x.dtype == nw.Date: # convert to unix epoch seconds x = _to_unix_epoch_seconds(x) elif not x.dtype.is_numeric(): try: x = x.cast(nw.Float64()) except ValueError: raise ValueError( "Could not convert value of 'x' ('%s') into a numeric type. " "If 'x' contains stringified dates, please convert to a datetime column." % args["x"] ) if not y.dtype.is_numeric(): try: y = y.cast(nw.Float64()) except ValueError: raise ValueError( "Could not convert value of 'y' into a numeric type." ) # preserve original values of "x" in case they're dates # otherwise numpy/pandas can mess with the timezones # NB this means trendline functions must output one-to-one with the input series # i.e. we can't do resampling, because then the X values might not line up! non_missing = ~(x.is_null() | y.is_null()) trace_patch["x"] = sorted_trace_data.filter(non_missing).get_column( args["x"] ) if ( trace_patch["x"].dtype == nw.Datetime and trace_patch["x"].dtype.time_zone is not None ): # Remove time zone so that local time is displayed trace_patch["x"] = ( trace_patch["x"].dt.replace_time_zone(None).to_numpy() ) else: trace_patch["x"] = trace_patch["x"].to_numpy() trendline_function = trendline_functions[attr_value] y_out, hover_header, fit_results = trendline_function( args["trendline_options"], sorted_trace_data.get_column(args["x"]), # narwhals series x.to_numpy(), # numpy array y.to_numpy(), # numpy array args["x"], args["y"], non_missing.to_numpy(), # numpy array ) assert len(y_out) == len(trace_patch["x"]), ( "missing-data-handling failure in trendline code" ) trace_patch["y"] = y_out mapping_labels[get_label(args, args["x"])] = "%{x}" mapping_labels[get_label(args, args["y"])] = "%{y} (trend)" elif attr_name.startswith("error"): error_xy = attr_name[:7] arr = "arrayminus" if attr_name.endswith("minus") else "array" if error_xy not in trace_patch: trace_patch[error_xy] = {} trace_patch[error_xy][arr] = trace_data.get_column(attr_value) elif attr_name == "custom_data": if len(attr_value) > 0: # here we store a data frame in customdata, and it's serialized # as a list of row lists, which is what we want trace_patch["customdata"] = trace_data.select(nw.col(attr_value)) elif attr_name == "hover_name": if trace_spec.constructor not in [ go.Histogram, go.Histogram2d, go.Histogram2dContour, ]: trace_patch["hovertext"] = trace_data.get_column(attr_value) if hover_header == "": hover_header = "%{hovertext}

" elif attr_name == "hover_data": if trace_spec.constructor not in [ go.Histogram, go.Histogram2d, go.Histogram2dContour, ]: hover_is_dict = isinstance(attr_value, dict) customdata_cols = args.get("custom_data") or [] for col in attr_value: if hover_is_dict and not attr_value[col]: continue if col in [ args.get("x"), args.get("y"), args.get("z"), args.get("base"), ]: continue try: position = args["custom_data"].index(col) except (ValueError, AttributeError, KeyError): position = len(customdata_cols) customdata_cols.append(col) attr_label_col = get_decorated_label(args, col, None) mapping_labels[attr_label_col] = "%%{customdata[%d]}" % ( position ) if len(customdata_cols) > 0: # here we store a data frame in customdata, and it's serialized # as a list of row lists, which is what we want # dict.fromkeys(customdata_cols) allows to deduplicate column # names, yet maintaining the original order. trace_patch["customdata"] = trace_data.select( *[nw.col(c) for c in dict.fromkeys(customdata_cols)] ) elif attr_name == "color": if trace_spec.constructor in [ go.Choropleth, go.Choroplethmap, go.Choroplethmapbox, ]: trace_patch["z"] = trace_data.get_column(attr_value) trace_patch["coloraxis"] = "coloraxis1" mapping_labels[attr_label] = "%{z}" elif trace_spec.constructor in [ go.Sunburst, go.Treemap, go.Icicle, go.Pie, go.Funnelarea, ]: if "marker" not in trace_patch: trace_patch["marker"] = dict() if args.get("color_is_continuous"): trace_patch["marker"]["colors"] = trace_data.get_column( attr_value ) trace_patch["marker"]["coloraxis"] = "coloraxis1" mapping_labels[attr_label] = "%{color}" else: trace_patch["marker"]["colors"] = [] if args["color_discrete_map"] is not None: mapping = args["color_discrete_map"].copy() else: mapping = {} for cat in trace_data.get_column(attr_value).to_list(): # although trace_data.get_column(attr_value) is a Narwhals # Series, which is an iterable, explicitly calling a to_list() # makes sure that the elements we loop over are python objects # in all cases, since depending on the backend this may not be # the case (e.g. PyArrow) if mapping.get(cat) is None: mapping[cat] = args["color_discrete_sequence"][ len(mapping) % len(args["color_discrete_sequence"]) ] trace_patch["marker"]["colors"].append(mapping[cat]) else: colorable = "marker" if trace_spec.constructor in [go.Parcats, go.Parcoords]: colorable = "line" if colorable not in trace_patch: trace_patch[colorable] = dict() trace_patch[colorable]["color"] = trace_data.get_column(attr_value) trace_patch[colorable]["coloraxis"] = "coloraxis1" mapping_labels[attr_label] = "%%{%s.color}" % colorable elif attr_name == "animation_group": trace_patch["ids"] = trace_data.get_column(attr_value) elif attr_name == "locations": trace_patch[attr_name] = trace_data.get_column(attr_value) mapping_labels[attr_label] = "%{location}" elif attr_name == "values": trace_patch[attr_name] = trace_data.get_column(attr_value) _label = "value" if attr_label == "values" else attr_label mapping_labels[_label] = "%{value}" elif attr_name == "parents": trace_patch[attr_name] = trace_data.get_column(attr_value) _label = "parent" if attr_label == "parents" else attr_label mapping_labels[_label] = "%{parent}" elif attr_name == "ids": trace_patch[attr_name] = trace_data.get_column(attr_value) _label = "id" if attr_label == "ids" else attr_label mapping_labels[_label] = "%{id}" elif attr_name == "names": if trace_spec.constructor in [ go.Sunburst, go.Treemap, go.Icicle, go.Pie, go.Funnelarea, ]: trace_patch["labels"] = trace_data.get_column(attr_value) _label = "label" if attr_label == "names" else attr_label mapping_labels[_label] = "%{label}" else: trace_patch[attr_name] = trace_data.get_column(attr_value) else: trace_patch[attr_name] = trace_data.get_column(attr_value) mapping_labels[attr_label] = "%%{%s}" % attr_name elif (trace_spec.constructor == go.Histogram and attr_name in ["x", "y"]) or ( trace_spec.constructor in [go.Histogram2d, go.Histogram2dContour] and attr_name == "z" ): # ensure that stuff like "count" gets into the hoverlabel mapping_labels[attr_label] = "%%{%s}" % attr_name if trace_spec.constructor not in [go.Parcoords, go.Parcats]: # Modify mapping_labels according to hover_data keys # if hover_data is a dict mapping_labels_copy = OrderedDict(mapping_labels) if args["hover_data"] and isinstance(args["hover_data"], dict): for k, v in mapping_labels.items(): # We need to invert the mapping here k_args = invert_label(args, k) if k_args in args["hover_data"]: formatter = args["hover_data"][k_args][0] if formatter: if isinstance(formatter, str): mapping_labels_copy[k] = v.replace("}", "%s}" % formatter) else: _ = mapping_labels_copy.pop(k) hover_lines = [k + "=" + v for k, v in mapping_labels_copy.items()] trace_patch["hovertemplate"] = hover_header + "
".join(hover_lines) trace_patch["hovertemplate"] += "" return trace_patch, fit_results def configure_axes(args, constructor, fig, orders): configurators = { go.Scatter3d: configure_3d_axes, go.Scatterternary: configure_ternary_axes, go.Scatterpolar: configure_polar_axes, go.Scatterpolargl: configure_polar_axes, go.Barpolar: configure_polar_axes, go.Scattermap: configure_map, go.Choroplethmap: configure_map, go.Densitymap: configure_map, go.Scattermapbox: configure_mapbox, go.Choroplethmapbox: configure_mapbox, go.Densitymapbox: configure_mapbox, go.Scattergeo: configure_geo, go.Choropleth: configure_geo, } for c in cartesians: configurators[c] = configure_cartesian_axes if constructor in configurators: configurators[constructor](args, fig, orders) def set_cartesian_axis_opts(args, axis, letter, orders): log_key = "log_" + letter range_key = "range_" + letter if log_key in args and args[log_key]: axis["type"] = "log" if range_key in args and args[range_key]: axis["range"] = [math.log(r, 10) for r in args[range_key]] elif range_key in args and args[range_key]: axis["range"] = args[range_key] if args[letter] in orders: axis["categoryorder"] = "array" axis["categoryarray"] = ( orders[args[letter]] if isinstance(axis, go.layout.XAxis) else list(reversed(orders[args[letter]])) # top down for Y axis ) def configure_cartesian_marginal_axes(args, fig, orders): nrows = len(fig._grid_ref) ncols = len(fig._grid_ref[0]) # Set y-axis titles and axis options in the left-most column for yaxis in fig.select_yaxes(col=1): set_cartesian_axis_opts(args, yaxis, "y", orders) # Set x-axis titles and axis options in the bottom-most row for xaxis in fig.select_xaxes(row=1): set_cartesian_axis_opts(args, xaxis, "x", orders) # Configure axis ticks on marginal subplots if args["marginal_x"]: fig.update_yaxes( showticklabels=False, showline=False, ticks="", range=None, row=nrows ) if args["template"].layout.yaxis.showgrid is None: fig.update_yaxes(showgrid=args["marginal_x"] == "histogram", row=nrows) if args["template"].layout.xaxis.showgrid is None: fig.update_xaxes(showgrid=True, row=nrows) if args["marginal_y"]: fig.update_xaxes( showticklabels=False, showline=False, ticks="", range=None, col=ncols ) if args["template"].layout.xaxis.showgrid is None: fig.update_xaxes(showgrid=args["marginal_y"] == "histogram", col=ncols) if args["template"].layout.yaxis.showgrid is None: fig.update_yaxes(showgrid=True, col=ncols) # Add axis titles to non-marginal subplots y_title = get_decorated_label(args, args["y"], "y") if args["marginal_x"]: fig.update_yaxes(title_text=y_title, row=1, col=1) else: for row in range(1, nrows + 1): fig.update_yaxes(title_text=y_title, row=row, col=1) x_title = get_decorated_label(args, args["x"], "x") if args["marginal_y"]: fig.update_xaxes(title_text=x_title, row=1, col=1) else: for col in range(1, ncols + 1): fig.update_xaxes(title_text=x_title, row=1, col=col) # Configure axis type across all x-axes if "log_x" in args and args["log_x"]: fig.update_xaxes(type="log") # Configure axis type across all y-axes if "log_y" in args and args["log_y"]: fig.update_yaxes(type="log") # Configure matching and axis type for marginal y-axes matches_y = "y" + str(ncols + 1) if args["marginal_x"]: for row in range(2, nrows + 1, 2): fig.update_yaxes(matches=matches_y, type=None, row=row) if args["marginal_y"]: for col in range(2, ncols + 1, 2): fig.update_xaxes(matches="x2", type=None, col=col) def configure_cartesian_axes(args, fig, orders): if ("marginal_x" in args and args["marginal_x"]) or ( "marginal_y" in args and args["marginal_y"] ): configure_cartesian_marginal_axes(args, fig, orders) return # Set y-axis titles and axis options in the left-most column y_title = get_decorated_label(args, args["y"], "y") for yaxis in fig.select_yaxes(col=1): yaxis.update(title_text=y_title) set_cartesian_axis_opts(args, yaxis, "y", orders) # Set x-axis titles and axis options in the bottom-most row x_title = get_decorated_label(args, args["x"], "x") for xaxis in fig.select_xaxes(row=1): if "is_timeline" not in args: xaxis.update(title_text=x_title) set_cartesian_axis_opts(args, xaxis, "x", orders) # Configure axis type across all x-axes if "log_x" in args and args["log_x"]: fig.update_xaxes(type="log") # Configure axis type across all y-axes if "log_y" in args and args["log_y"]: fig.update_yaxes(type="log") if "is_timeline" in args: fig.update_xaxes(type="date") if "ecdfmode" in args: if args["orientation"] == "v": fig.update_yaxes(rangemode="tozero") else: fig.update_xaxes(rangemode="tozero") def configure_ternary_axes(args, fig, orders): fig.update_ternaries( aaxis=dict(title_text=get_label(args, args["a"])), baxis=dict(title_text=get_label(args, args["b"])), caxis=dict(title_text=get_label(args, args["c"])), ) def configure_polar_axes(args, fig, orders): patch = dict( angularaxis=dict(direction=args["direction"], rotation=args["start_angle"]), radialaxis=dict(), ) for var, axis in [("r", "radialaxis"), ("theta", "angularaxis")]: if args[var] in orders: patch[axis]["categoryorder"] = "array" patch[axis]["categoryarray"] = orders[args[var]] radialaxis = patch["radialaxis"] if args["log_r"]: radialaxis["type"] = "log" if args["range_r"]: radialaxis["range"] = [math.log(x, 10) for x in args["range_r"]] else: if args["range_r"]: radialaxis["range"] = args["range_r"] if args["range_theta"]: patch["sector"] = args["range_theta"] fig.update_polars(patch) def configure_3d_axes(args, fig, orders): patch = dict( xaxis=dict(title_text=get_label(args, args["x"])), yaxis=dict(title_text=get_label(args, args["y"])), zaxis=dict(title_text=get_label(args, args["z"])), ) for letter in ["x", "y", "z"]: axis = patch[letter + "axis"] if args["log_" + letter]: axis["type"] = "log" if args["range_" + letter]: axis["range"] = [math.log(x, 10) for x in args["range_" + letter]] else: if args["range_" + letter]: axis["range"] = args["range_" + letter] if args[letter] in orders: axis["categoryorder"] = "array" axis["categoryarray"] = orders[args[letter]] fig.update_scenes(patch) def configure_mapbox(args, fig, orders): center = args["center"] if not center and "lat" in args and "lon" in args: center = dict( lat=args["data_frame"][args["lat"]].mean(), lon=args["data_frame"][args["lon"]].mean(), ) fig.update_mapboxes( accesstoken=MAPBOX_TOKEN, center=center, zoom=args["zoom"], style=args["mapbox_style"], ) def configure_map(args, fig, orders): center = args["center"] if not center and "lat" in args and "lon" in args: center = dict( lat=args["data_frame"][args["lat"]].mean(), lon=args["data_frame"][args["lon"]].mean(), ) fig.update_maps( center=center, zoom=args["zoom"], style=args["map_style"], ) def configure_geo(args, fig, orders): fig.update_geos( center=args["center"], scope=args["scope"], fitbounds=args["fitbounds"], visible=args["basemap_visible"], projection=dict(type=args["projection"]), ) def configure_animation_controls(args, constructor, fig): def frame_args(duration): return { "frame": {"duration": duration, "redraw": constructor != go.Scatter}, "mode": "immediate", "fromcurrent": True, "transition": {"duration": duration, "easing": "linear"}, } if "animation_frame" in args and args["animation_frame"] and len(fig.frames) > 1: fig.layout.updatemenus = [ { "buttons": [ { "args": [None, frame_args(500)], "label": "▶", "method": "animate", }, { "args": [[None], frame_args(0)], "label": "◼", "method": "animate", }, ], "direction": "left", "pad": {"r": 10, "t": 70}, "showactive": False, "type": "buttons", "x": 0.1, "xanchor": "right", "y": 0, "yanchor": "top", } ] fig.layout.sliders = [ { "active": 0, "yanchor": "top", "xanchor": "left", "currentvalue": { "prefix": get_label(args, args["animation_frame"]) + "=" }, "pad": {"b": 10, "t": 60}, "len": 0.9, "x": 0.1, "y": 0, "steps": [ { "args": [[f.name], frame_args(0)], "label": f.name, "method": "animate", } for f in fig.frames ], } ] def make_trace_spec(args, constructor, attrs, trace_patch): if constructor in [go.Scatter, go.Scatterpolar]: if "render_mode" in args and ( args["render_mode"] == "webgl" or ( args["render_mode"] == "auto" and len(args["data_frame"]) > 1000 and args.get("line_shape") != "spline" and args["animation_frame"] is None ) ): if constructor == go.Scatter: constructor = go.Scattergl if "orientation" in trace_patch: del trace_patch["orientation"] else: constructor = go.Scatterpolargl # Create base trace specification result = [TraceSpec(constructor, attrs, trace_patch, None)] # Add marginal trace specifications for letter in ["x", "y"]: if "marginal_" + letter in args and args["marginal_" + letter]: trace_spec = None axis_map = dict( xaxis="x1" if letter == "x" else "x2", yaxis="y1" if letter == "y" else "y2", ) if args["marginal_" + letter] == "histogram": trace_spec = TraceSpec( constructor=go.Histogram, attrs=[letter, "marginal_" + letter], trace_patch=dict(opacity=0.5, bingroup=letter, **axis_map), marginal=letter, ) elif args["marginal_" + letter] == "violin": trace_spec = TraceSpec( constructor=go.Violin, attrs=[letter, "hover_name", "hover_data"], trace_patch=dict(scalegroup=letter), marginal=letter, ) elif args["marginal_" + letter] == "box": trace_spec = TraceSpec( constructor=go.Box, attrs=[letter, "hover_name", "hover_data"], trace_patch=dict(notched=True), marginal=letter, ) elif args["marginal_" + letter] == "rug": symbols = {"x": "line-ns-open", "y": "line-ew-open"} trace_spec = TraceSpec( constructor=go.Box, attrs=[letter, "hover_name", "hover_data"], trace_patch=dict( fillcolor="rgba(255,255,255,0)", line={"color": "rgba(255,255,255,0)"}, boxpoints="all", jitter=0, hoveron="points", marker={"symbol": symbols[letter]}, ), marginal=letter, ) if "color" in attrs or "color" not in args: if "marker" not in trace_spec.trace_patch: trace_spec.trace_patch["marker"] = dict() first_default_color = args["color_continuous_scale"][0] trace_spec.trace_patch["marker"]["color"] = first_default_color result.append(trace_spec) # Add trendline trace specifications if args.get("trendline") and args.get("trendline_scope", "trace") == "trace": result.append(make_trendline_spec(args, constructor)) return result def make_trendline_spec(args, constructor): trace_spec = TraceSpec( constructor=( go.Scattergl if constructor == go.Scattergl # could be contour else go.Scatter ), attrs=["trendline"], trace_patch=dict(mode="lines"), marginal=None, ) if args["trendline_color_override"]: trace_spec.trace_patch["line"] = dict(color=args["trendline_color_override"]) return trace_spec def one_group(x): return "" def apply_default_cascade(args): # first we apply px.defaults to unspecified args for param in defaults.__slots__: if param in args and args[param] is None: args[param] = getattr(defaults, param) # load the default template if set, otherwise "plotly" if args["template"] is None: if pio.templates.default is not None: args["template"] = pio.templates.default else: args["template"] = "plotly" try: # retrieve the actual template if we were given a name args["template"] = pio.templates[args["template"]] except Exception: # otherwise try to build a real template args["template"] = go.layout.Template(args["template"]) # if colors not set explicitly or in px.defaults, defer to a template # if the template doesn't have one, we set some final fallback defaults if "color_continuous_scale" in args: if ( args["color_continuous_scale"] is None and args["template"].layout.colorscale.sequential ): args["color_continuous_scale"] = [ x[1] for x in args["template"].layout.colorscale.sequential ] if args["color_continuous_scale"] is None: args["color_continuous_scale"] = sequential.Viridis if "color_discrete_sequence" in args: if args["color_discrete_sequence"] is None and args["template"].layout.colorway: args["color_discrete_sequence"] = args["template"].layout.colorway if args["color_discrete_sequence"] is None: args["color_discrete_sequence"] = qualitative.D3 # if symbol_sequence/line_dash_sequence not set explicitly or in px.defaults, # see if we can defer to template. If not, set reasonable defaults if "symbol_sequence" in args: if args["symbol_sequence"] is None and args["template"].data.scatter: args["symbol_sequence"] = [ scatter.marker.symbol for scatter in args["template"].data.scatter ] if not args["symbol_sequence"] or not any(args["symbol_sequence"]): args["symbol_sequence"] = ["circle", "diamond", "square", "x", "cross"] if "line_dash_sequence" in args: if args["line_dash_sequence"] is None and args["template"].data.scatter: args["line_dash_sequence"] = [ scatter.line.dash for scatter in args["template"].data.scatter ] if not args["line_dash_sequence"] or not any(args["line_dash_sequence"]): args["line_dash_sequence"] = [ "solid", "dot", "dash", "longdash", "dashdot", "longdashdot", ] if "pattern_shape_sequence" in args: if args["pattern_shape_sequence"] is None and args["template"].data.bar: args["pattern_shape_sequence"] = [ bar.marker.pattern.shape for bar in args["template"].data.bar ] if not args["pattern_shape_sequence"] or not any( args["pattern_shape_sequence"] ): args["pattern_shape_sequence"] = ["", "/", "\\", "x", "+", "."] def _check_name_not_reserved(field_name, reserved_names): if field_name not in reserved_names: return field_name else: raise NameError( "A name conflict was encountered for argument '%s'. " "A column or index with name '%s' is ambiguous." % (field_name, field_name) ) def _get_reserved_col_names(args): """ This function builds a list of columns of the data_frame argument used as arguments, either as str/int arguments or given as columns (pandas series type). """ df: nw.DataFrame = args["data_frame"] reserved_names = set() for field in args: if field not in all_attrables: continue names = args[field] if field in array_attrables else [args[field]] if names is None: continue for arg in names: if arg is None: continue elif isinstance(arg, str): # no need to add ints since kw arg are not ints reserved_names.add(arg) elif nw.dependencies.is_into_series(arg): arg_series = nw.from_native(arg, series_only=True) arg_name = arg_series.name if arg_name and arg_name in df.columns: in_df = (arg_series == df.get_column(arg_name)).all() if in_df: reserved_names.add(arg_name) elif arg is nw.maybe_get_index(df) and arg.name is not None: reserved_names.add(arg.name) return reserved_names def _is_col_list(columns, arg, is_pd_like, native_namespace): """Returns True if arg looks like it's a list of columns or references to columns in df_input, and False otherwise (in which case it's assumed to be a single column or reference to a column). """ if arg is None or isinstance(arg, str) or isinstance(arg, int): return False if is_pd_like and isinstance(arg, native_namespace.MultiIndex): return False # just to keep existing behaviour for now try: iter(arg) except TypeError: return False # not iterable for c in arg: if isinstance(c, str) or isinstance(c, int): if columns is None or c not in columns: return False else: try: iter(c) except TypeError: return False # not iterable return True def _isinstance_listlike(x): """Returns True if x is an iterable which can be transformed into a pandas Series, False for the other types of possible values of a `hover_data` dict. A tuple of length 2 is a special case corresponding to a (format, data) tuple. """ if ( isinstance(x, str) or (isinstance(x, tuple) and len(x) == 2) or isinstance(x, bool) or x is None ): return False else: return True def _escape_col_name(columns, col_name, extra): if columns is None: return col_name while col_name in columns or col_name in extra: col_name = "_" + col_name return col_name def to_named_series(x, name=None, native_namespace=None): """Assuming x is list-like or even an existing Series, returns a new Series named `name`.""" # With `pass_through=True`, the original object will be returned if unable to convert # to a Narwhals Series. x = nw.from_native(x, series_only=True, pass_through=True) if isinstance(x, nw.Series): return x.rename(name) elif native_namespace is not None: return nw.new_series(name=name, values=x, native_namespace=native_namespace) else: try: import pandas as pd return nw.new_series(name=name, values=x, native_namespace=pd) except ImportError: msg = "Pandas installation is required if no dataframe is provided." raise NotImplementedError(msg) def process_args_into_dataframe( args, wide_mode, var_name, value_name, is_pd_like, native_namespace ): """ After this function runs, the `all_attrables` keys of `args` all contain only references to columns of `df_output`. This function handles the extraction of data from `args["attrable"]` and column-name-generation as appropriate, and adds the data to `df_output` and then replaces `args["attrable"]` with the appropriate reference. """ df_input: nw.DataFrame | None = args["data_frame"] df_provided = df_input is not None # we use a dict instead of a dataframe directly so that it doesn't cause # PerformanceWarning by pandas by repeatedly setting the columns. # a dict is used instead of a list as the columns needs to be overwritten. df_output = {} constants = {} ranges = [] wide_id_vars = set() reserved_names = _get_reserved_col_names(args) if df_provided else set() # Case of functions with a "dimensions" kw: scatter_matrix, parcats, parcoords if "dimensions" in args and args["dimensions"] is None: if not df_provided: raise ValueError( "No data were provided. Please provide data either with the `data_frame` or with the `dimensions` argument." ) else: df_output = {col: df_input.get_column(col) for col in df_input.columns} # hover_data is a dict hover_data_is_dict = ( "hover_data" in args and args["hover_data"] and isinstance(args["hover_data"], dict) ) # If dict, convert all values of hover_data to tuples to simplify processing if hover_data_is_dict: for k in args["hover_data"]: if _isinstance_listlike(args["hover_data"][k]): args["hover_data"][k] = (True, args["hover_data"][k]) if not isinstance(args["hover_data"][k], tuple): args["hover_data"][k] = (args["hover_data"][k], None) if df_provided and args["hover_data"][k][1] is not None and k in df_input: raise ValueError( "Ambiguous input: values for '%s' appear both in hover_data and data_frame" % k ) # Loop over possible arguments for field_name in all_attrables: # Massaging variables argument_list = ( [args.get(field_name)] if field_name not in array_attrables else args.get(field_name) ) # argument not specified, continue # The original also tested `or argument_list is [None]` but # that clause is always False, so it has been removed. The # alternative fix would have been to test that `argument_list` # is of length 1 and its sole element is `None`, but that # feels pedantic. All tests pass with the change below; let's # see if the world decides we were wrong. if argument_list is None: continue # Argument name: field_name if the argument is not a list # Else we give names like ["hover_data_0, hover_data_1"] etc. field_list = ( [field_name] if field_name not in array_attrables else [field_name + "_" + str(i) for i in range(len(argument_list))] ) # argument_list and field_list ready, iterate over them # Core of the loop starts here for i, (argument, field) in enumerate(zip(argument_list, field_list)): length = len(df_output[next(iter(df_output))]) if len(df_output) else 0 if argument is None: continue col_name = None # Case of multiindex if is_pd_like and isinstance(argument, native_namespace.MultiIndex): raise TypeError( f"Argument '{field}' is a {native_namespace.__name__} MultiIndex. " f"{native_namespace.__name__} MultiIndex is not supported by plotly " "express at the moment." ) # ----------------- argument is a special value ---------------------- if isinstance(argument, (Constant, Range)): col_name = _check_name_not_reserved( str(argument.label) if argument.label is not None else field, reserved_names, ) if isinstance(argument, Constant): constants[col_name] = argument.value else: ranges.append(col_name) # ----------------- argument is likely a col name ---------------------- elif isinstance(argument, str) or not hasattr(argument, "__len__"): if ( field_name == "hover_data" and hover_data_is_dict and args["hover_data"][str(argument)][1] is not None ): # hover_data has onboard data # previously-checked to have no name-conflict with data_frame col_name = str(argument) real_argument = args["hover_data"][col_name][1] if length and (real_length := len(real_argument)) != length: raise ValueError( "All arguments should have the same length. " "The length of hover_data key `%s` is %d, whereas the " "length of previously-processed arguments %s is %d" % ( argument, real_length, str(list(df_output.keys())), length, ) ) df_output[col_name] = to_named_series( real_argument, col_name, native_namespace ) elif not df_provided: raise ValueError( "String or int arguments are only possible when a " "DataFrame or an array is provided in the `data_frame` " "argument. No DataFrame was provided, but argument " "'%s' is of type str or int." % field ) # Check validity of column name elif argument not in df_input.columns: if wide_mode and argument in (value_name, var_name): continue else: err_msg = ( "Value of '%s' is not the name of a column in 'data_frame'. " "Expected one of %s but received: %s" % (field, str(list(df_input.columns)), argument) ) if argument == "index": err_msg += "\n To use the index, pass it in directly as `df.index`." raise ValueError(err_msg) elif length and (actual_len := len(df_input)) != length: raise ValueError( "All arguments should have the same length. " "The length of column argument `df[%s]` is %d, whereas the " "length of previously-processed arguments %s is %d" % ( field, actual_len, str(list(df_output.keys())), length, ) ) else: col_name = str(argument) df_output[col_name] = to_named_series( df_input.get_column(argument), col_name ) # ----------------- argument is likely a column / array / list.... ------- else: if df_provided and hasattr(argument, "name"): if is_pd_like and argument is nw.maybe_get_index(df_input): if argument.name is None or argument.name in df_input.columns: col_name = "index" else: col_name = argument.name col_name = _escape_col_name( df_input.columns, col_name, [var_name, value_name] ) else: if ( argument.name is not None and argument.name in df_input.columns and ( to_named_series( argument, argument.name, native_namespace ) == df_input.get_column(argument.name) ).all() ): col_name = argument.name if col_name is None: # numpy array, list... col_name = _check_name_not_reserved(field, reserved_names) if length and (len_arg := len(argument)) != length: raise ValueError( "All arguments should have the same length. " "The length of argument `%s` is %d, whereas the " "length of previously-processed arguments %s is %d" % (field, len_arg, str(list(df_output.keys())), length) ) df_output[str(col_name)] = to_named_series( x=argument, name=str(col_name), native_namespace=native_namespace, ) # Finally, update argument with column name now that column exists assert col_name is not None, ( "Data-frame processing failure, likely due to a internal bug. " "Please report this to " "https://github.com/plotly/plotly.py/issues/new and we will try to " "replicate and fix it." ) if field_name not in array_attrables: args[field_name] = str(col_name) elif isinstance(args[field_name], dict): pass else: args[field_name][i] = str(col_name) if field_name != "wide_variable": wide_id_vars.add(str(col_name)) length = len(df_output[next(iter(df_output))]) if len(df_output) else 0 if native_namespace is None: try: import pandas as pd native_namespace = pd except ImportError: msg = "Pandas installation is required if no dataframe is provided." raise NotImplementedError(msg) if ranges: import numpy as np range_series = nw.new_series( name="__placeholder__", values=np.arange(length), native_namespace=native_namespace, ) df_output.update( {col_name: range_series.alias(col_name) for col_name in ranges} ) df_output.update( { # constant is single value. repeat by len to avoid creating NaN on concatenating col_name: nw.new_series( name=col_name, values=[constants[col_name]] * length, native_namespace=native_namespace, ) for col_name in constants } ) if df_output: df_output = nw.from_dict(df_output) else: try: import pandas as pd except ImportError: msg = "Pandas installation is required." raise NotImplementedError(msg) df_output = nw.from_native(pd.DataFrame({}), eager_only=True) return df_output, wide_id_vars def build_dataframe(args, constructor): """ Constructs a dataframe and modifies `args` in-place. The argument values in `args` can be either strings corresponding to existing columns of a dataframe, or data arrays (lists, numpy arrays, pandas columns, series). Parameters ---------- args : OrderedDict arguments passed to the px function and subsequently modified constructor : graph_object trace class the trace type selected for this figure """ # make copies of all the fields via dict() and list() for field in args: if field in array_attrables and args[field] is not None: if isinstance(args[field], dict): args[field] = dict(args[field]) elif field in ["custom_data", "hover_data"] and isinstance( args[field], str ): args[field] = [args[field]] else: args[field] = list(args[field]) # Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.) df_provided = args["data_frame"] is not None # Flag that indicates if the resulting data_frame after parsing is pandas-like # (in terms of resulting Narwhals DataFrame). # True if pandas, modin.pandas or cudf DataFrame/Series instance, or converted from # PySpark to pandas. is_pd_like = False # Flag that indicates if data_frame needs to be converted to PyArrow. # True if Ibis, DuckDB, Vaex, or implements __dataframe__ needs_interchanging = False # If data_frame is provided, we parse it into a narwhals DataFrame, while accounting # for compatibility with pandas specific paths (e.g. Index/MultiIndex case). if df_provided: # data_frame is pandas-like DataFrame (pandas, modin.pandas, cudf) if nw.dependencies.is_pandas_like_dataframe(args["data_frame"]): columns = args["data_frame"].columns # This can be multi index args["data_frame"] = nw.from_native(args["data_frame"], eager_only=True) is_pd_like = True # data_frame is pandas-like Series (pandas, modin.pandas, cudf) elif nw.dependencies.is_pandas_like_series(args["data_frame"]): args["data_frame"] = nw.from_native( args["data_frame"], series_only=True ).to_frame() columns = args["data_frame"].columns is_pd_like = True # data_frame is any other DataFrame object natively supported via Narwhals. # With `pass_through=True`, the original object will be returned if unable to convert # to a Narwhals DataFrame, making this condition False. elif isinstance( data_frame := nw.from_native( args["data_frame"], eager_or_interchange_only=True, pass_through=True ), nw.DataFrame, ): args["data_frame"] = data_frame needs_interchanging = nw.get_level(data_frame) == "interchange" columns = args["data_frame"].columns # data_frame is any other Series object natively supported via Narwhals. # With `pass_through=True`, the original object will be returned if unable to convert # to a Narwhals Series, making this condition False. elif isinstance( series := nw.from_native( args["data_frame"], series_only=True, pass_through=True ), nw.Series, ): args["data_frame"] = series.to_frame() columns = args["data_frame"].columns # data_frame is PySpark: it does not support interchange protocol and it is not # integrated in Narwhals. We use its native method to convert it to pandas. elif hasattr(args["data_frame"], "toPandas"): args["data_frame"] = nw.from_native( args["data_frame"].toPandas(), eager_only=True ) columns = args["data_frame"].columns is_pd_like = True # data_frame is some other object type (e.g. dict, list, ...) # We try to import pandas, and then try to instantiate a pandas dataframe from # this such object else: try: import pandas as pd try: args["data_frame"] = nw.from_native( pd.DataFrame(args["data_frame"]) ) columns = args["data_frame"].columns is_pd_like = True except Exception: msg = ( f"Unable to convert data_frame of type {type(args['data_frame'])} " "to pandas DataFrame. Please provide a supported dataframe type " "or a type that can be passed to pd.DataFrame." ) raise NotImplementedError(msg) except ImportError: msg = ( f"Attempting to convert data_frame of type {type(args['data_frame'])} " "to pandas DataFrame, but Pandas is not installed. " "Convert it to supported dataframe type or install pandas." ) raise NotImplementedError(msg) # data_frame is not provided else: columns = None df_input: nw.DataFrame | None = args["data_frame"] index = ( nw.maybe_get_index(df_input) if df_provided and not needs_interchanging else None ) native_namespace = ( nw.get_native_namespace(df_input) if df_provided and not needs_interchanging else None ) # now we handle special cases like wide-mode or x-xor-y specification # by rearranging args to tee things up for process_args_into_dataframe to work no_x = args.get("x") is None no_y = args.get("y") is None wide_x = ( False if no_x else _is_col_list(columns, args["x"], is_pd_like, native_namespace) ) wide_y = ( False if no_y else _is_col_list(columns, args["y"], is_pd_like, native_namespace) ) wide_mode = False var_name = None # will likely be "variable" in wide_mode wide_cross_name = None # will likely be "index" in wide_mode value_name = None # will likely be "value" in wide_mode hist2d_types = [go.Histogram2d, go.Histogram2dContour] hist1d_orientation = constructor == go.Histogram or "ecdfmode" in args if constructor in cartesians: if wide_x and wide_y: raise ValueError( "Cannot accept list of column references or list of columns for both `x` and `y`." ) if df_provided and no_x and no_y: wide_mode = True if is_pd_like and isinstance(columns, native_namespace.MultiIndex): raise TypeError( f"Data frame columns is a {native_namespace.__name__} MultiIndex. " f"{native_namespace.__name__} MultiIndex is not supported by plotly " "express at the moment." ) args["wide_variable"] = list(columns) if is_pd_like and isinstance(columns, native_namespace.Index): var_name = columns.name else: var_name = None if var_name in [None, "value", "index"] or var_name in columns: var_name = "variable" if constructor == go.Funnel: wide_orientation = args.get("orientation") or "h" else: wide_orientation = args.get("orientation") or "v" args["orientation"] = wide_orientation args["wide_cross"] = None elif wide_x != wide_y: wide_mode = True args["wide_variable"] = args["y"] if wide_y else args["x"] if df_provided and is_pd_like and args["wide_variable"] is columns: var_name = columns.name if is_pd_like and isinstance(args["wide_variable"], native_namespace.Index): args["wide_variable"] = list(args["wide_variable"]) if var_name in [None, "value", "index"] or ( df_provided and var_name in columns ): var_name = "variable" if hist1d_orientation: wide_orientation = "v" if wide_x else "h" else: wide_orientation = "v" if wide_y else "h" args["y" if wide_y else "x"] = None args["wide_cross"] = None if not no_x and not no_y: wide_cross_name = "__x__" if wide_y else "__y__" if wide_mode: value_name = _escape_col_name(columns, "value", []) var_name = _escape_col_name(columns, var_name, []) # If the data_frame has interchange-only support levelin Narwhals, then we need to # convert it to a full support level backend. # Hence we convert requires Interchange to PyArrow. if needs_interchanging: if wide_mode: args["data_frame"] = nw.from_native( args["data_frame"].to_arrow(), eager_only=True ) else: # Save precious resources by only interchanging columns that are # actually going to be plotted. This is tricky to do in the general case, # because Plotly allows calls like `px.line(df, x='x', y=['y1', df['y1']])`, # but interchange-only objects (e.g. DuckDB) don't typically have a concept # of self-standing Series. It's more important to perform project pushdown # here seeing as we're materialising to an (eager) PyArrow table. necessary_columns = { i for i in args.values() if isinstance(i, str) and i in columns } for field in args: if args[field] is not None and field in array_attrables: necessary_columns.update(i for i in args[field] if i in columns) columns = list(necessary_columns) args["data_frame"] = nw.from_native( args["data_frame"].select(columns).to_arrow(), eager_only=True ) import pyarrow as pa native_namespace = pa missing_bar_dim = None if ( constructor in [go.Scatter, go.Bar, go.Funnel] + hist2d_types and not hist1d_orientation ): if not wide_mode and (no_x != no_y): for ax in ["x", "y"]: if args.get(ax) is None: args[ax] = ( index if index is not None else Range( label=_escape_col_name(columns, ax, [var_name, value_name]) ) ) if constructor == go.Bar: missing_bar_dim = ax else: if args["orientation"] is None: args["orientation"] = "v" if ax == "x" else "h" if wide_mode and wide_cross_name is None: if no_x != no_y and args["orientation"] is None: args["orientation"] = "v" if no_x else "h" if df_provided and is_pd_like and index is not None: if isinstance(index, native_namespace.MultiIndex): raise TypeError( f"Data frame index is a {native_namespace.__name__} MultiIndex. " f"{native_namespace.__name__} MultiIndex is not supported by " "plotly express at the moment." ) args["wide_cross"] = index else: args["wide_cross"] = Range( label=_escape_col_name(columns, "index", [var_name, value_name]) ) no_color = False if isinstance(args.get("color"), str) and args["color"] == NO_COLOR: no_color = True args["color"] = None # now that things have been prepped, we do the systematic rewriting of `args` df_output, wide_id_vars = process_args_into_dataframe( args, wide_mode, var_name, value_name, is_pd_like, native_namespace, ) df_output: nw.DataFrame # now that `df_output` exists and `args` contains only references, we complete # the special-case and wide-mode handling by further rewriting args and/or mutating # df_output count_name = _escape_col_name(df_output.columns, "count", [var_name, value_name]) if not wide_mode and missing_bar_dim and constructor == go.Bar: # now that we've populated df_output, we check to see if the non-missing # dimension is categorical: if so, then setting the missing dimension to a # constant 1 is a less-insane thing to do than setting it to the index by # default and we let the normal auto-orientation-code do its thing later other_dim = "x" if missing_bar_dim == "y" else "y" if not _is_continuous(df_output, args[other_dim]): args[missing_bar_dim] = count_name df_output = df_output.with_columns(nw.lit(1).alias(count_name)) else: # on the other hand, if the non-missing dimension is continuous, then we # can use this information to override the normal auto-orientation code if args["orientation"] is None: args["orientation"] = "v" if missing_bar_dim == "x" else "h" if constructor in hist2d_types: del args["orientation"] if wide_mode: # at this point, `df_output` is semi-long/semi-wide, but we know which columns # are which, so we melt it and reassign `args` to refer to the newly-tidy # columns, keeping track of various names and manglings set up above wide_value_vars = [c for c in args["wide_variable"] if c not in wide_id_vars] del args["wide_variable"] if wide_cross_name == "__x__": wide_cross_name = args["x"] elif wide_cross_name == "__y__": wide_cross_name = args["y"] else: wide_cross_name = args["wide_cross"] del args["wide_cross"] dtype = None for v in wide_value_vars: v_dtype = df_output.get_column(v).dtype v_dtype = "number" if v_dtype.is_numeric() else str(v_dtype) if dtype is None: dtype = v_dtype elif dtype != v_dtype: raise ValueError( "Plotly Express cannot process wide-form data with columns of different type." ) df_output = df_output.unpivot( index=wide_id_vars, on=wide_value_vars, variable_name=var_name, value_name=value_name, ) assert len(df_output.columns) == len(set(df_output.columns)), ( "Wide-mode name-inference failure, likely due to a internal bug. " "Please report this to " "https://github.com/plotly/plotly.py/issues/new and we will try to " "replicate and fix it." ) df_output = df_output.with_columns(nw.col(var_name).cast(nw.String)) orient_v = wide_orientation == "v" if hist1d_orientation: args["x" if orient_v else "y"] = value_name args["y" if orient_v else "x"] = wide_cross_name args["color"] = args["color"] or var_name elif constructor in [go.Scatter, go.Funnel] + hist2d_types: args["x" if orient_v else "y"] = wide_cross_name args["y" if orient_v else "x"] = value_name if constructor != go.Histogram2d: args["color"] = args["color"] or var_name if "line_group" in args: args["line_group"] = args["line_group"] or var_name elif constructor == go.Bar: if _is_continuous(df_output, value_name): args["x" if orient_v else "y"] = wide_cross_name args["y" if orient_v else "x"] = value_name args["color"] = args["color"] or var_name else: args["x" if orient_v else "y"] = value_name args["y" if orient_v else "x"] = count_name df_output = df_output.with_columns(nw.lit(1).alias(count_name)) args["color"] = args["color"] or var_name elif constructor in [go.Violin, go.Box]: args["x" if orient_v else "y"] = wide_cross_name or var_name args["y" if orient_v else "x"] = value_name if hist1d_orientation and constructor == go.Scatter: if args["x"] is not None and args["y"] is not None: args["histfunc"] = "sum" elif args["x"] is None: args["histfunc"] = None args["orientation"] = "h" args["x"] = count_name df_output = df_output.with_columns(nw.lit(1).alias(count_name)) else: args["histfunc"] = None args["orientation"] = "v" args["y"] = count_name df_output = df_output.with_columns(nw.lit(1).alias(count_name)) if no_color: args["color"] = None args["data_frame"] = df_output return args def _check_dataframe_all_leaves(df: nw.DataFrame) -> None: cols = df.columns df_sorted = df.sort(by=cols, descending=False, nulls_last=True) null_mask = df_sorted.select(nw.all().is_null()) df_sorted = df_sorted.select(nw.all().cast(nw.String())) null_indices_mask = null_mask.select( null_mask=nw.any_horizontal(nw.all()) ).get_column("null_mask") null_mask_filtered = null_mask.filter(null_indices_mask) if not null_mask_filtered.is_empty(): for col_idx in range(1, null_mask_filtered.shape[1]): # For each row, if a True value is encountered, then check that # all values in subsequent columns are also True null_entries_with_non_null_children = ( ~null_mask_filtered[:, col_idx] & null_mask_filtered[:, col_idx - 1] ) if nw.to_py_scalar(null_entries_with_non_null_children.any()): row_idx = null_entries_with_non_null_children.to_list().index(True) raise ValueError( "None entries cannot have not-None children", df_sorted.row(row_idx), ) fill_series = nw.new_series( name="fill_value", values=[""] * len(df_sorted), dtype=nw.String(), native_namespace=nw.get_native_namespace(df_sorted), ) df_sorted = df_sorted.with_columns( **{ c: df_sorted.get_column(c).zip_with(~null_mask.get_column(c), fill_series) for c in cols } ) # Conversion to list is due to python native vs pyarrow scalars row_strings = ( df_sorted.select( row_strings=nw.concat_str(cols, separator="", ignore_nulls=False) ) .get_column("row_strings") .to_list() ) null_indices = set(null_indices_mask.arg_true().to_list()) for i, (current_row, next_row) in enumerate( zip(row_strings[:-1], row_strings[1:]), start=1 ): if (next_row in current_row) and (i in null_indices): raise ValueError( "Non-leaves rows are not permitted in the dataframe \n", df_sorted.row(i), "is not a leaf.", ) def process_dataframe_hierarchy(args): """ Build dataframe for sunburst, treemap, or icicle when the path argument is provided. """ df: nw.DataFrame = args["data_frame"] path = args["path"][::-1] _check_dataframe_all_leaves(df[path[::-1]]) discrete_color = not _is_continuous(df, args["color"]) if args["color"] else False df = df.lazy() new_path = [col_name + "_path_copy" for col_name in path] df = df.with_columns( nw.col(col_name).alias(new_col_name) for new_col_name, col_name in zip(new_path, path) ) path = new_path # ------------ Define aggregation functions -------------------------------- agg_f = {} if args["values"]: try: df = df.with_columns(nw.col(args["values"]).cast(nw.Float64())) except Exception: # pandas, Polars and pyarrow exception types are different raise ValueError( "Column `%s` of `df` could not be converted to a numerical data type." % args["values"] ) if args["color"] and args["color"] == args["values"]: new_value_col_name = args["values"] + "_sum" df = df.with_columns(nw.col(args["values"]).alias(new_value_col_name)) args["values"] = new_value_col_name count_colname = args["values"] else: # we need a count column for the first groupby and the weighted mean of color # trick to be sure the col name is unused: take the sum of existing names columns = df.collect_schema().names() count_colname = ( "count" if "count" not in columns else "".join([str(el) for el in columns]) ) # we can modify df because it's a copy of the px argument df = df.with_columns(nw.lit(1).alias(count_colname)) args["values"] = count_colname # Since count_colname is always in agg_f, it can be used later to normalize color # in the continuous case after some gymnastic agg_f[count_colname] = nw.sum(count_colname) discrete_aggs = [] continuous_aggs = [] n_unique_token = _generate_temporary_column_name( n_bytes=16, columns=df.collect_schema().names() ) # In theory, for discrete columns aggregation, we should have a way to do # `.agg(nw.col(x).unique())` in group_by and successively unpack/parse it as: # ``` # (nw.when(nw.col(x).list.len()==1) # .then(nw.col(x).list.first()) # .otherwise(nw.lit("(?)")) # ) # ``` # which replicates the original pandas only codebase: # ``` # def discrete_agg(x): # uniques = x.unique() # return uniques[0] if len(uniques) == 1 else "(?)" # # df.groupby(path[i:]).agg(...) # ``` # However this is not possible, therefore the following workaround is provided. # We make two aggregations for the same column: # - take the max value # - take the number of unique values # Finally, after the group by statement, it is unpacked via: # ``` # (nw.when(nw.col(col_n_unique) == 1) # .then(nw.col(col_max_value)) # which is the unique value # .otherwise(nw.lit("(?)")) # ) # ``` if args["color"]: if discrete_color: discrete_aggs.append(args["color"]) agg_f[args["color"]] = nw.col(args["color"]).max() agg_f[f"{args['color']}{n_unique_token}"] = ( nw.col(args["color"]) .n_unique() .alias(f"{args['color']}{n_unique_token}") ) else: # This first needs to be multiplied by `count_colname` continuous_aggs.append(args["color"]) agg_f[args["color"]] = nw.sum(args["color"]) # Other columns (for color, hover_data, custom_data etc.) cols = list(set(df.collect_schema().names()).difference(path)) df = df.with_columns(nw.col(c).cast(nw.String()) for c in cols if c not in agg_f) for col in cols: # for hover_data, custom_data etc. if col not in agg_f: # Similar trick as above discrete_aggs.append(col) agg_f[col] = nw.col(col).max() agg_f[f"{col}{n_unique_token}"] = ( nw.col(col).n_unique().alias(f"{col}{n_unique_token}") ) # Avoid collisions with reserved names - columns in the path have been copied already cols = list(set(cols) - set(["labels", "parent", "id"])) # ---------------------------------------------------------------------------- all_trees = [] if args["color"] and not discrete_color: df = df.with_columns( (nw.col(args["color"]) * nw.col(count_colname)).alias(args["color"]) ) def post_agg(dframe: nw.LazyFrame, continuous_aggs, discrete_aggs) -> nw.LazyFrame: """ - continuous_aggs is either [] or [args["color"]] - discrete_aggs is either [args["color"], ] or [] """ return dframe.with_columns( *[nw.col(col) / nw.col(count_colname) for col in continuous_aggs], *[ ( nw.when(nw.col(f"{col}{n_unique_token}") == 1) .then(nw.col(col)) .otherwise(nw.lit("(?)")) .alias(col) ) for col in discrete_aggs ], ).drop([f"{col}{n_unique_token}" for col in discrete_aggs]) for i, level in enumerate(path): dfg = ( df.group_by(path[i:], drop_null_keys=True) .agg(**agg_f) .pipe(post_agg, continuous_aggs, discrete_aggs) ) # Path label massaging df_tree = dfg.with_columns( *cols, labels=nw.col(level).cast(nw.String()), parent=nw.lit(""), id=nw.col(level).cast(nw.String()), ) if i < len(path) - 1: _concat_str_token = _generate_temporary_column_name( n_bytes=16, columns=[*cols, "labels", "parent", "id"] ) df_tree = ( df_tree.with_columns( nw.concat_str( [ nw.col(path[j]).cast(nw.String()) for j in range(len(path) - 1, i, -1) ], separator="/", ).alias(_concat_str_token) ) .with_columns( parent=nw.concat_str( [nw.col(_concat_str_token), nw.col("parent")], separator="/" ), id=nw.concat_str( [nw.col(_concat_str_token), nw.col("id")], separator="/" ), ) .drop(_concat_str_token) ) # strip "/" if at the end of the string, equivalent to `.str.rstrip` df_tree = df_tree.with_columns( parent=nw.col("parent").str.replace("/?$", "").str.replace("^/?", "") ) all_trees.append(df_tree.select(*["labels", "parent", "id", *cols])) df_all_trees = nw.maybe_reset_index(nw.concat(all_trees, how="vertical").collect()) # we want to make sure than (?) is the first color of the sequence if args["color"] and discrete_color: sort_col_name = "sort_color_if_discrete_color" while sort_col_name in df_all_trees.columns: sort_col_name += "0" df_all_trees = df_all_trees.with_columns( nw.col(args["color"]).cast(nw.String()).alias(sort_col_name) ).sort(by=sort_col_name, nulls_last=True) # Now modify arguments args["data_frame"] = df_all_trees args["path"] = None args["ids"] = "id" args["names"] = "labels" args["parents"] = "parent" if args["color"]: if not args["hover_data"]: args["hover_data"] = [args["color"]] elif isinstance(args["hover_data"], dict): if not args["hover_data"].get(args["color"]): args["hover_data"][args["color"]] = (True, None) else: args["hover_data"].append(args["color"]) return args def process_dataframe_timeline(args): """ Massage input for bar traces for px.timeline() """ args["is_timeline"] = True if args["x_start"] is None or args["x_end"] is None: raise ValueError("Both x_start and x_end are required") df: nw.DataFrame = args["data_frame"] schema = df.schema to_convert_to_datetime = [ col for col in [args["x_start"], args["x_end"]] if schema[col] != nw.Datetime and schema[col] != nw.Date ] if to_convert_to_datetime: try: df = df.with_columns(nw.col(to_convert_to_datetime).str.to_datetime()) except Exception as exc: raise TypeError( "Both x_start and x_end must refer to data convertible to datetimes." ) from exc # note that we are not adding any columns to the data frame here, so no risk of overwrite args["data_frame"] = df.with_columns( (nw.col(args["x_end"]) - nw.col(args["x_start"])) .dt.total_milliseconds() .alias(args["x_end"]) ) args["x"] = args["x_end"] args["base"] = args["x_start"] del args["x_start"], args["x_end"] return args def process_dataframe_pie(args, trace_patch): import numpy as np names = args.get("names") if names is None: return args, trace_patch order_in = args["category_orders"].get(names, {}).copy() if not order_in: return args, trace_patch df: nw.DataFrame = args["data_frame"] trace_patch["sort"] = False trace_patch["direction"] = "clockwise" uniques = df.get_column(names).unique(maintain_order=True).to_list() order = [x for x in OrderedDict.fromkeys(list(order_in) + uniques) if x in uniques] # Sort args['data_frame'] by column `names` according to order `order`. token = nw.generate_temporary_column_name(8, df.columns) args["data_frame"] = ( df.with_columns( nw.col(names) .replace_strict(order, np.arange(len(order)), return_dtype=nw.UInt32) .alias(token) ) .sort(token) .drop(token) ) return args, trace_patch def infer_config(args, constructor, trace_patch, layout_patch): attrs = [k for k in direct_attrables + array_attrables if k in args] grouped_attrs = [] df: nw.DataFrame = args["data_frame"] # Compute sizeref sizeref = 0 if "size" in args and args["size"]: sizeref = ( nw.to_py_scalar(df.get_column(args["size"]).max()) / args["size_max"] ** 2 ) # Compute color attributes and grouping attributes if "color" in args: if "color_continuous_scale" in args: if "color_discrete_sequence" not in args: attrs.append("color") else: if args["color"] and _is_continuous(df, args["color"]): attrs.append("color") args["color_is_continuous"] = True elif constructor in [go.Sunburst, go.Treemap, go.Icicle]: attrs.append("color") args["color_is_continuous"] = False else: grouped_attrs.append("marker.color") elif "line_group" in args or constructor == go.Histogram2dContour: grouped_attrs.append("line.color") elif constructor in [go.Pie, go.Funnelarea]: attrs.append("color") if args["color"]: if args["hover_data"] is None: args["hover_data"] = [] args["hover_data"].append(args["color"]) else: grouped_attrs.append("marker.color") show_colorbar = bool( "color" in attrs and args["color"] and constructor not in [go.Pie, go.Funnelarea] and ( constructor not in [go.Treemap, go.Sunburst, go.Icicle] or args.get("color_is_continuous") ) ) else: show_colorbar = False if "line_dash" in args: grouped_attrs.append("line.dash") if "symbol" in args: grouped_attrs.append("marker.symbol") if "pattern_shape" in args: if constructor in [go.Scatter]: grouped_attrs.append("fillpattern.shape") else: grouped_attrs.append("marker.pattern.shape") if "orientation" in args: has_x = args["x"] is not None has_y = args["y"] is not None if args["orientation"] is None: if constructor in [go.Histogram, go.Scatter]: if has_y and not has_x: args["orientation"] = "h" elif constructor in [go.Violin, go.Box, go.Bar, go.Funnel]: if has_x and not has_y: args["orientation"] = "h" if args["orientation"] is None and has_x and has_y: x_is_continuous = _is_continuous(df, args["x"]) y_is_continuous = _is_continuous(df, args["y"]) if x_is_continuous and not y_is_continuous: args["orientation"] = "h" if y_is_continuous and not x_is_continuous: args["orientation"] = "v" if args["orientation"] is None: args["orientation"] = "v" if constructor == go.Histogram: if has_x and has_y and args["histfunc"] is None: args["histfunc"] = trace_patch["histfunc"] = "sum" orientation = args["orientation"] nbins = args["nbins"] trace_patch["nbinsx"] = nbins if orientation == "v" else None trace_patch["nbinsy"] = None if orientation == "v" else nbins trace_patch["bingroup"] = "x" if orientation == "v" else "y" trace_patch["orientation"] = args["orientation"] if constructor in [go.Violin, go.Box]: mode = "boxmode" if constructor == go.Box else "violinmode" if layout_patch[mode] is None and args["color"] is not None: if args["y"] == args["color"] and args["orientation"] == "h": layout_patch[mode] = "overlay" elif args["x"] == args["color"] and args["orientation"] == "v": layout_patch[mode] = "overlay" if layout_patch[mode] is None: layout_patch[mode] = "group" if ( constructor == go.Histogram2d and args["z"] is not None and args["histfunc"] is None ): args["histfunc"] = trace_patch["histfunc"] = "sum" if args.get("text_auto", False) is not False: if constructor in [go.Histogram2d, go.Histogram2dContour]: letter = "z" elif constructor == go.Bar: letter = "y" if args["orientation"] == "v" else "x" else: letter = "value" if args["text_auto"] is True: trace_patch["texttemplate"] = "%{" + letter + "}" else: trace_patch["texttemplate"] = "%{" + letter + ":" + args["text_auto"] + "}" if constructor in [go.Histogram2d, go.Densitymap, go.Densitymapbox]: show_colorbar = True trace_patch["coloraxis"] = "coloraxis1" if "opacity" in args: if args["opacity"] is None: if "barmode" in args and args["barmode"] == "overlay": trace_patch["marker"] = dict(opacity=0.5) elif constructor in [ go.Densitymap, go.Densitymapbox, go.Pie, go.Funnel, go.Funnelarea, ]: trace_patch["opacity"] = args["opacity"] else: trace_patch["marker"] = dict(opacity=args["opacity"]) if ( "line_group" in args or "line_dash" in args ): # px.line, px.line_*, px.area, px.ecdf modes = set() if args.get("lines", True): modes.add("lines") if args.get("text") or args.get("symbol") or args.get("markers"): modes.add("markers") if args.get("text"): modes.add("text") if len(modes) == 0: modes.add("lines") trace_patch["mode"] = "+".join(sorted(modes)) elif constructor != go.Splom and ( "symbol" in args or constructor in [go.Scattermap, go.Scattermapbox] ): trace_patch["mode"] = "markers" + ("+text" if args["text"] else "") if "line_shape" in args: trace_patch["line"] = dict(shape=args["line_shape"]) elif "ecdfmode" in args: trace_patch["line"] = dict( shape="vh" if args["ecdfmode"] == "reversed" else "hv" ) if "geojson" in args: trace_patch["featureidkey"] = args["featureidkey"] trace_patch["geojson"] = ( args["geojson"] if not hasattr(args["geojson"], "__geo_interface__") # for geopandas else args["geojson"].__geo_interface__ ) # Compute marginal attribute: copy to appropriate marginal_* if "marginal" in args: position = "marginal_x" if args["orientation"] == "v" else "marginal_y" other_position = "marginal_x" if args["orientation"] == "h" else "marginal_y" args[position] = args["marginal"] args[other_position] = None # Ignore facet rows and columns when data frame is empty so as to prevent nrows/ncols equaling 0 if df.is_empty(): args["facet_row"] = args["facet_col"] = None # If both marginals and faceting are specified, faceting wins if args.get("facet_col") is not None and args.get("marginal_y") is not None: args["marginal_y"] = None if args.get("facet_row") is not None and args.get("marginal_x") is not None: args["marginal_x"] = None # facet_col_wrap only works if no marginals or row faceting is used if ( args.get("marginal_x") is not None or args.get("marginal_y") is not None or args.get("facet_row") is not None ): args["facet_col_wrap"] = 0 if "trendline" in args and args["trendline"] is not None: if args["trendline"] not in trendline_functions: raise ValueError( "Value '%s' for `trendline` must be one of %s" % (args["trendline"], trendline_functions.keys()) ) if "trendline_options" in args and args["trendline_options"] is None: args["trendline_options"] = dict() if "ecdfnorm" in args: if args.get("ecdfnorm", None) not in [None, "percent", "probability"]: raise ValueError( "`ecdfnorm` must be one of None, 'percent' or 'probability'. " + "'%s' was provided." % args["ecdfnorm"] ) args["histnorm"] = args["ecdfnorm"] # Compute applicable grouping attributes grouped_attrs.extend([k for k in group_attrables if k in args]) # Create grouped mappings grouped_mappings = [make_mapping(args, a) for a in grouped_attrs] # Create trace specs trace_specs = make_trace_spec(args, constructor, attrs, trace_patch) return trace_specs, grouped_mappings, sizeref, show_colorbar def get_groups_and_orders(args, grouper): """ `orders` is the user-supplied ordering with the remaining data-frame-supplied ordering appended if the column is used for grouping. It includes anything the user gave, for any variable, including values not present in the dataset. It's a dict where the keys are e.g. "x" or "color" `groups` is the dicts of groups, ordered by the order above. Its keys are tuples like [("value1", ""), ("value2", "")] where each tuple contains the name of a single dimension-group """ orders = {} if "category_orders" not in args else args["category_orders"].copy() df: nw.DataFrame = args["data_frame"] # figure out orders and what the single group name would be if there were one single_group_name = [] unique_cache = dict() for i, col in enumerate(grouper): if col == one_group: single_group_name.append("") else: if col not in unique_cache: unique_cache[col] = ( df.get_column(col).unique(maintain_order=True).to_list() ) uniques = unique_cache[col] if len(uniques) == 1: single_group_name.append(uniques[0]) if col not in orders: orders[col] = uniques else: orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques)) if len(single_group_name) == len(grouper): # we have a single group, so we can skip all group-by operations! groups = {tuple(single_group_name): df} else: required_grouper = [group for group in orders if group in grouper] grouped = dict(df.group_by(required_grouper, drop_null_keys=True).__iter__()) sorted_group_names = sorted( grouped.keys(), key=lambda values: [ orders[group].index(value) if value in orders[group] else -1 for group, value in zip(required_grouper, values) ], ) # calculate the full group_names by inserting "" in the tuple index for one_group groups full_sorted_group_names = [ tuple( [ ( "" if col == one_group else sub_group_names[required_grouper.index(col)] ) for col in grouper ] ) for sub_group_names in sorted_group_names ] groups = { sf: grouped[s] for sf, s in zip(full_sorted_group_names, sorted_group_names) } return groups, orders def make_figure(args, constructor, trace_patch=None, layout_patch=None): trace_patch = trace_patch or {} layout_patch = layout_patch or {} apply_default_cascade(args) args = build_dataframe(args, constructor) if constructor in [go.Treemap, go.Sunburst, go.Icicle] and args["path"] is not None: args = process_dataframe_hierarchy(args) if constructor in [go.Pie]: args, trace_patch = process_dataframe_pie(args, trace_patch) if constructor == "timeline": constructor = go.Bar args = process_dataframe_timeline(args) # If we have marginal histograms, set barmode to "overlay" if "histogram" in [args.get("marginal_x"), args.get("marginal_y")]: layout_patch["barmode"] = "overlay" trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config( args, constructor, trace_patch, layout_patch ) grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group] groups, orders = get_groups_and_orders(args, grouper) col_labels = [] row_labels = [] nrows = ncols = 1 for m in grouped_mappings: if m.grouper not in orders: m.val_map[""] = m.sequence[0] else: sorted_values = orders[m.grouper] if m.facet == "col": prefix = get_label(args, args["facet_col"]) + "=" col_labels = [prefix + str(s) for s in sorted_values] ncols = len(col_labels) if m.facet == "row": prefix = get_label(args, args["facet_row"]) + "=" row_labels = [prefix + str(s) for s in sorted_values] nrows = len(row_labels) for val in sorted_values: if val not in m.val_map: # always False if it's an IdentityMap m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)] subplot_type = _subplot_type_for_trace_type(constructor().type) trace_names_by_frame = {} frames = OrderedDict() trendline_rows = [] trace_name_labels = None facet_col_wrap = args.get("facet_col_wrap", 0) for group_name, group in groups.items(): mapping_labels = OrderedDict() trace_name_labels = OrderedDict() frame_name = "" for col, val, m in zip(grouper, group_name, grouped_mappings): if col != one_group: key = get_label(args, col) if not isinstance(m.val_map, IdentityMap): mapping_labels[key] = str(val) if m.show_in_trace_name: trace_name_labels[key] = str(val) if m.variable == "animation_frame": frame_name = val trace_name = ", ".join(trace_name_labels.values()) if frame_name not in trace_names_by_frame: trace_names_by_frame[frame_name] = set() trace_names = trace_names_by_frame[frame_name] for trace_spec in trace_specs: # Create the trace trace = trace_spec.constructor(name=trace_name) if trace_spec.constructor not in [ go.Parcats, go.Parcoords, go.Choropleth, go.Choroplethmap, go.Choroplethmapbox, go.Densitymap, go.Densitymapbox, go.Histogram2d, go.Sunburst, go.Treemap, go.Icicle, ]: trace.update( legendgroup=trace_name, showlegend=(trace_name != "" and trace_name not in trace_names), ) # Set 'offsetgroup' only in group barmode (or if no barmode is set) barmode = layout_patch.get("barmode") if trace_spec.constructor in [go.Bar, go.Box, go.Violin, go.Histogram] and ( barmode == "group" or barmode is None ): trace.update(alignmentgroup=True, offsetgroup=trace_name) trace_names.add(trace_name) # Init subplot row/col trace._subplot_row = 1 trace._subplot_col = 1 for i, m in enumerate(grouped_mappings): val = group_name[i] try: m.updater(trace, m.val_map[val]) # covers most cases except ValueError: # this catches some odd cases like marginals if ( trace_spec != trace_specs[0] and ( trace_spec.constructor in [go.Violin, go.Box] and m.variable in ["symbol", "pattern", "dash"] ) or ( trace_spec.constructor in [go.Histogram] and m.variable in ["symbol", "dash"] ) ): pass elif ( trace_spec != trace_specs[0] and trace_spec.constructor in [go.Histogram] and m.variable == "color" ): trace.update(marker=dict(color=m.val_map[val])) elif ( trace_spec.constructor in [go.Choropleth, go.Choroplethmap, go.Choroplethmapbox] and m.variable == "color" ): trace.update( z=[1] * len(group), colorscale=[m.val_map[val]] * 2, showscale=False, showlegend=True, ) else: raise # Find row for trace, handling facet_row and marginal_x if m.facet == "row": row = m.val_map[val] else: if ( args.get("marginal_x") is not None # there is a marginal and trace_spec.marginal != "x" # and we're not it ): row = 2 else: row = 1 # Find col for trace, handling facet_col and marginal_y if m.facet == "col": col = m.val_map[val] if facet_col_wrap: # assumes no facet_row, no marginals row = 1 + ((col - 1) // facet_col_wrap) col = 1 + ((col - 1) % facet_col_wrap) else: if trace_spec.marginal == "y": col = 2 else: col = 1 if row > 1: trace._subplot_row = row if col > 1: trace._subplot_col = col if ( trace_specs[0].constructor == go.Histogram2dContour and trace_spec.constructor == go.Box and trace.line.color ): trace.update(marker=dict(color=trace.line.color)) if "ecdfmode" in args: base = args["x"] if args["orientation"] == "v" else args["y"] var = args["x"] if args["orientation"] == "h" else args["y"] ascending = args.get("ecdfmode", "standard") != "reversed" group = group.sort(by=base, descending=not ascending, nulls_last=True) group_sum = group.get_column( var ).sum() # compute here before next line mutates group = group.with_columns(nw.col(var).cum_sum().alias(var)) if not ascending: group = group.sort(by=base, descending=False, nulls_last=True) if args.get("ecdfmode", "standard") == "complementary": group = group.with_columns((group_sum - nw.col(var)).alias(var)) if args["ecdfnorm"] == "probability": group = group.with_columns(nw.col(var) / group_sum) elif args["ecdfnorm"] == "percent": group = group.with_columns((nw.col(var) / group_sum) * 100.0) patch, fit_results = make_trace_kwargs( args, trace_spec, group, mapping_labels.copy(), sizeref ) trace.update(patch) if fit_results is not None: trendline_rows.append(mapping_labels.copy()) trendline_rows[-1]["px_fit_results"] = fit_results if frame_name not in frames: frames[frame_name] = dict(data=[], name=frame_name) frames[frame_name]["data"].append(trace) frame_list = [f for f in frames.values()] if len(frame_list) > 1: frame_list = sorted( frame_list, key=lambda f: orders[args["animation_frame"]].index(f["name"]) ) if show_colorbar: colorvar = ( "z" if constructor in [go.Histogram2d, go.Densitymap, go.Densitymapbox] else "color" ) range_color = args["range_color"] or [None, None] colorscale_validator = ColorscaleValidator("colorscale", "make_figure") layout_patch["coloraxis1"] = dict( colorscale=colorscale_validator.validate_coerce( args["color_continuous_scale"] ), cmid=args["color_continuous_midpoint"], cmin=range_color[0], cmax=range_color[1], colorbar=dict( title_text=get_decorated_label(args, args[colorvar], colorvar) ), ) for v in ["height", "width"]: if args[v]: layout_patch[v] = args[v] layout_patch["legend"] = dict(tracegroupgap=0) if trace_name_labels: layout_patch["legend"]["title_text"] = ", ".join(trace_name_labels) if args["title"]: layout_patch["title_text"] = args["title"] elif args["template"].layout.margin.t is None: layout_patch["margin"] = {"t": 60} if args["subtitle"]: layout_patch["title_subtitle_text"] = args["subtitle"] if ( "size" in args and args["size"] and args["template"].layout.legend.itemsizing is None ): layout_patch["legend"]["itemsizing"] = "constant" if facet_col_wrap: nrows = math.ceil(ncols / facet_col_wrap) ncols = min(ncols, facet_col_wrap) if args.get("marginal_x") is not None: nrows += 1 if args.get("marginal_y") is not None: ncols += 1 fig = init_figure( args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels ) # Position traces in subplots for frame in frame_list: for trace in frame["data"]: if isinstance(trace, go.Splom): # Special case that is not compatible with make_subplots continue _set_trace_grid_reference( trace, fig.layout, fig._grid_ref, nrows - trace._subplot_row + 1, trace._subplot_col, ) # Add traces, layout and frames to figure fig.add_traces(frame_list[0]["data"] if len(frame_list) > 0 else []) fig.update_layout(layout_patch) if "template" in args and args["template"] is not None: fig.update_layout(template=args["template"], overwrite=True) for f in frame_list: f["name"] = str(f["name"]) fig.frames = frame_list if len(frames) > 1 else [] if args.get("trendline") and args.get("trendline_scope", "trace") == "overall": trendline_spec = make_trendline_spec(args, constructor) trendline_trace = trendline_spec.constructor( name="Overall Trendline", legendgroup="Overall Trendline", showlegend=False ) if "line" not in trendline_spec.trace_patch: # no color override for m in grouped_mappings: if m.variable == "color": next_color = m.sequence[len(m.val_map) % len(m.sequence)] trendline_spec.trace_patch["line"] = dict(color=next_color) patch, fit_results = make_trace_kwargs( args, trendline_spec, args["data_frame"], {}, sizeref ) trendline_trace.update(patch) fig.add_trace( trendline_trace, row="all", col="all", exclude_empty_subplots=True ) fig.update_traces(selector=-1, showlegend=True) if fit_results is not None: trendline_rows.append(dict(px_fit_results=fit_results)) if trendline_rows: try: import pandas as pd fig._px_trendlines = pd.DataFrame(trendline_rows) except ImportError: msg = "Trendlines require pandas to be installed." raise NotImplementedError(msg) else: fig._px_trendlines = [] configure_axes(args, constructor, fig, orders) configure_animation_controls(args, constructor, fig) return fig def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels): # Build subplot specs specs = [[dict(type=subplot_type or "domain")] * ncols for _ in range(nrows)] # Default row/column widths uniform column_widths = [1.0] * ncols row_heights = [1.0] * nrows facet_col_wrap = args.get("facet_col_wrap", 0) # Build column_widths/row_heights if subplot_type == "xy": if args.get("marginal_x") is not None: if args["marginal_x"] == "histogram" or ("color" in args and args["color"]): main_size = 0.74 else: main_size = 0.84 row_heights = [main_size] * (nrows - 1) + [1 - main_size] vertical_spacing = 0.01 elif facet_col_wrap: vertical_spacing = args.get("facet_row_spacing") or 0.07 else: vertical_spacing = args.get("facet_row_spacing") or 0.03 if args.get("marginal_y") is not None: if args["marginal_y"] == "histogram" or ("color" in args and args["color"]): main_size = 0.74 else: main_size = 0.84 column_widths = [main_size] * (ncols - 1) + [1 - main_size] horizontal_spacing = 0.005 else: horizontal_spacing = args.get("facet_col_spacing") or 0.02 else: # Other subplot types: # 'scene', 'geo', 'polar', 'ternary', 'mapbox', 'domain', None # # We can customize subplot spacing per type once we enable faceting # for all plot types if facet_col_wrap: vertical_spacing = args.get("facet_row_spacing") or 0.07 else: vertical_spacing = args.get("facet_row_spacing") or 0.03 horizontal_spacing = args.get("facet_col_spacing") or 0.02 if facet_col_wrap: subplot_labels = [None] * nrows * ncols while len(col_labels) < nrows * ncols: col_labels.append(None) for i in range(nrows): for j in range(ncols): subplot_labels[i * ncols + j] = col_labels[(nrows - 1 - i) * ncols + j] def _spacing_error_translator(e, direction, facet_arg): """ Translates the spacing errors thrown by the underlying make_subplots routine into one that describes an argument adjustable through px. """ if ("%s spacing" % (direction,)) in e.args[0]: e.args = ( e.args[0] + """ Use the {facet_arg} argument to adjust this spacing.""".format(facet_arg=facet_arg), ) raise e # Create figure with subplots try: fig = make_subplots( rows=nrows, cols=ncols, specs=specs, shared_xaxes="all", shared_yaxes="all", row_titles=[] if facet_col_wrap else list(reversed(row_labels)), column_titles=[] if facet_col_wrap else col_labels, subplot_titles=subplot_labels if facet_col_wrap else [], horizontal_spacing=horizontal_spacing, vertical_spacing=vertical_spacing, row_heights=row_heights, column_widths=column_widths, start_cell="bottom-left", ) except ValueError as e: _spacing_error_translator(e, "Horizontal", "facet_col_spacing") _spacing_error_translator(e, "Vertical", "facet_row_spacing") raise # Remove explicit font size of row/col titles so template can take over for annot in fig.layout.annotations: annot.update(font=None) return fig