# -*- coding: utf-8 -*- """ tools ===== Functions that USERS will possibly want access to. """ from __future__ import absolute_import from collections import OrderedDict import warnings import six import math import decimal from plotly import utils from plotly import exceptions from plotly import graph_reference from plotly import session from plotly.files import (CONFIG_FILE, CREDENTIALS_FILE, FILE_CONTENT, GRAPH_REFERENCE_FILE, check_file_permissions) DEFAULT_PLOTLY_COLORS = ['rgb(31, 119, 180)', 'rgb(255, 127, 14)', 'rgb(44, 160, 44)', 'rgb(214, 39, 40)', 'rgb(148, 103, 189)', 'rgb(140, 86, 75)', 'rgb(227, 119, 194)', 'rgb(127, 127, 127)', 'rgb(188, 189, 34)', 'rgb(23, 190, 207)'] REQUIRED_GANTT_KEYS = ['Task', 'Start', 'Finish'] PLOTLY_SCALES = {'Greys': ['rgb(0,0,0)', 'rgb(255,255,255)'], 'YlGnBu': ['rgb(8,29,88)', 'rgb(255,255,217)'], 'Greens': ['rgb(0,68,27)', 'rgb(247,252,245)'], 'YlOrRd': ['rgb(128,0,38)', 'rgb(255,255,204)'], 'Bluered': ['rgb(0,0,255)', 'rgb(255,0,0)'], 'RdBu': ['rgb(5,10,172)', 'rgb(178,10,28)'], 'Reds': ['rgb(220,220,220)', 'rgb(178,10,28)'], 'Blues': ['rgb(5,10,172)', 'rgb(220,220,220)'], 'Picnic': ['rgb(0,0,255)', 'rgb(255,0,0)'], 'Rainbow': ['rgb(150,0,90)', 'rgb(255,0,0)'], 'Portland': ['rgb(12,51,131)', 'rgb(217,30,30)'], 'Jet': ['rgb(0,0,131)', 'rgb(128,0,0)'], 'Hot': ['rgb(0,0,0)', 'rgb(255,255,255)'], 'Blackbody': ['rgb(0,0,0)', 'rgb(160,200,255)'], 'Earth': ['rgb(0,0,130)', 'rgb(255,255,255)'], 'Electric': ['rgb(0,0,0)', 'rgb(255,250,220)'], 'Viridis': ['rgb(68,1,84)', 'rgb(253,231,37)']} # color constants for violin plot DEFAULT_FILLCOLOR = '#1f77b4' DEFAULT_HISTNORM = 'probability density' ALTERNATIVE_HISTNORM = 'probability' # Warning format def warning_on_one_line(message, category, filename, lineno, file=None, line=None): return '%s:%s: %s:\n\n%s\n\n' % (filename, lineno, category.__name__, message) warnings.formatwarning = warning_on_one_line try: from . import matplotlylib _matplotlylib_imported = True except ImportError: _matplotlylib_imported = False try: import IPython import IPython.core.display _ipython_imported = True except ImportError: _ipython_imported = False try: import numpy as np _numpy_imported = True except ImportError: _numpy_imported = False try: import pandas as pd _pandas_imported = True except ImportError: _pandas_imported = False try: import scipy as scp _scipy_imported = True except ImportError: _scipy_imported = False try: import scipy.spatial as scs _scipy__spatial_imported = True except ImportError: _scipy__spatial_imported = False try: import scipy.cluster.hierarchy as sch _scipy__cluster__hierarchy_imported = True except ImportError: _scipy__cluster__hierarchy_imported = False try: import scipy import scipy.stats _scipy_imported = True except ImportError: _scipy_imported = False def get_config_defaults(): """ Convenience function to check current settings against defaults. Example: if plotly_domain != get_config_defaults()['plotly_domain']: # do something """ return dict(FILE_CONTENT[CONFIG_FILE]) # performs a shallow copy def ensure_local_plotly_files(): """Ensure that filesystem is setup/filled out in a valid way. If the config or credential files aren't filled out, then write them to the disk. """ if check_file_permissions(): for fn in [CREDENTIALS_FILE, CONFIG_FILE]: utils.ensure_file_exists(fn) contents = utils.load_json_dict(fn) for key, val in list(FILE_CONTENT[fn].items()): # TODO: removed type checking below, may want to revisit if key not in contents: contents[key] = val contents_keys = list(contents.keys()) for key in contents_keys: if key not in FILE_CONTENT[fn]: del contents[key] utils.save_json_dict(fn, contents) # make a request to get graph reference if DNE. utils.ensure_file_exists(GRAPH_REFERENCE_FILE) utils.save_json_dict(GRAPH_REFERENCE_FILE, graph_reference.GRAPH_REFERENCE) else: warnings.warn("Looks like you don't have 'read-write' permission to " "your 'home' ('~') directory or to our '~/.plotly' " "directory. That means plotly's python api can't setup " "local configuration files. No problem though! You'll " "just have to sign-in using 'plotly.plotly.sign_in()'. " "For help with that: 'help(plotly.plotly.sign_in)'." "\nQuestions? support@plot.ly") ### credentials tools ### def set_credentials_file(username=None, api_key=None, stream_ids=None, proxy_username=None, proxy_password=None): """Set the keyword-value pairs in `~/.plotly_credentials`. :param (str) username: The username you'd use to sign in to Plotly :param (str) api_key: The api key associated with above username :param (list) stream_ids: Stream tokens for above credentials :param (str) proxy_username: The un associated with with your Proxy :param (str) proxy_password: The pw associated with your Proxy un """ if not check_file_permissions(): raise exceptions.PlotlyError("You don't have proper file permissions " "to run this function.") ensure_local_plotly_files() # make sure what's there is OK credentials = get_credentials_file() if isinstance(username, six.string_types): credentials['username'] = username if isinstance(api_key, six.string_types): credentials['api_key'] = api_key if isinstance(proxy_username, six.string_types): credentials['proxy_username'] = proxy_username if isinstance(proxy_password, six.string_types): credentials['proxy_password'] = proxy_password if isinstance(stream_ids, (list, tuple)): credentials['stream_ids'] = stream_ids utils.save_json_dict(CREDENTIALS_FILE, credentials) ensure_local_plotly_files() # make sure what we just put there is OK def get_credentials_file(*args): """Return specified args from `~/.plotly_credentials`. as dict. Returns all if no arguments are specified. Example: get_credentials_file('username') """ if check_file_permissions(): ensure_local_plotly_files() # make sure what's there is OK return utils.load_json_dict(CREDENTIALS_FILE, *args) else: return FILE_CONTENT[CREDENTIALS_FILE] def reset_credentials_file(): ensure_local_plotly_files() # make sure what's there is OK utils.save_json_dict(CREDENTIALS_FILE, {}) ensure_local_plotly_files() # put the defaults back ### config tools ### def set_config_file(plotly_domain=None, plotly_streaming_domain=None, plotly_api_domain=None, plotly_ssl_verification=None, plotly_proxy_authorization=None, world_readable=None, sharing=None, auto_open=None): """Set the keyword-value pairs in `~/.plotly/.config`. :param (str) plotly_domain: ex - https://plot.ly :param (str) plotly_streaming_domain: ex - stream.plot.ly :param (str) plotly_api_domain: ex - https://api.plot.ly :param (bool) plotly_ssl_verification: True = verify, False = don't verify :param (bool) plotly_proxy_authorization: True = use plotly proxy auth creds :param (bool) world_readable: True = public, False = private """ if not check_file_permissions(): raise exceptions.PlotlyError("You don't have proper file permissions " "to run this function.") ensure_local_plotly_files() # make sure what's there is OK utils.validate_world_readable_and_sharing_settings({ 'sharing': sharing, 'world_readable': world_readable}) settings = get_config_file() if isinstance(plotly_domain, six.string_types): settings['plotly_domain'] = plotly_domain elif plotly_domain is not None: raise TypeError('plotly_domain should be a string') if isinstance(plotly_streaming_domain, six.string_types): settings['plotly_streaming_domain'] = plotly_streaming_domain elif plotly_streaming_domain is not None: raise TypeError('plotly_streaming_domain should be a string') if isinstance(plotly_api_domain, six.string_types): settings['plotly_api_domain'] = plotly_api_domain elif plotly_api_domain is not None: raise TypeError('plotly_api_domain should be a string') if isinstance(plotly_ssl_verification, (six.string_types, bool)): settings['plotly_ssl_verification'] = plotly_ssl_verification elif plotly_ssl_verification is not None: raise TypeError('plotly_ssl_verification should be a boolean') if isinstance(plotly_proxy_authorization, (six.string_types, bool)): settings['plotly_proxy_authorization'] = plotly_proxy_authorization elif plotly_proxy_authorization is not None: raise TypeError('plotly_proxy_authorization should be a boolean') if isinstance(auto_open, bool): settings['auto_open'] = auto_open elif auto_open is not None: raise TypeError('auto_open should be a boolean') if isinstance(world_readable, bool): settings['world_readable'] = world_readable settings.pop('sharing') elif world_readable is not None: raise TypeError('Input should be a boolean') if isinstance(sharing, six.string_types): settings['sharing'] = sharing elif sharing is not None: raise TypeError('sharing should be a string') utils.set_sharing_and_world_readable(settings) utils.save_json_dict(CONFIG_FILE, settings) ensure_local_plotly_files() # make sure what we just put there is OK def get_config_file(*args): """Return specified args from `~/.plotly/.config`. as tuple. Returns all if no arguments are specified. Example: get_config_file('plotly_domain') """ if check_file_permissions(): ensure_local_plotly_files() # make sure what's there is OK return utils.load_json_dict(CONFIG_FILE, *args) else: return FILE_CONTENT[CONFIG_FILE] def reset_config_file(): ensure_local_plotly_files() # make sure what's there is OK f = open(CONFIG_FILE, 'w') f.close() ensure_local_plotly_files() # put the defaults back ### embed tools ### def get_embed(file_owner_or_url, file_id=None, width="100%", height=525): """Returns HTML code to embed figure on a webpage as an ").format( plotly_rest_url=plotly_rest_url, file_owner=file_owner, file_id=file_id, iframe_height=height, iframe_width=width) else: s = ("").format( plotly_rest_url=plotly_rest_url, file_owner=file_owner, file_id=file_id, share_key=share_key, iframe_height=height, iframe_width=width) return s def embed(file_owner_or_url, file_id=None, width="100%", height=525): """Embeds existing Plotly figure in IPython Notebook Plotly uniquely identifies figures with a 'file_owner'/'file_id' pair. Since each file is given a corresponding unique url, you may also simply pass a valid plotly url as the first argument. Note, if you're using a file_owner string as the first argument, you MUST specify a `file_id` keyword argument. Else, if you're using a url string as the first argument, you MUST NOT specify a `file_id` keyword argument, or file_id must be set to Python's None value. Positional arguments: file_owner_or_url (string) -- a valid plotly username OR a valid plotly url Keyword arguments: file_id (default=None) -- an int or string that can be converted to int if you're using a url, don't fill this in! width (default="100%") -- an int or string corresp. to width of the figure height (default="525") -- same as width but corresp. to the height of the figure """ try: s = get_embed(file_owner_or_url, file_id=file_id, width=width, height=height) # see if we are in the SageMath Cloud from sage_salvus import html return html(s, hide=False) except: pass if _ipython_imported: if file_id: plotly_domain = ( session.get_session_config().get('plotly_domain') or get_config_file()['plotly_domain'] ) url = "{plotly_domain}/~{un}/{fid}".format( plotly_domain=plotly_domain, un=file_owner_or_url, fid=file_id) else: url = file_owner_or_url return PlotlyDisplay(url, width, height) else: if (get_config_defaults()['plotly_domain'] != session.get_session_config()['plotly_domain']): feedback_email = 'feedback@plot.ly' else: # different domain likely means enterprise feedback_email = 'support@plot.ly' warnings.warn( "Looks like you're not using IPython or Sage to embed this " "plot. If you just want the *embed code*,\ntry using " "`get_embed()` instead." '\nQuestions? {}'.format(feedback_email)) ### mpl-related tools ### @utils.template_doc(**get_config_file()) def mpl_to_plotly(fig, resize=False, strip_style=False, verbose=False): """Convert a matplotlib figure to plotly dictionary and send. All available information about matplotlib visualizations are stored within a matplotlib.figure.Figure object. You can create a plot in python using matplotlib, store the figure object, and then pass this object to the fig_to_plotly function. In the background, mplexporter is used to crawl through the mpl figure object for appropriate information. This information is then systematically sent to the PlotlyRenderer which creates the JSON structure used to make plotly visualizations. Finally, these dictionaries are sent to plotly and your browser should open up a new tab for viewing! Optionally, if you're working in IPython, you can set notebook=True and the PlotlyRenderer will call plotly.iplot instead of plotly.plot to have the graph appear directly in the IPython notebook. Note, this function gives the user access to a simple, one-line way to render an mpl figure in plotly. If you need to trouble shoot, you can do this step manually by NOT running this fuction and entereing the following: =========================================================================== from mplexporter import Exporter from mplexporter.renderers import PlotlyRenderer # create an mpl figure and store it under a varialble 'fig' renderer = PlotlyRenderer() exporter = Exporter(renderer) exporter.run(fig) =========================================================================== You can then inspect the JSON structures by accessing these: renderer.layout -- a plotly layout dictionary renderer.data -- a list of plotly data dictionaries Positional arguments: fig -- a matplotlib figure object username -- a valid plotly username ** api_key -- a valid api_key for the above username ** notebook -- an option for use with an IPython notebook ** Don't have a username/api_key? Try looking here: {plotly_domain}/plot ** Forgot your api_key? Try signing in and looking here: {plotly_domain}/python/getting-started """ if _matplotlylib_imported: renderer = matplotlylib.PlotlyRenderer() matplotlylib.Exporter(renderer).run(fig) if resize: renderer.resize() if strip_style: renderer.strip_style() if verbose: print(renderer.msg) return renderer.plotly_fig else: warnings.warn( "To use Plotly's matplotlylib functionality, you'll need to have " "matplotlib successfully installed with all of its dependencies. " "You're getting this error because matplotlib or one of its " "dependencies doesn't seem to be installed correctly.") ### graph_objs related tools ### def get_subplots(rows=1, columns=1, print_grid=False, **kwargs): """Return a dictionary instance with the subplots set in 'layout'. Example 1: # stack two subplots vertically fig = tools.get_subplots(rows=2) fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x1', yaxis='y1')] fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x2', yaxis='y2')] Example 2: # print out string showing the subplot grid you've put in the layout fig = tools.get_subplots(rows=3, columns=2, print_grid=True) Keywords arguments with constant defaults: rows (kwarg, int greater than 0, default=1): Number of rows, evenly spaced vertically on the figure. columns (kwarg, int greater than 0, default=1): Number of columns, evenly spaced horizontally on the figure. horizontal_spacing (kwarg, float in [0,1], default=0.1): Space between subplot columns. Applied to all columns. vertical_spacing (kwarg, float in [0,1], default=0.05): Space between subplot rows. Applied to all rows. print_grid (kwarg, True | False, default=False): If True, prints a tab-delimited string representation of your plot grid. Keyword arguments with variable defaults: horizontal_spacing (kwarg, float in [0,1], default=0.2 / columns): Space between subplot columns. vertical_spacing (kwarg, float in [0,1], default=0.3 / rows): Space between subplot rows. """ # TODO: protected until #282 from plotly.graph_objs import graph_objs warnings.warn( "tools.get_subplots is depreciated. " "Please use tools.make_subplots instead." ) # Throw exception for non-integer rows and columns if not isinstance(rows, int) or rows <= 0: raise Exception("Keyword argument 'rows' " "must be an int greater than 0") if not isinstance(columns, int) or columns <= 0: raise Exception("Keyword argument 'columns' " "must be an int greater than 0") # Throw exception if non-valid kwarg is sent VALID_KWARGS = ['horizontal_spacing', 'vertical_spacing'] for key in kwargs.keys(): if key not in VALID_KWARGS: raise Exception("Invalid keyword argument: '{0}'".format(key)) # Set 'horizontal_spacing' / 'vertical_spacing' w.r.t. rows / columns try: horizontal_spacing = float(kwargs['horizontal_spacing']) except KeyError: horizontal_spacing = 0.2 / columns try: vertical_spacing = float(kwargs['vertical_spacing']) except KeyError: vertical_spacing = 0.3 / rows fig = dict(layout=graph_objs.Layout()) # will return this at the end plot_width = (1 - horizontal_spacing * (columns - 1)) / columns plot_height = (1 - vertical_spacing * (rows - 1)) / rows plot_num = 0 for rrr in range(rows): for ccc in range(columns): xaxis_name = 'xaxis{0}'.format(plot_num + 1) x_anchor = 'y{0}'.format(plot_num + 1) x_start = (plot_width + horizontal_spacing) * ccc x_end = x_start + plot_width yaxis_name = 'yaxis{0}'.format(plot_num + 1) y_anchor = 'x{0}'.format(plot_num + 1) y_start = (plot_height + vertical_spacing) * rrr y_end = y_start + plot_height xaxis = graph_objs.XAxis(domain=[x_start, x_end], anchor=x_anchor) fig['layout'][xaxis_name] = xaxis yaxis = graph_objs.YAxis(domain=[y_start, y_end], anchor=y_anchor) fig['layout'][yaxis_name] = yaxis plot_num += 1 if print_grid: print("This is the format of your plot grid!") grid_string = "" plot = 1 for rrr in range(rows): grid_line = "" for ccc in range(columns): grid_line += "[{0}]\t".format(plot) plot += 1 grid_string = grid_line + '\n' + grid_string print(grid_string) return graph_objs.Figure(fig) # forces us to validate what we just did... def make_subplots(rows=1, cols=1, shared_xaxes=False, shared_yaxes=False, start_cell='top-left', print_grid=True, **kwargs): """Return an instance of plotly.graph_objs.Figure with the subplots domain set in 'layout'. Example 1: # stack two subplots vertically fig = tools.make_subplots(rows=2) This is the format of your plot grid: [ (1,1) x1,y1 ] [ (2,1) x2,y2 ] fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2])] fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x2', yaxis='y2')] # or see Figure.append_trace Example 2: # subplots with shared x axes fig = tools.make_subplots(rows=2, shared_xaxes=True) This is the format of your plot grid: [ (1,1) x1,y1 ] [ (2,1) x1,y2 ] fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2])] fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], yaxis='y2')] Example 3: # irregular subplot layout (more examples below under 'specs') fig = tools.make_subplots(rows=2, cols=2, specs=[[{}, {}], [{'colspan': 2}, None]]) This is the format of your plot grid! [ (1,1) x1,y1 ] [ (1,2) x2,y2 ] [ (2,1) x3,y3 - ] fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2])] fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x2', yaxis='y2')] fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x3', yaxis='y3')] Example 4: # insets fig = tools.make_subplots(insets=[{'cell': (1,1), 'l': 0.7, 'b': 0.3}]) This is the format of your plot grid! [ (1,1) x1,y1 ] With insets: [ x2,y2 ] over [ (1,1) x1,y1 ] fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2])] fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x2', yaxis='y2')] Example 5: # include subplot titles fig = tools.make_subplots(rows=2, subplot_titles=('Plot 1','Plot 2')) This is the format of your plot grid: [ (1,1) x1,y1 ] [ (2,1) x2,y2 ] fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2])] fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x2', yaxis='y2')] Example 6: # Include subplot title on one plot (but not all) fig = tools.make_subplots(insets=[{'cell': (1,1), 'l': 0.7, 'b': 0.3}], subplot_titles=('','Inset')) This is the format of your plot grid! [ (1,1) x1,y1 ] With insets: [ x2,y2 ] over [ (1,1) x1,y1 ] fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2])] fig['data'] += [Scatter(x=[1,2,3], y=[2,1,2], xaxis='x2', yaxis='y2')] Keywords arguments with constant defaults: rows (kwarg, int greater than 0, default=1): Number of rows in the subplot grid. cols (kwarg, int greater than 0, default=1): Number of columns in the subplot grid. shared_xaxes (kwarg, boolean or list, default=False) Assign shared x axes. If True, subplots in the same grid column have one common shared x-axis at the bottom of the gird. To assign shared x axes per subplot grid cell (see 'specs'), send list (or list of lists, one list per shared x axis) of cell index tuples. shared_yaxes (kwarg, boolean or list, default=False) Assign shared y axes. If True, subplots in the same grid row have one common shared y-axis on the left-hand side of the gird. To assign shared y axes per subplot grid cell (see 'specs'), send list (or list of lists, one list per shared y axis) of cell index tuples. start_cell (kwarg, 'bottom-left' or 'top-left', default='top-left') Choose the starting cell in the subplot grid used to set the domains of the subplots. print_grid (kwarg, boolean, default=True): If True, prints a tab-delimited string representation of your plot grid. Keyword arguments with variable defaults: horizontal_spacing (kwarg, float in [0,1], default=0.2 / cols): Space between subplot columns. Applies to all columns (use 'specs' subplot-dependents spacing) vertical_spacing (kwarg, float in [0,1], default=0.3 / rows): Space between subplot rows. Applies to all rows (use 'specs' subplot-dependents spacing) subplot_titles (kwarg, list of strings, default=empty list): Title of each subplot. "" can be included in the list if no subplot title is desired in that space so that the titles are properly indexed. specs (kwarg, list of lists of dictionaries): Subplot specifications. ex1: specs=[[{}, {}], [{'colspan': 2}, None]] ex2: specs=[[{'rowspan': 2}, {}], [None, {}]] - Indices of the outer list correspond to subplot grid rows starting from the bottom. The number of rows in 'specs' must be equal to 'rows'. - Indices of the inner lists correspond to subplot grid columns starting from the left. The number of columns in 'specs' must be equal to 'cols'. - Each item in the 'specs' list corresponds to one subplot in a subplot grid. (N.B. The subplot grid has exactly 'rows' times 'cols' cells.) - Use None for blank a subplot cell (or to move pass a col/row span). - Note that specs[0][0] has the specs of the 'start_cell' subplot. - Each item in 'specs' is a dictionary. The available keys are: * is_3d (boolean, default=False): flag for 3d scenes * colspan (int, default=1): number of subplot columns for this subplot to span. * rowspan (int, default=1): number of subplot rows for this subplot to span. * l (float, default=0.0): padding left of cell * r (float, default=0.0): padding right of cell * t (float, default=0.0): padding right of cell * b (float, default=0.0): padding bottom of cell - Use 'horizontal_spacing' and 'vertical_spacing' to adjust the spacing in between the subplots. insets (kwarg, list of dictionaries): Inset specifications. - Each item in 'insets' is a dictionary. The available keys are: * cell (tuple, default=(1,1)): (row, col) index of the subplot cell to overlay inset axes onto. * is_3d (boolean, default=False): flag for 3d scenes * l (float, default=0.0): padding left of inset in fraction of cell width * w (float or 'to_end', default='to_end') inset width in fraction of cell width ('to_end': to cell right edge) * b (float, default=0.0): padding bottom of inset in fraction of cell height * h (float or 'to_end', default='to_end') inset height in fraction of cell height ('to_end': to cell top edge) """ # TODO: protected until #282 from plotly.graph_objs import graph_objs # Throw exception for non-integer rows and cols if not isinstance(rows, int) or rows <= 0: raise Exception("Keyword argument 'rows' " "must be an int greater than 0") if not isinstance(cols, int) or cols <= 0: raise Exception("Keyword argument 'cols' " "must be an int greater than 0") # Dictionary of things start_cell START_CELL_all = { 'bottom-left': { # 'natural' setup where x & y domains increase monotonically 'col_dir': 1, 'row_dir': 1 }, 'top-left': { # 'default' setup visually matching the 'specs' list of lists 'col_dir': 1, 'row_dir': -1 } # TODO maybe add 'bottom-right' and 'top-right' } # Throw exception for invalid 'start_cell' values try: START_CELL = START_CELL_all[start_cell] except KeyError: raise Exception("Invalid 'start_cell' value") # Throw exception if non-valid kwarg is sent VALID_KWARGS = ['horizontal_spacing', 'vertical_spacing', 'specs', 'insets', 'subplot_titles'] for key in kwargs.keys(): if key not in VALID_KWARGS: raise Exception("Invalid keyword argument: '{0}'".format(key)) # Set 'subplot_titles' subplot_titles = kwargs.get('subplot_titles', [""] * rows * cols) # Set 'horizontal_spacing' / 'vertical_spacing' w.r.t. rows / cols try: horizontal_spacing = float(kwargs['horizontal_spacing']) except KeyError: horizontal_spacing = 0.2 / cols try: vertical_spacing = float(kwargs['vertical_spacing']) except KeyError: if 'subplot_titles' in kwargs: vertical_spacing = 0.5 / rows else: vertical_spacing = 0.3 / rows # Sanitize 'specs' (must be a list of lists) exception_msg = "Keyword argument 'specs' must be a list of lists" try: specs = kwargs['specs'] if not isinstance(specs, list): raise Exception(exception_msg) else: for spec_row in specs: if not isinstance(spec_row, list): raise Exception(exception_msg) except KeyError: specs = [[{} for c in range(cols)] for r in range(rows)] # default 'specs' # Throw exception if specs is over or under specified if len(specs) != rows: raise Exception("The number of rows in 'specs' " "must be equal to 'rows'") for r, spec_row in enumerate(specs): if len(spec_row) != cols: raise Exception("The number of columns in 'specs' " "must be equal to 'cols'") # Sanitize 'insets' try: insets = kwargs['insets'] if not isinstance(insets, list): raise Exception("Keyword argument 'insets' must be a list") except KeyError: insets = False # Throw exception if non-valid key / fill in defaults def _check_keys_and_fill(name, arg, defaults): def _checks(item, defaults): if item is None: return if not isinstance(item, dict): raise Exception("Items in keyword argument '{name}' must be " "dictionaries or None".format(name=name)) for k in item.keys(): if k not in defaults.keys(): raise Exception("Invalid key '{k}' in keyword " "argument '{name}'".format(k=k, name=name)) for k in defaults.keys(): if k not in item.keys(): item[k] = defaults[k] for arg_i in arg: if isinstance(arg_i, list): for arg_ii in arg_i: _checks(arg_ii, defaults) elif isinstance(arg_i, dict): _checks(arg_i, defaults) # Default spec key-values SPEC_defaults = dict( is_3d=False, colspan=1, rowspan=1, l=0.0, r=0.0, b=0.0, t=0.0 # TODO add support for 'w' and 'h' ) _check_keys_and_fill('specs', specs, SPEC_defaults) # Default inset key-values if insets: INSET_defaults = dict( cell=(1, 1), is_3d=False, l=0.0, w='to_end', b=0.0, h='to_end' ) _check_keys_and_fill('insets', insets, INSET_defaults) # Set width & height of each subplot cell (excluding padding) width = (1. - horizontal_spacing * (cols - 1)) / cols height = (1. - vertical_spacing * (rows - 1)) / rows # Built row/col sequence using 'row_dir' and 'col_dir' COL_DIR = START_CELL['col_dir'] ROW_DIR = START_CELL['row_dir'] col_seq = range(cols)[::COL_DIR] row_seq = range(rows)[::ROW_DIR] # [grid] Build subplot grid (coord tuple of cell) grid = [[((width + horizontal_spacing) * c, (height + vertical_spacing) * r) for c in col_seq] for r in row_seq] # [grid_ref] Initialize the grid and insets' axis-reference lists grid_ref = [[None for c in range(cols)] for r in range(rows)] insets_ref = [None for inset in range(len(insets))] if insets else None layout = graph_objs.Layout() # init layout object # Function handling logic around 2d axis labels # Returns 'x{}' | 'y{}' def _get_label(x_or_y, r, c, cnt, shared_axes): # Default label (given strictly by cnt) label = "{x_or_y}{cnt}".format(x_or_y=x_or_y, cnt=cnt) if isinstance(shared_axes, bool): if shared_axes: if x_or_y == 'x': label = "{x_or_y}{c}".format(x_or_y=x_or_y, c=c + 1) if x_or_y == 'y': label = "{x_or_y}{r}".format(x_or_y=x_or_y, r=r + 1) if isinstance(shared_axes, list): if isinstance(shared_axes[0], tuple): shared_axes = [shared_axes] # TODO put this elsewhere for shared_axis in shared_axes: if (r + 1, c + 1) in shared_axis: label = { 'x': "x{0}".format(shared_axis[0][1]), 'y': "y{0}".format(shared_axis[0][0]) }[x_or_y] return label # Row in grid of anchor row if shared_xaxes=True ANCHOR_ROW = 0 if ROW_DIR > 0 else rows - 1 # Function handling logic around 2d axis anchors # Return 'x{}' | 'y{}' | 'free' | False def _get_anchors(r, c, x_cnt, y_cnt, shared_xaxes, shared_yaxes): # Default anchors (give strictly by cnt) x_anchor = "y{y_cnt}".format(y_cnt=y_cnt) y_anchor = "x{x_cnt}".format(x_cnt=x_cnt) if isinstance(shared_xaxes, bool): if shared_xaxes: if r != ANCHOR_ROW: x_anchor = False y_anchor = 'free' if shared_yaxes and c != 0: # TODO covers all cases? y_anchor = False return x_anchor, y_anchor elif isinstance(shared_xaxes, list): if isinstance(shared_xaxes[0], tuple): shared_xaxes = [shared_xaxes] # TODO put this elsewhere for shared_xaxis in shared_xaxes: if (r + 1, c + 1) in shared_xaxis[1:]: x_anchor = False y_anchor = 'free' # TODO covers all cases? if isinstance(shared_yaxes, bool): if shared_yaxes: if c != 0: y_anchor = False x_anchor = 'free' if shared_xaxes and r != ANCHOR_ROW: # TODO all cases? x_anchor = False return x_anchor, y_anchor elif isinstance(shared_yaxes, list): if isinstance(shared_yaxes[0], tuple): shared_yaxes = [shared_yaxes] # TODO put this elsewhere for shared_yaxis in shared_yaxes: if (r + 1, c + 1) in shared_yaxis[1:]: y_anchor = False x_anchor = 'free' # TODO covers all cases? return x_anchor, y_anchor list_of_domains = [] # added for subplot titles # Function pasting x/y domains in layout object (2d case) def _add_domain(layout, x_or_y, label, domain, anchor, position): name = label[0] + 'axis' + label[1:] graph_obj = '{X_or_Y}Axis'.format(X_or_Y=x_or_y.upper()) axis = getattr(graph_objs, graph_obj)(domain=domain) if anchor: axis['anchor'] = anchor if isinstance(position, float): axis['position'] = position layout[name] = axis list_of_domains.append(domain) # added for subplot titles # Function pasting x/y domains in layout object (3d case) def _add_domain_is_3d(layout, s_label, x_domain, y_domain): scene = graph_objs.Scene(domain={'x': x_domain, 'y': y_domain}) layout[s_label] = scene x_cnt = y_cnt = s_cnt = 1 # subplot axis/scene counters # Loop through specs -- (r, c) <-> (row, col) for r, spec_row in enumerate(specs): for c, spec in enumerate(spec_row): if spec is None: # skip over None cells continue c_spanned = c + spec['colspan'] - 1 # get spanned c r_spanned = r + spec['rowspan'] - 1 # get spanned r # Throw exception if 'colspan' | 'rowspan' is too large for grid if c_spanned >= cols: raise Exception("Some 'colspan' value is too large for " "this subplot grid.") if r_spanned >= rows: raise Exception("Some 'rowspan' value is too large for " "this subplot grid.") # Get x domain using grid and colspan x_s = grid[r][c][0] + spec['l'] x_e = grid[r][c_spanned][0] + width - spec['r'] x_domain = [x_s, x_e] # Get y domain (dep. on row_dir) using grid & r_spanned if ROW_DIR > 0: y_s = grid[r][c][1] + spec['b'] y_e = grid[r_spanned][c][1] + height - spec['t'] else: y_s = grid[r_spanned][c][1] + spec['b'] y_e = grid[r][c][1] + height - spec['t'] y_domain = [y_s, y_e] if spec['is_3d']: # Add scene to layout s_label = 'scene{0}'.format(s_cnt) _add_domain_is_3d(layout, s_label, x_domain, y_domain) grid_ref[r][c] = (s_label, ) s_cnt += 1 else: # Get axis label and anchor x_label = _get_label('x', r, c, x_cnt, shared_xaxes) y_label = _get_label('y', r, c, y_cnt, shared_yaxes) x_anchor, y_anchor = _get_anchors(r, c, x_cnt, y_cnt, shared_xaxes, shared_yaxes) # Add a xaxis to layout (N.B anchor == False -> no axis) if x_anchor: if x_anchor == 'free': x_position = y_domain[0] else: x_position = False _add_domain(layout, 'x', x_label, x_domain, x_anchor, x_position) x_cnt += 1 # Add a yaxis to layout (N.B anchor == False -> no axis) if y_anchor: if y_anchor == 'free': y_position = x_domain[0] else: y_position = False _add_domain(layout, 'y', y_label, y_domain, y_anchor, y_position) y_cnt += 1 grid_ref[r][c] = (x_label, y_label) # fill in ref # Loop through insets if insets: for i_inset, inset in enumerate(insets): r = inset['cell'][0] - 1 c = inset['cell'][1] - 1 # Throw exception if r | c is out of range if not (0 <= r < rows): raise Exception("Some 'cell' row value is out of range. " "Note: the starting cell is (1, 1)") if not (0 <= c < cols): raise Exception("Some 'cell' col value is out of range. " "Note: the starting cell is (1, 1)") # Get inset x domain using grid x_s = grid[r][c][0] + inset['l'] * width if inset['w'] == 'to_end': x_e = grid[r][c][0] + width else: x_e = x_s + inset['w'] * width x_domain = [x_s, x_e] # Get inset y domain using grid y_s = grid[r][c][1] + inset['b'] * height if inset['h'] == 'to_end': y_e = grid[r][c][1] + height else: y_e = y_s + inset['h'] * height y_domain = [y_s, y_e] if inset['is_3d']: # Add scene to layout s_label = 'scene{0}'.format(s_cnt) _add_domain_is_3d(layout, s_label, x_domain, y_domain) insets_ref[i_inset] = (s_label, ) s_cnt += 1 else: # Get axis label and anchor x_label = _get_label('x', False, False, x_cnt, False) y_label = _get_label('y', False, False, y_cnt, False) x_anchor, y_anchor = _get_anchors(r, c, x_cnt, y_cnt, False, False) # Add a xaxis to layout (N.B insets always have anchors) _add_domain(layout, 'x', x_label, x_domain, x_anchor, False) x_cnt += 1 # Add a yayis to layout (N.B insets always have anchors) _add_domain(layout, 'y', y_label, y_domain, y_anchor, False) y_cnt += 1 insets_ref[i_inset] = (x_label, y_label) # fill in ref # [grid_str] Set the grid's string representation sp = " " # space between cell s_str = "[ " # cell start string e_str = " ]" # cell end string colspan_str = ' -' # colspan string rowspan_str = ' |' # rowspan string empty_str = ' (empty) ' # empty cell string # Init grid_str with intro message grid_str = "This is the format of your plot grid:\n" # Init tmp list of lists of strings (sorta like 'grid_ref' but w/ strings) _tmp = [['' for c in range(cols)] for r in range(rows)] # Define cell string as function of (r, c) and grid_ref def _get_cell_str(r, c, ref): return '({r},{c}) {ref}'.format(r=r + 1, c=c + 1, ref=','.join(ref)) # Find max len of _cell_str, add define a padding function cell_len = max([len(_get_cell_str(r, c, ref)) for r, row_ref in enumerate(grid_ref) for c, ref in enumerate(row_ref) if ref]) + len(s_str) + len(e_str) def _pad(s, cell_len=cell_len): return ' ' * (cell_len - len(s)) # Loop through specs, fill in _tmp for r, spec_row in enumerate(specs): for c, spec in enumerate(spec_row): ref = grid_ref[r][c] if ref is None: if _tmp[r][c] == '': _tmp[r][c] = empty_str + _pad(empty_str) continue cell_str = s_str + _get_cell_str(r, c, ref) if spec['colspan'] > 1: for cc in range(1, spec['colspan'] - 1): _tmp[r][c + cc] = colspan_str + _pad(colspan_str) _tmp[r][c + spec['colspan'] - 1] = ( colspan_str + _pad(colspan_str + e_str)) + e_str else: cell_str += e_str if spec['rowspan'] > 1: for rr in range(1, spec['rowspan'] - 1): _tmp[r + rr][c] = rowspan_str + _pad(rowspan_str) for cc in range(spec['colspan']): _tmp[r + spec['rowspan'] - 1][c + cc] = ( rowspan_str + _pad(rowspan_str)) _tmp[r][c] = cell_str + _pad(cell_str) # Append grid_str using data from _tmp in the correct order for r in row_seq[::-1]: grid_str += sp.join(_tmp[r]) + '\n' # Append grid_str to include insets info if insets: grid_str += "\nWith insets:\n" for i_inset, inset in enumerate(insets): r = inset['cell'][0] - 1 c = inset['cell'][1] - 1 ref = grid_ref[r][c] grid_str += ( s_str + ','.join(insets_ref[i_inset]) + e_str + ' over ' + s_str + _get_cell_str(r, c, ref) + e_str + '\n' ) # Add subplot titles # If shared_axes is False (default) use list_of_domains # This is used for insets and irregular layouts if not shared_xaxes and not shared_yaxes: x_dom = list_of_domains[::2] y_dom = list_of_domains[1::2] subtitle_pos_x = [] subtitle_pos_y = [] for x_domains in x_dom: subtitle_pos_x.append(sum(x_domains) / 2) for y_domains in y_dom: subtitle_pos_y.append(y_domains[1]) # If shared_axes is True the domin of each subplot is not returned so the # title position must be calculated for each subplot else: subtitle_pos_x = [None] * cols subtitle_pos_y = [None] * rows delt_x = (x_e - x_s) for index in range(cols): subtitle_pos_x[index] = ((delt_x / 2) + ((delt_x + horizontal_spacing) * index)) subtitle_pos_x *= rows for index in range(rows): subtitle_pos_y[index] = (1 - ((y_e + vertical_spacing) * index)) subtitle_pos_y *= cols subtitle_pos_y = sorted(subtitle_pos_y, reverse=True) plot_titles = [] for index in range(len(subplot_titles)): if not subplot_titles[index]: pass else: plot_titles.append({'y': subtitle_pos_y[index], 'xref': 'paper', 'x': subtitle_pos_x[index], 'yref': 'paper', 'text': subplot_titles[index], 'showarrow': False, 'font': graph_objs.Font(size=16), 'xanchor': 'center', 'yanchor': 'bottom' }) layout['annotations'] = plot_titles if print_grid: print(grid_str) fig = graph_objs.Figure(layout=layout) fig.__dict__['_grid_ref'] = grid_ref fig.__dict__['_grid_str'] = grid_str return fig def get_valid_graph_obj(obj, obj_type=None): """Returns a new graph object that won't raise. CAREFUL: this will *silently* strip out invalid pieces of the object. """ # TODO: Deprecate or move. #283 from plotly.graph_objs import graph_objs try: cls = getattr(graph_objs, obj_type) except (AttributeError, KeyError): raise exceptions.PlotlyError( "'{}' is not a recognized graph_obj.".format(obj_type) ) return cls(obj, _raise=False) def validate(obj, obj_type): """Validate a dictionary, list, or graph object as 'obj_type'. This will not alter the 'obj' referenced in the call signature. It will raise an error if the 'obj' reference could not be instantiated as a valid 'obj_type' graph object. """ # TODO: Deprecate or move. #283 from plotly.graph_objs import graph_objs if obj_type not in graph_reference.CLASSES: obj_type = graph_reference.string_to_class_name(obj_type) try: cls = getattr(graph_objs, obj_type) except AttributeError: raise exceptions.PlotlyError( "'{0}' is not a recognizable graph_obj.". format(obj_type)) cls(obj) # this will raise on invalid keys/items def _replace_newline(obj): """Replaces '\n' with '
' for all strings in a collection.""" if isinstance(obj, dict): d = dict() for key, val in list(obj.items()): d[key] = _replace_newline(val) return d elif isinstance(obj, list): l = list() for index, entry in enumerate(obj): l += [_replace_newline(entry)] return l elif isinstance(obj, six.string_types): s = obj.replace('\n', '
') if s != obj: warnings.warn("Looks like you used a newline character: '\\n'.\n\n" "Plotly uses a subset of HTML escape characters\n" "to do things like newline (
), bold (),\n" "italics (), etc. Your newline characters \n" "have been converted to '
' so they will show \n" "up right on your Plotly figure!") return s else: return obj # we return the actual reference... but DON'T mutate. if _ipython_imported: class PlotlyDisplay(IPython.core.display.HTML): """An IPython display object for use with plotly urls PlotlyDisplay objects should be instantiated with a url for a plot. IPython will *choose* the proper display representation from any Python object, and using provided methods if they exist. By defining the following, if an HTML display is unusable, the PlotlyDisplay object can provide alternate representations. """ def __init__(self, url, width, height): self.resource = url self.embed_code = get_embed(url, width=width, height=height) super(PlotlyDisplay, self).__init__(data=self.embed_code) def _repr_html_(self): return self.embed_code def return_figure_from_figure_or_data(figure_or_data, validate_figure): from plotly.graph_objs import graph_objs if isinstance(figure_or_data, dict): figure = figure_or_data elif isinstance(figure_or_data, list): figure = {'data': figure_or_data} else: raise exceptions.PlotlyError("The `figure_or_data` positional " "argument must be either " "`dict`-like or `list`-like.") if validate_figure: try: graph_objs.Figure(figure) except exceptions.PlotlyError as err: raise exceptions.PlotlyError("Invalid 'figure_or_data' argument. " "Plotly will not be able to properly " "parse the resulting JSON. If you " "want to send this 'figure_or_data' " "to Plotly anyway (not recommended), " "you can set 'validate=False' as a " "plot option.\nHere's why you're " "seeing this error:\n\n{0}" "".format(err)) if not figure['data']: raise exceptions.PlotlyEmptyDataError( "Empty data list found. Make sure that you populated the " "list of data objects you're sending and try again.\n" "Questions? support@plot.ly" ) return figure # Default colours for finance charts _DEFAULT_INCREASING_COLOR = '#3D9970' # http://clrs.cc _DEFAULT_DECREASING_COLOR = '#FF4136' DIAG_CHOICES = ['scatter', 'histogram', 'box'] VALID_COLORMAP_TYPES = ['cat', 'seq'] class FigureFactory(object): """ BETA functions to create specific chart types. This is beta as in: subject to change in a backwards incompatible way without notice. Supported chart types include candlestick, open high low close, quiver, streamline, distplot, dendrogram, annotated heatmap, and tables. See FigureFactory.create_candlestick, FigureFactory.create_ohlc, FigureFactory.create_quiver, FigureFactory.create_streamline, FigureFactory.create_distplot, FigureFactory.create_dendrogram, FigureFactory.create_annotated_heatmap, or FigureFactory.create_table for more information and examples of a specific chart type. """ @staticmethod def _make_colorscale(colors, scale=None): """ Makes a colorscale from a list of colors and scale Takes a list of colors and scales and constructs a colorscale based on the colors in sequential order. If 'scale' is left empty, a linear- interpolated colorscale will be generated. If 'scale' is a specificed list, it must be the same legnth as colors and must contain all floats For documentation regarding to the form of the output, see https://plot.ly/python/reference/#mesh3d-colorscale """ colorscale = [] if not scale: for j, color in enumerate(colors): colorscale.append([j * 1./(len(colors) - 1), color]) return colorscale else: colorscale = [list(tup) for tup in zip(scale, colors)] return colorscale @staticmethod def _convert_colorscale_to_rgb(colorscale): """ Converts the colors in a colorscale to rgb colors A colorscale is an array of arrays, each with a numeric value as the first item and a color as the second. This function specifically is converting a colorscale with tuple colors (each coordinate between 0 and 1) into a colorscale with the colors transformed into rgb colors """ for color in colorscale: color[1] = FigureFactory._convert_to_RGB_255( color[1] ) for color in colorscale: color[1] = FigureFactory._label_rgb( color[1] ) return colorscale @staticmethod def _make_linear_colorscale(colors): """ Makes a list of colors into a colorscale-acceptable form For documentation regarding to the form of the output, see https://plot.ly/python/reference/#mesh3d-colorscale """ scale = 1./(len(colors) - 1) return[[i * scale, color] for i, color in enumerate(colors)] @staticmethod def create_2D_density(x, y, colorscale='Earth', ncontours=20, hist_color=(0, 0, 0.5), point_color=(0, 0, 0.5), point_size=2, title='2D Density Plot', height=600, width=600): """ Returns figure for a 2D density plot :param (list|array) x: x-axis data for plot generation :param (list|array) y: y-axis data for plot generation :param (str|tuple|list) colorscale: either a plotly scale name, an rgb or hex color, a color tuple or a list or tuple of colors. An rgb color is of the form 'rgb(x, y, z)' where x, y, z belong to the interval [0, 255] and a color tuple is a tuple of the form (a, b, c) where a, b and c belong to [0, 1]. If colormap is a list, it must contain the valid color types aforementioned as its members. :param (int) ncontours: the number of 2D contours to draw on the plot :param (str) hist_color: the color of the plotted histograms :param (str) point_color: the color of the scatter points :param (str) point_size: the color of the scatter points :param (str) title: set the title for the plot :param (float) height: the height of the chart :param (float) width: the width of the chart Example 1: Simple 2D Density Plot ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF import numpy as np # Make data points t = np.linspace(-1,1.2,2000) x = (t**3)+(0.3*np.random.randn(2000)) y = (t**6)+(0.3*np.random.randn(2000)) # Create a figure fig = FF.create_2D_density(x, y) # Plot the data py.iplot(fig, filename='simple-2d-density') ``` Example 2: Using Parameters ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF import numpy as np # Make data points t = np.linspace(-1,1.2,2000) x = (t**3)+(0.3*np.random.randn(2000)) y = (t**6)+(0.3*np.random.randn(2000)) # Create custom colorscale colorscale = ['#7A4579', '#D56073', 'rgb(236,158,105)', (1, 1, 0.2), (0.98,0.98,0.98)] # Create a figure fig = FF.create_2D_density( x, y, colorscale=colorscale, hist_color='rgb(255, 237, 222)', point_size=3) # Plot the data py.iplot(fig, filename='use-parameters') ``` """ from plotly.graph_objs import graph_objs from numbers import Number # validate x and y are filled with numbers only for array in [x, y]: if not all(isinstance(element, Number) for element in array): raise exceptions.PlotlyError( "All elements of your 'x' and 'y' lists must be numbers." ) # validate x and y are the same length if len(x) != len(y): raise exceptions.PlotlyError( "Both lists 'x' and 'y' must be the same length." ) colorscale = FigureFactory._validate_colors(colorscale, 'rgb') colorscale = FigureFactory._make_linear_colorscale(colorscale) # validate hist_color and point_color hist_color = FigureFactory._validate_colors(hist_color, 'rgb') point_color = FigureFactory._validate_colors(point_color, 'rgb') trace1 = graph_objs.Scatter( x=x, y=y, mode='markers', name='points', marker=dict( color=point_color[0], size=point_size, opacity=0.4 ) ) trace2 = graph_objs.Histogram2dcontour( x=x, y=y, name='density', ncontours=ncontours, colorscale=colorscale, reversescale=True, showscale=False ) trace3 = graph_objs.Histogram( x=x, name='x density', marker=dict(color=hist_color[0]), yaxis='y2' ) trace4 = graph_objs.Histogram( y=y, name='y density', marker=dict(color=hist_color[0]), xaxis='x2' ) data = [trace1, trace2, trace3, trace4] layout = graph_objs.Layout( showlegend=False, autosize=False, title=title, height=height, width=width, xaxis=dict( domain=[0, 0.85], showgrid=False, zeroline=False ), yaxis=dict( domain=[0, 0.85], showgrid=False, zeroline=False ), margin=dict( t=50 ), hovermode='closest', bargap=0, xaxis2=dict( domain=[0.85, 1], showgrid=False, zeroline=False ), yaxis2=dict( domain=[0.85, 1], showgrid=False, zeroline=False ) ) fig = graph_objs.Figure(data=data, layout=layout) return fig @staticmethod def _validate_gantt(df): """ Validates the inputted dataframe or list """ if _pandas_imported and isinstance(df, pd.core.frame.DataFrame): # validate that df has all the required keys for key in REQUIRED_GANTT_KEYS: if key not in df: raise exceptions.PlotlyError( "The columns in your dataframe must include the " "keys".format(REQUIRED_GANTT_KEYS) ) num_of_rows = len(df.index) chart = [] for index in range(num_of_rows): task_dict = {} for key in df: task_dict[key] = df.ix[index][key] chart.append(task_dict) return chart # validate if df is a list if not isinstance(df, list): raise exceptions.PlotlyError("You must input either a dataframe " "or a list of dictionaries.") # validate if df is empty if len(df) <= 0: raise exceptions.PlotlyError("Your list is empty. It must contain " "at least one dictionary.") if not isinstance(df[0], dict): raise exceptions.PlotlyError("Your list must only " "include dictionaries.") return df @staticmethod def _gantt(chart, colors, title, bar_width, showgrid_x, showgrid_y, height, width, tasks=None, task_names=None, data=None): """ Refer to FigureFactory.create_gantt() for docstring """ if tasks is None: tasks = [] if task_names is None: task_names = [] if data is None: data = [] for index in range(len(chart)): task = dict(x0=chart[index]['Start'], x1=chart[index]['Finish'], name=chart[index]['Task']) tasks.append(task) shape_template = { 'type': 'rect', 'xref': 'x', 'yref': 'y', 'opacity': 1, 'line': { 'width': 0, }, 'yref': 'y', } color_index = 0 for index in range(len(tasks)): tn = tasks[index]['name'] task_names.append(tn) del tasks[index]['name'] tasks[index].update(shape_template) tasks[index]['y0'] = index - bar_width tasks[index]['y1'] = index + bar_width # check if colors need to be looped if color_index >= len(colors): color_index = 0 tasks[index]['fillcolor'] = colors[color_index] # Add a line for hover text and autorange data.append( dict( x=[tasks[index]['x0'], tasks[index]['x1']], y=[index, index], name='', marker={'color': 'white'} ) ) color_index += 1 layout = dict( title=title, showlegend=False, height=height, width=width, shapes=[], hovermode='closest', yaxis=dict( showgrid=showgrid_y, ticktext=task_names, tickvals=list(range(len(tasks))), range=[-1, len(tasks) + 1], autorange=False, zeroline=False, ), xaxis=dict( showgrid=showgrid_x, zeroline=False, rangeselector=dict( buttons=list([ dict(count=7, label='1w', step='day', stepmode='backward'), dict(count=1, label='1m', step='month', stepmode='backward'), dict(count=6, label='6m', step='month', stepmode='backward'), dict(count=1, label='YTD', step='year', stepmode='todate'), dict(count=1, label='1y', step='year', stepmode='backward'), dict(step='all') ]) ), type='date' ) ) layout['shapes'] = tasks fig = dict(data=data, layout=layout) return fig @staticmethod def _gantt_colorscale(chart, colors, title, index_col, show_colorbar, bar_width, showgrid_x, showgrid_y, height, width, tasks=None, task_names=None, data=None): """ Refer to FigureFactory.create_gantt() for docstring """ from numbers import Number if tasks is None: tasks = [] if task_names is None: task_names = [] if data is None: data = [] showlegend = False for index in range(len(chart)): task = dict(x0=chart[index]['Start'], x1=chart[index]['Finish'], name=chart[index]['Task']) tasks.append(task) shape_template = { 'type': 'rect', 'xref': 'x', 'yref': 'y', 'opacity': 1, 'line': { 'width': 0, }, 'yref': 'y', } # compute the color for task based on indexing column if isinstance(chart[0][index_col], Number): # check that colors has at least 2 colors if len(colors) < 2: raise exceptions.PlotlyError( "You must use at least 2 colors in 'colors' if you " "are using a colorscale. However only the first two " "colors given will be used for the lower and upper " "bounds on the colormap." ) for index in range(len(tasks)): tn = tasks[index]['name'] task_names.append(tn) del tasks[index]['name'] tasks[index].update(shape_template) tasks[index]['y0'] = index - bar_width tasks[index]['y1'] = index + bar_width # unlabel color colors = FigureFactory._color_parser( colors, FigureFactory._unlabel_rgb ) lowcolor = colors[0] highcolor = colors[1] intermed = (chart[index][index_col])/100.0 intermed_color = FigureFactory._find_intermediate_color( lowcolor, highcolor, intermed ) intermed_color = FigureFactory._color_parser( intermed_color, FigureFactory._label_rgb ) tasks[index]['fillcolor'] = intermed_color # relabel colors with 'rgb' colors = FigureFactory._color_parser( colors, FigureFactory._label_rgb ) # add a line for hover text and autorange data.append( dict( x=[tasks[index]['x0'], tasks[index]['x1']], y=[index, index], name='', marker={'color': 'white'} ) ) if show_colorbar is True: # generate dummy data for colorscale visibility data.append( dict( x=[tasks[index]['x0'], tasks[index]['x0']], y=[index, index], name='', marker={'color': 'white', 'colorscale': [[0, colors[0]], [1, colors[1]]], 'showscale': True, 'cmax': 100, 'cmin': 0} ) ) if isinstance(chart[0][index_col], str): index_vals = [] for row in range(len(tasks)): if chart[row][index_col] not in index_vals: index_vals.append(chart[row][index_col]) index_vals.sort() if len(colors) < len(index_vals): raise exceptions.PlotlyError( "Error. The number of colors in 'colors' must be no less " "than the number of unique index values in your group " "column." ) # make a dictionary assignment to each index value index_vals_dict = {} # define color index c_index = 0 for key in index_vals: if c_index > len(colors) - 1: c_index = 0 index_vals_dict[key] = colors[c_index] c_index += 1 for index in range(len(tasks)): tn = tasks[index]['name'] task_names.append(tn) del tasks[index]['name'] tasks[index].update(shape_template) tasks[index]['y0'] = index - bar_width tasks[index]['y1'] = index + bar_width tasks[index]['fillcolor'] = index_vals_dict[ chart[index][index_col] ] # add a line for hover text and autorange data.append( dict( x=[tasks[index]['x0'], tasks[index]['x1']], y=[index, index], name='', marker={'color': 'white'} ) ) if show_colorbar is True: # generate dummy data to generate legend showlegend = True for k, index_value in enumerate(index_vals): data.append( dict( x=[tasks[index]['x0'], tasks[index]['x0']], y=[k, k], showlegend=True, name=str(index_value), hoverinfo='none', marker=dict( color=colors[k], size=1 ) ) ) layout = dict( title=title, showlegend=showlegend, height=height, width=width, shapes=[], hovermode='closest', yaxis=dict( showgrid=showgrid_y, ticktext=task_names, tickvals=list(range(len(tasks))), range=[-1, len(tasks) + 1], autorange=False, zeroline=False, ), xaxis=dict( showgrid=showgrid_x, zeroline=False, rangeselector=dict( buttons=list([ dict(count=7, label='1w', step='day', stepmode='backward'), dict(count=1, label='1m', step='month', stepmode='backward'), dict(count=6, label='6m', step='month', stepmode='backward'), dict(count=1, label='YTD', step='year', stepmode='todate'), dict(count=1, label='1y', step='year', stepmode='backward'), dict(step='all') ]) ), type='date' ) ) layout['shapes'] = tasks fig = dict(data=data, layout=layout) return fig @staticmethod def _gantt_dict(chart, colors, title, index_col, show_colorbar, bar_width, showgrid_x, showgrid_y, height, width, tasks=None, task_names=None, data=None): """ Refer to FigureFactory.create_gantt() for docstring """ if tasks is None: tasks = [] if task_names is None: task_names = [] if data is None: data = [] showlegend = False for index in range(len(chart)): task = dict(x0=chart[index]['Start'], x1=chart[index]['Finish'], name=chart[index]['Task']) tasks.append(task) shape_template = { 'type': 'rect', 'xref': 'x', 'yref': 'y', 'opacity': 1, 'line': { 'width': 0, }, 'yref': 'y', } index_vals = [] for row in range(len(tasks)): if chart[row][index_col] not in index_vals: index_vals.append(chart[row][index_col]) index_vals.sort() # verify each value in index column appears in colors dictionary for key in index_vals: if key not in colors: raise exceptions.PlotlyError( "If you are using colors as a dictionary, all of its " "keys must be all the values in the index column." ) for index in range(len(tasks)): tn = tasks[index]['name'] task_names.append(tn) del tasks[index]['name'] tasks[index].update(shape_template) tasks[index]['y0'] = index - bar_width tasks[index]['y1'] = index + bar_width tasks[index]['fillcolor'] = colors[chart[index][index_col]] # add a line for hover text and autorange data.append( dict( x=[tasks[index]['x0'], tasks[index]['x1']], y=[index, index], name='', marker={'color': 'white'} ) ) if show_colorbar is True: # generate dummy data to generate legend showlegend = True for k, index_value in enumerate(index_vals): data.append( dict( x=[tasks[index]['x0'], tasks[index]['x0']], y=[k, k], showlegend=True, hoverinfo='none', name=str(index_value), marker=dict( color=colors[index_value], size=1 ) ) ) layout = dict( title=title, showlegend=showlegend, height=height, width=width, shapes=[], hovermode='closest', yaxis=dict( showgrid=showgrid_y, ticktext=task_names, tickvals=list(range(len(tasks))), range=[-1, len(tasks) + 1], autorange=False, zeroline=False, ), xaxis=dict( showgrid=showgrid_x, zeroline=False, rangeselector=dict( buttons=list([ dict(count=7, label='1w', step='day', stepmode='backward'), dict(count=1, label='1m', step='month', stepmode='backward'), dict(count=6, label='6m', step='month', stepmode='backward'), dict(count=1, label='YTD', step='year', stepmode='todate'), dict(count=1, label='1y', step='year', stepmode='backward'), dict(step='all') ]) ), type='date' ) ) layout['shapes'] = tasks fig = dict(data=data, layout=layout) return fig @staticmethod def create_gantt(df, colors=None, index_col=None, show_colorbar=False, reverse_colors=False, title='Gantt Chart', bar_width=0.2, showgrid_x=False, showgrid_y=False, height=600, width=900, tasks=None, task_names=None, data=None): """ Returns figure for a gantt chart :param (array|list) df: input data for gantt chart. Must be either a a dataframe or a list. If dataframe, the columns must include 'Task', 'Start' and 'Finish'. Other columns can be included and used for indexing. If a list, its elements must be dictionaries with the same required column headers: 'Task', 'Start' and 'Finish'. :param (str|list|dict|tuple) colors: either a plotly scale name, an rgb or hex color, a color tuple or a list of colors. An rgb color is of the form 'rgb(x, y, z)' where x, y, z belong to the interval [0, 255] and a color tuple is a tuple of the form (a, b, c) where a, b and c belong to [0, 1]. If colors is a list, it must contain the valid color types aforementioned as its members. If a dictionary, all values of the indexing column must be keys in colors. :param (str|float) index_col: the column header (if df is a data frame) that will function as the indexing column. If df is a list, index_col must be one of the keys in all the items of df. :param (bool) show_colorbar: determines if colorbar will be visible. Only applies if values in the index column are numeric. :param (bool) reverse_colors: reverses the order of selected colors :param (str) title: the title of the chart :param (float) bar_width: the width of the horizontal bars in the plot :param (bool) showgrid_x: show/hide the x-axis grid :param (bool) showgrid_y: show/hide the y-axis grid :param (float) height: the height of the chart :param (float) width: the width of the chart Example 1: Simple Gantt Chart ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF # Make data for chart df = [dict(Task="Job A", Start='2009-01-01', Finish='2009-02-30'), dict(Task="Job B", Start='2009-03-05', Finish='2009-04-15'), dict(Task="Job C", Start='2009-02-20', Finish='2009-05-30')] # Create a figure fig = FF.create_gantt(df) # Plot the data py.iplot(fig, filename='Simple Gantt Chart', world_readable=True) ``` Example 2: Index by Column with Numerical Entries ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF # Make data for chart df = [dict(Task="Job A", Start='2009-01-01', Finish='2009-02-30', Complete=10), dict(Task="Job B", Start='2009-03-05', Finish='2009-04-15', Complete=60), dict(Task="Job C", Start='2009-02-20', Finish='2009-05-30', Complete=95)] # Create a figure with Plotly colorscale fig = FF.create_gantt(df, colors='Blues', index_col='Complete', show_colorbar=True, bar_width=0.5, showgrid_x=True, showgrid_y=True) # Plot the data py.iplot(fig, filename='Numerical Entries', world_readable=True) ``` Example 3: Index by Column with String Entries ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF # Make data for chart df = [dict(Task="Job A", Start='2009-01-01', Finish='2009-02-30', Resource='Apple'), dict(Task="Job B", Start='2009-03-05', Finish='2009-04-15', Resource='Grape'), dict(Task="Job C", Start='2009-02-20', Finish='2009-05-30', Resource='Banana')] # Create a figure with Plotly colorscale fig = FF.create_gantt(df, colors=['rgb(200, 50, 25)', (1, 0, 1), '#6c4774'], index_col='Resource', reverse_colors=True, show_colorbar=True) # Plot the data py.iplot(fig, filename='String Entries', world_readable=True) ``` Example 4: Use a dictionary for colors ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF # Make data for chart df = [dict(Task="Job A", Start='2009-01-01', Finish='2009-02-30', Resource='Apple'), dict(Task="Job B", Start='2009-03-05', Finish='2009-04-15', Resource='Grape'), dict(Task="Job C", Start='2009-02-20', Finish='2009-05-30', Resource='Banana')] # Make a dictionary of colors colors = {'Apple': 'rgb(255, 0, 0)', 'Grape': 'rgb(170, 14, 200)', 'Banana': (1, 1, 0.2)} # Create a figure with Plotly colorscale fig = FF.create_gantt(df, colors=colors, index_col='Resource', show_colorbar=True) # Plot the data py.iplot(fig, filename='dictioanry colors', world_readable=True) ``` Example 5: Use a pandas dataframe ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF import pandas as pd # Make data as a dataframe df = pd.DataFrame([['Run', '2010-01-01', '2011-02-02', 10], ['Fast', '2011-01-01', '2012-06-05', 55], ['Eat', '2012-01-05', '2013-07-05', 94]], columns=['Task', 'Start', 'Finish', 'Complete']) # Create a figure with Plotly colorscale fig = FF.create_gantt(df, colors='Blues', index_col='Complete', show_colorbar=True, bar_width=0.5, showgrid_x=True, showgrid_y=True) # Plot the data py.iplot(fig, filename='data with dataframe', world_readable=True) ``` """ # validate gantt input data chart = FigureFactory._validate_gantt(df) if index_col: if index_col not in chart[0]: raise exceptions.PlotlyError( "In order to use an indexing column and assign colors to " "the values of the index, you must choose an actual " "column name in the dataframe or key if a list of " "dictionaries is being used.") # validate gantt index column index_list = [] for dictionary in chart: index_list.append(dictionary[index_col]) FigureFactory._validate_index(index_list) # Validate colors if isinstance(colors, dict): colors = FigureFactory._validate_colors_dict(colors, 'rgb') else: colors = FigureFactory._validate_colors(colors, 'rgb') if reverse_colors is True: colors.reverse() if not index_col: if isinstance(colors, dict): raise exceptions.PlotlyError( "Error. You have set colors to a dictionary but have not " "picked an index. An index is required if you are " "assigning colors to particular values in a dictioanry." ) fig = FigureFactory._gantt( chart, colors, title, bar_width, showgrid_x, showgrid_y, height, width, tasks=None, task_names=None, data=None ) return fig else: if not isinstance(colors, dict): fig = FigureFactory._gantt_colorscale( chart, colors, title, index_col, show_colorbar, bar_width, showgrid_x, showgrid_y, height, width, tasks=None, task_names=None, data=None ) return fig else: fig = FigureFactory._gantt_dict( chart, colors, title, index_col, show_colorbar, bar_width, showgrid_x, showgrid_y, height, width, tasks=None, task_names=None, data=None ) return fig @staticmethod def _validate_colors(colors, colortype='tuple'): """ Validates color(s) and returns a list of color(s) of a specified type """ from numbers import Number if colors is None: colors = DEFAULT_PLOTLY_COLORS if isinstance(colors, str): if colors in PLOTLY_SCALES: colors = PLOTLY_SCALES[colors] elif 'rgb' in colors or '#' in colors: colors = [colors] else: raise exceptions.PlotlyError( "If your colors variable is a string, it must be a " "Plotly scale, an rgb color or a hex color.") elif isinstance(colors, tuple): if isinstance(colors[0], Number): colors = [colors] else: colors = list(colors) # convert color elements in list to tuple color for j, each_color in enumerate(colors): if 'rgb' in each_color: each_color = FigureFactory._color_parser( each_color, FigureFactory._unlabel_rgb ) for value in each_color: if value > 255.0: raise exceptions.PlotlyError( "Whoops! The elements in your rgb colors " "tuples cannot exceed 255.0." ) each_color = FigureFactory._color_parser( each_color, FigureFactory._unconvert_from_RGB_255 ) colors[j] = each_color if '#' in each_color: each_color = FigureFactory._color_parser( each_color, FigureFactory._hex_to_rgb ) each_color = FigureFactory._color_parser( each_color, FigureFactory._unconvert_from_RGB_255 ) colors[j] = each_color if isinstance(each_color, tuple): for value in each_color: if value > 1.0: raise exceptions.PlotlyError( "Whoops! The elements in your colors tuples " "cannot exceed 1.0." ) colors[j] = each_color if colortype == 'rgb': for j, each_color in enumerate(colors): rgb_color = FigureFactory._color_parser( each_color, FigureFactory._convert_to_RGB_255 ) colors[j] = FigureFactory._color_parser( rgb_color, FigureFactory._label_rgb ) return colors @staticmethod def _validate_colors_dict(colors, colortype='tuple'): """ Validates dictioanry of color(s) """ # validate each color element in the dictionary for key in colors: if 'rgb' in colors[key]: colors[key] = FigureFactory._color_parser( colors[key], FigureFactory._unlabel_rgb ) for value in colors[key]: if value > 255.0: raise exceptions.PlotlyError( "Whoops! The elements in your rgb colors " "tuples cannot exceed 255.0." ) colors[key] = FigureFactory._color_parser( colors[key], FigureFactory._unconvert_from_RGB_255 ) if '#' in colors[key]: colors[key] = FigureFactory._color_parser( colors[key], FigureFactory._hex_to_rgb ) colors[key] = FigureFactory._color_parser( colors[key], FigureFactory._unconvert_from_RGB_255 ) if isinstance(colors[key], tuple): for value in colors[key]: if value > 1.0: raise exceptions.PlotlyError( "Whoops! The elements in your colors tuples " "cannot exceed 1.0." ) if colortype == 'rgb': for key in colors: colors[key] = FigureFactory._color_parser( colors[key], FigureFactory._convert_to_RGB_255 ) colors[key] = FigureFactory._color_parser( colors[key], FigureFactory._label_rgb ) return colors @staticmethod def _calc_stats(data): """ Calculate statistics for use in violin plot. """ import numpy as np x = np.asarray(data, np.float) vals_min = np.min(x) vals_max = np.max(x) q2 = np.percentile(x, 50, interpolation='linear') q1 = np.percentile(x, 25, interpolation='lower') q3 = np.percentile(x, 75, interpolation='higher') iqr = q3 - q1 whisker_dist = 1.5 * iqr # in order to prevent drawing whiskers outside the interval # of data one defines the whisker positions as: d1 = np.min(x[x >= (q1 - whisker_dist)]) d2 = np.max(x[x <= (q3 + whisker_dist)]) return { 'min': vals_min, 'max': vals_max, 'q1': q1, 'q2': q2, 'q3': q3, 'd1': d1, 'd2': d2 } @staticmethod def _make_half_violin(x, y, fillcolor='#1f77b4', linecolor='rgb(0, 0, 0)'): """ Produces a sideways probability distribution fig violin plot. """ from plotly.graph_objs import graph_objs text = ['(pdf(y), y)=(' + '{:0.2f}'.format(x[i]) + ', ' + '{:0.2f}'.format(y[i]) + ')' for i in range(len(x))] return graph_objs.Scatter( x=x, y=y, mode='lines', name='', text=text, fill='tonextx', fillcolor=fillcolor, line=graph_objs.Line(width=0.5, color=linecolor, shape='spline'), hoverinfo='text', opacity=0.5 ) @staticmethod def _make_violin_rugplot(vals, pdf_max, distance, color='#1f77b4'): """ Returns a rugplot fig for a violin plot. """ from plotly.graph_objs import graph_objs return graph_objs.Scatter( y=vals, x=[-pdf_max-distance]*len(vals), marker=graph_objs.Marker( color=color, symbol='line-ew-open' ), mode='markers', name='', showlegend=False, hoverinfo='y' ) @staticmethod def _make_quartiles(q1, q3): """ Makes the upper and lower quartiles for a violin plot. """ from plotly.graph_objs import graph_objs return graph_objs.Scatter( x=[0, 0], y=[q1, q3], text=['lower-quartile: ' + '{:0.2f}'.format(q1), 'upper-quartile: ' + '{:0.2f}'.format(q3)], mode='lines', line=graph_objs.Line( width=4, color='rgb(0,0,0)' ), hoverinfo='text' ) @staticmethod def _make_median(q2): """ Formats the 'median' hovertext for a violin plot. """ from plotly.graph_objs import graph_objs return graph_objs.Scatter( x=[0], y=[q2], text=['median: ' + '{:0.2f}'.format(q2)], mode='markers', marker=dict(symbol='square', color='rgb(255,255,255)'), hoverinfo='text' ) @staticmethod def _make_non_outlier_interval(d1, d2): """ Returns the scatterplot fig of most of a violin plot. """ from plotly.graph_objs import graph_objs return graph_objs.Scatter( x=[0, 0], y=[d1, d2], name='', mode='lines', line=graph_objs.Line(width=1.5, color='rgb(0,0,0)') ) @staticmethod def _make_XAxis(xaxis_title, xaxis_range): """ Makes the x-axis for a violin plot. """ from plotly.graph_objs import graph_objs xaxis = graph_objs.XAxis(title=xaxis_title, range=xaxis_range, showgrid=False, zeroline=False, showline=False, mirror=False, ticks='', showticklabels=False, ) return xaxis @staticmethod def _make_YAxis(yaxis_title): """ Makes the y-axis for a violin plot. """ from plotly.graph_objs import graph_objs yaxis = graph_objs.YAxis(title=yaxis_title, showticklabels=True, autorange=True, ticklen=4, showline=True, zeroline=False, showgrid=False, mirror=False) return yaxis @staticmethod def _violinplot(vals, fillcolor='#1f77b4', rugplot=True): """ Refer to FigureFactory.create_violin() for docstring. """ import numpy as np from scipy import stats vals = np.asarray(vals, np.float) # summary statistics vals_min = FigureFactory._calc_stats(vals)['min'] vals_max = FigureFactory._calc_stats(vals)['max'] q1 = FigureFactory._calc_stats(vals)['q1'] q2 = FigureFactory._calc_stats(vals)['q2'] q3 = FigureFactory._calc_stats(vals)['q3'] d1 = FigureFactory._calc_stats(vals)['d1'] d2 = FigureFactory._calc_stats(vals)['d2'] # kernel density estimation of pdf pdf = stats.gaussian_kde(vals) # grid over the data interval xx = np.linspace(vals_min, vals_max, 100) # evaluate the pdf at the grid xx yy = pdf(xx) max_pdf = np.max(yy) # distance from the violin plot to rugplot distance = (2.0 * max_pdf)/10 if rugplot else 0 # range for x values in the plot plot_xrange = [-max_pdf - distance - 0.1, max_pdf + 0.1] plot_data = [FigureFactory._make_half_violin( -yy, xx, fillcolor=fillcolor), FigureFactory._make_half_violin( yy, xx, fillcolor=fillcolor), FigureFactory._make_non_outlier_interval(d1, d2), FigureFactory._make_quartiles(q1, q3), FigureFactory._make_median(q2)] if rugplot: plot_data.append(FigureFactory._make_violin_rugplot( vals, max_pdf, distance=distance, color=fillcolor) ) return plot_data, plot_xrange @staticmethod def _violin_no_colorscale(data, data_header, group_header, colors, use_colorscale, group_stats, height, width, title): """ Refer to FigureFactory.create_violin() for docstring. Returns fig for violin plot without colorscale. """ from plotly.graph_objs import graph_objs import numpy as np # collect all group names group_name = [] for name in data[group_header]: if name not in group_name: group_name.append(name) group_name.sort() gb = data.groupby([group_header]) L = len(group_name) fig = make_subplots(rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=True) color_index = 0 for k, gr in enumerate(group_name): vals = np.asarray(gb.get_group(gr)[data_header], np.float) if color_index >= len(colors): color_index = 0 plot_data, plot_xrange = FigureFactory._violinplot( vals, fillcolor=colors[color_index] ) layout = graph_objs.Layout() for item in plot_data: fig.append_trace(item, 1, k + 1) color_index += 1 # add violin plot labels fig['layout'].update({'xaxis{}'.format(k + 1): FigureFactory._make_XAxis(group_name[k], plot_xrange)}) # set the sharey axis style fig['layout'].update( {'yaxis{}'.format(1): FigureFactory._make_YAxis('')} ) fig['layout'].update( title=title, showlegend=False, hovermode='closest', autosize=False, height=height, width=width ) return fig @staticmethod def _violin_colorscale(data, data_header, group_header, colors, use_colorscale, group_stats, height, width, title): """ Refer to FigureFactory.create_violin() for docstring. Returns fig for violin plot with colorscale. """ from plotly.graph_objs import graph_objs import numpy as np # collect all group names group_name = [] for name in data[group_header]: if name not in group_name: group_name.append(name) group_name.sort() # make sure all group names are keys in group_stats for group in group_name: if group not in group_stats: raise exceptions.PlotlyError("All values/groups in the index " "column must be represented " "as a key in group_stats.") gb = data.groupby([group_header]) L = len(group_name) fig = make_subplots(rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=True) # prepare low and high color for colorscale lowcolor = FigureFactory._color_parser( colors[0], FigureFactory._unlabel_rgb ) highcolor = FigureFactory._color_parser( colors[1], FigureFactory._unlabel_rgb ) # find min and max values in group_stats group_stats_values = [] for key in group_stats: group_stats_values.append(group_stats[key]) max_value = max(group_stats_values) min_value = min(group_stats_values) for k, gr in enumerate(group_name): vals = np.asarray(gb.get_group(gr)[data_header], np.float) # find intermediate color from colorscale intermed = (group_stats[gr] - min_value) / (max_value - min_value) intermed_color = FigureFactory._find_intermediate_color( lowcolor, highcolor, intermed ) plot_data, plot_xrange = FigureFactory._violinplot( vals, fillcolor='rgb{}'.format(intermed_color) ) layout = graph_objs.Layout() for item in plot_data: fig.append_trace(item, 1, k + 1) fig['layout'].update({'xaxis{}'.format(k + 1): FigureFactory._make_XAxis(group_name[k], plot_xrange)}) # add colorbar to plot trace_dummy = graph_objs.Scatter( x=[0], y=[0], mode='markers', marker=dict( size=2, cmin=min_value, cmax=max_value, colorscale=[[0, colors[0]], [1, colors[1]]], showscale=True), showlegend=False, ) fig.append_trace(trace_dummy, 1, L) # set the sharey axis style fig['layout'].update( {'yaxis{}'.format(1): FigureFactory._make_YAxis('')} ) fig['layout'].update( title=title, showlegend=False, hovermode='closest', autosize=False, height=height, width=width ) return fig @staticmethod def _violin_dict(data, data_header, group_header, colors, use_colorscale, group_stats, height, width, title): """ Refer to FigureFactory.create_violin() for docstring. Returns fig for violin plot without colorscale. """ from plotly.graph_objs import graph_objs import numpy as np # collect all group names group_name = [] for name in data[group_header]: if name not in group_name: group_name.append(name) group_name.sort() # check if all group names appear in colors dict for group in group_name: if group not in colors: raise exceptions.PlotlyError("If colors is a dictionary, all " "the group names must appear as " "keys in colors.") gb = data.groupby([group_header]) L = len(group_name) fig = make_subplots(rows=1, cols=L, shared_yaxes=True, horizontal_spacing=0.025, print_grid=True) for k, gr in enumerate(group_name): vals = np.asarray(gb.get_group(gr)[data_header], np.float) plot_data, plot_xrange = FigureFactory._violinplot( vals, fillcolor=colors[gr] ) layout = graph_objs.Layout() for item in plot_data: fig.append_trace(item, 1, k + 1) # add violin plot labels fig['layout'].update({'xaxis{}'.format(k + 1): FigureFactory._make_XAxis(group_name[k], plot_xrange)}) # set the sharey axis style fig['layout'].update( {'yaxis{}'.format(1): FigureFactory._make_YAxis('')} ) fig['layout'].update( title=title, showlegend=False, hovermode='closest', autosize=False, height=height, width=width ) return fig @staticmethod def create_violin(data, data_header=None, group_header=None, colors=None, use_colorscale=False, group_stats=None, height=450, width=600, title='Violin and Rug Plot'): """ Returns figure for a violin plot :param (list|array) data: accepts either a list of numerical values, a list of dictionaries all with identical keys and at least one column of numeric values, or a pandas dataframe with at least one column of numbers :param (str) data_header: the header of the data column to be used from an inputted pandas dataframe. Not applicable if 'data' is a list of numeric values :param (str) group_header: applicable if grouping data by a variable. 'group_header' must be set to the name of the grouping variable. :param (str|tuple|list|dict) colors: either a plotly scale name, an rgb or hex color, a color tuple, a list of colors or a dictionary. An rgb color is of the form 'rgb(x, y, z)' where x, y and z belong to the interval [0, 255] and a color tuple is a tuple of the form (a, b, c) where a, b and c belong to [0, 1]. If colors is a list, it must contain valid color types as its members. :param (bool) use_colorscale: Only applicable if grouping by another variable. Will implement a colorscale based on the first 2 colors of param colors. This means colors must be a list with at least 2 colors in it (Plotly colorscales are accepted since they map to a list of two rgb colors) :param (dict) group_stats: a dictioanry where each key is a unique value from the group_header column in data. Each value must be a number and will be used to color the violin plots if a colorscale is being used :param (float) height: the height of the violin plot :param (float) width: the width of the violin plot :param (str) title: the title of the violin plot Example 1: Single Violin Plot ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF from plotly.graph_objs import graph_objs import numpy as np from scipy import stats # create list of random values data_list = np.random.randn(100) data_list.tolist() # create violin fig fig = FF.create_violin(data_list, colors='#604d9e') # plot py.iplot(fig, filename='Violin Plot') ``` Example 2: Multiple Violin Plots with Qualitative Coloring ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF from plotly.graph_objs import graph_objs import numpy as np import pandas as pd from scipy import stats # create dataframe np.random.seed(619517) Nr=250 y = np.random.randn(Nr) gr = np.random.choice(list("ABCDE"), Nr) norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)] for i, letter in enumerate("ABCDE"): y[gr == letter] *=norm_params[i][1]+ norm_params[i][0] df = pd.DataFrame(dict(Score=y, Group=gr)) # create violin fig fig = FF.create_violin(df, data_header='Score', group_header='Group', height=600, width=1000) # plot py.iplot(fig, filename='Violin Plot with Coloring') ``` Example 3: Violin Plots with Colorscale ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF from plotly.graph_objs import graph_objs import numpy as np import pandas as pd from scipy import stats # create dataframe np.random.seed(619517) Nr=250 y = np.random.randn(Nr) gr = np.random.choice(list("ABCDE"), Nr) norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)] for i, letter in enumerate("ABCDE"): y[gr == letter] *=norm_params[i][1]+ norm_params[i][0] df = pd.DataFrame(dict(Score=y, Group=gr)) # define header params data_header = 'Score' group_header = 'Group' # make groupby object with pandas group_stats = {} groupby_data = df.groupby([group_header]) for group in "ABCDE": data_from_group = groupby_data.get_group(group)[data_header] # take a stat of the grouped data stat = np.median(data_from_group) # add to dictionary group_stats[group] = stat # create violin fig fig = FF.create_violin(df, data_header='Score', group_header='Group', height=600, width=1000, use_colorscale=True, group_stats=group_stats) # plot py.iplot(fig, filename='Violin Plot with Colorscale') ``` """ from plotly.graph_objs import graph_objs from numbers import Number # Validate colors if isinstance(colors, dict): valid_colors = FigureFactory._validate_colors_dict(colors, 'rgb') else: valid_colors = FigureFactory._validate_colors(colors, 'rgb') # validate data and choose plot type if group_header is None: if isinstance(data, list): if len(data) <= 0: raise exceptions.PlotlyError("If data is a list, it must be " "nonempty and contain either " "numbers or dictionaries.") if not all(isinstance(element, Number) for element in data): raise exceptions.PlotlyError("If data is a list, it must " "contain only numbers.") if _pandas_imported and isinstance(data, pd.core.frame.DataFrame): if data_header is None: raise exceptions.PlotlyError("data_header must be the " "column name with the " "desired numeric data for " "the violin plot.") data = data[data_header].values.tolist() # call the plotting functions plot_data, plot_xrange = FigureFactory._violinplot( data, fillcolor=valid_colors[0] ) layout = graph_objs.Layout( title=title, autosize=False, font=graph_objs.Font(size=11), height=height, showlegend=False, width=width, xaxis=FigureFactory._make_XAxis('', plot_xrange), yaxis=FigureFactory._make_YAxis(''), hovermode='closest' ) layout['yaxis'].update(dict(showline=False, showticklabels=False, ticks='')) fig = graph_objs.Figure(data=graph_objs.Data(plot_data), layout=layout) return fig else: if not isinstance(data, pd.core.frame.DataFrame): raise exceptions.PlotlyError("Error. You must use a pandas " "DataFrame if you are using a " "group header.") if data_header is None: raise exceptions.PlotlyError("data_header must be the column " "name with the desired numeric " "data for the violin plot.") if use_colorscale is False: if isinstance(valid_colors, dict): # validate colors dict choice below fig = FigureFactory._violin_dict( data, data_header, group_header, valid_colors, use_colorscale, group_stats, height, width, title ) return fig else: fig = FigureFactory._violin_no_colorscale( data, data_header, group_header, valid_colors, use_colorscale, group_stats, height, width, title ) return fig else: if isinstance(valid_colors, dict): raise exceptions.PlotlyError("The colors param cannot be " "a dictionary if you are " "using a colorscale.") if len(valid_colors) < 2: raise exceptions.PlotlyError("colors must be a list with " "at least 2 colors. A " "Plotly scale is allowed.") if not isinstance(group_stats, dict): raise exceptions.PlotlyError("Your group_stats param " "must be a dictionary.") fig = FigureFactory._violin_colorscale( data, data_header, group_header, valid_colors, use_colorscale, group_stats, height, width, title ) return fig @staticmethod def _find_intermediate_color(lowcolor, highcolor, intermed): """ Returns the color at a given distance between two colors This function takes two color tuples, where each element is between 0 and 1, along with a value 0 < intermed < 1 and returns a color that is intermed-percent from lowcolor to highcolor """ diff_0 = float(highcolor[0] - lowcolor[0]) diff_1 = float(highcolor[1] - lowcolor[1]) diff_2 = float(highcolor[2] - lowcolor[2]) return (lowcolor[0] + intermed * diff_0, lowcolor[1] + intermed * diff_1, lowcolor[2] + intermed * diff_2) @staticmethod def _color_parser(colors, function): """ Takes color(s) and a function and applies the function on the color(s) In particular, this function identifies whether the given color object is an iterable or not and applies the given color-parsing function to the color or iterable of colors. If given an iterable, it will only be able to work with it if all items in the iterable are of the same type - rgb string, hex string or tuple """ from numbers import Number if isinstance(colors, str): return function(colors) if isinstance(colors, tuple) and isinstance(colors[0], Number): return function(colors) if hasattr(colors, '__iter__'): if isinstance(colors, tuple): new_color_tuple = tuple(function(item) for item in colors) return new_color_tuple else: new_color_list = [function(item) for item in colors] return new_color_list @staticmethod def _unconvert_from_RGB_255(colors): """ Return a tuple where each element gets divided by 255 Takes a (list of) color tuple(s) where each element is between 0 and 255. Returns the same tuples where each tuple element is normalized to a value between 0 and 1 """ return (colors[0]/(255.0), colors[1]/(255.0), colors[2]/(255.0)) @staticmethod def _map_face2color(face, colormap, vmin, vmax): """ Normalize facecolor values by vmin/vmax and return rgb-color strings This function takes a tuple color along with a colormap and a minimum (vmin) and maximum (vmax) range of possible mean distances for the given parametrized surface. It returns an rgb color based on the mean distance between vmin and vmax """ if vmin >= vmax: raise exceptions.PlotlyError("Incorrect relation between vmin " "and vmax. The vmin value cannot be " "bigger than or equal to the value " "of vmax.") if len(colormap) == 1: # color each triangle face with the same color in colormap face_color = colormap[0] face_color = FigureFactory._convert_to_RGB_255(face_color) face_color = FigureFactory._label_rgb(face_color) else: if face == vmax: # pick last color in colormap face_color = colormap[-1] face_color = FigureFactory._convert_to_RGB_255(face_color) face_color = FigureFactory._label_rgb(face_color) else: # find the normalized distance t of a triangle face between # vmin and vmax where the distance is between 0 and 1 t = (face - vmin) / float((vmax - vmin)) low_color_index = int(t / (1./(len(colormap) - 1))) face_color = FigureFactory._find_intermediate_color( colormap[low_color_index], colormap[low_color_index + 1], t * (len(colormap) - 1) - low_color_index ) face_color = FigureFactory._convert_to_RGB_255(face_color) face_color = FigureFactory._label_rgb(face_color) return face_color @staticmethod def _trisurf(x, y, z, simplices, show_colorbar, edges_color, colormap=None, color_func=None, plot_edges=False, x_edge=None, y_edge=None, z_edge=None, facecolor=None): """ Refer to FigureFactory.create_trisurf() for docstring """ # numpy import check if _numpy_imported is False: raise ImportError("FigureFactory._trisurf() requires " "numpy imported.") import numpy as np from plotly.graph_objs import graph_objs points3D = np.vstack((x, y, z)).T simplices = np.atleast_2d(simplices) # vertices of the surface triangles tri_vertices = points3D[simplices] # Define colors for the triangle faces if color_func is None: # mean values of z-coordinates of triangle vertices mean_dists = tri_vertices[:, :, 2].mean(-1) elif isinstance(color_func, (list, np.ndarray)): # Pre-computed list / array of values to map onto color if len(color_func) != len(simplices): raise ValueError("If color_func is a list/array, it must " "be the same length as simplices.") # convert all colors in color_func to rgb for index in range(len(color_func)): if isinstance(color_func[index], str): if '#' in color_func[index]: foo = FigureFactory._hex_to_rgb(color_func[index]) color_func[index] = FigureFactory._label_rgb(foo) if isinstance(color_func[index], tuple): foo = FigureFactory._convert_to_RGB_255(color_func[index]) color_func[index] = FigureFactory._label_rgb(foo) mean_dists = np.asarray(color_func) else: # apply user inputted function to calculate # custom coloring for triangle vertices mean_dists = [] for triangle in tri_vertices: dists = [] for vertex in triangle: dist = color_func(vertex[0], vertex[1], vertex[2]) dists.append(dist) mean_dists.append(np.mean(dists)) mean_dists = np.asarray(mean_dists) # Check if facecolors are already strings and can be skipped if isinstance(mean_dists[0], str): facecolor = mean_dists else: min_mean_dists = np.min(mean_dists) max_mean_dists = np.max(mean_dists) if facecolor is None: facecolor = [] for index in range(len(mean_dists)): color = FigureFactory._map_face2color(mean_dists[index], colormap, min_mean_dists, max_mean_dists) facecolor.append(color) # Make sure facecolor is a list so output is consistent across Pythons facecolor = np.asarray(facecolor) ii, jj, kk = simplices.T triangles = graph_objs.Mesh3d(x=x, y=y, z=z, facecolor=facecolor, i=ii, j=jj, k=kk, name='') mean_dists_are_numbers = not isinstance(mean_dists[0], str) if mean_dists_are_numbers and show_colorbar is True: # make a colorscale from the colors colorscale = FigureFactory._make_colorscale(colormap) colorscale = FigureFactory._convert_colorscale_to_rgb(colorscale) colorbar = graph_objs.Scatter3d( x=x[0], y=y[0], z=z[0], mode='markers', marker=dict( size=0.1, color=[min_mean_dists, max_mean_dists], colorscale=colorscale, showscale=True), hoverinfo='None', showlegend=False ) # the triangle sides are not plotted if plot_edges is False: if mean_dists_are_numbers and show_colorbar is True: return graph_objs.Data([triangles, colorbar]) else: return graph_objs.Data([triangles]) # define the lists x_edge, y_edge and z_edge, of x, y, resp z # coordinates of edge end points for each triangle # None separates data corresponding to two consecutive triangles is_none = [ii is None for ii in [x_edge, y_edge, z_edge]] if any(is_none): if not all(is_none): raise ValueError("If any (x_edge, y_edge, z_edge) is None, " "all must be None") else: x_edge = [] y_edge = [] z_edge = [] # Pull indices we care about, then add a None column to separate tris ixs_triangles = [0, 1, 2, 0] pull_edges = tri_vertices[:, ixs_triangles, :] x_edge_pull = np.hstack([pull_edges[:, :, 0], np.tile(None, [pull_edges.shape[0], 1])]) y_edge_pull = np.hstack([pull_edges[:, :, 1], np.tile(None, [pull_edges.shape[0], 1])]) z_edge_pull = np.hstack([pull_edges[:, :, 2], np.tile(None, [pull_edges.shape[0], 1])]) # Now unravel the edges into a 1-d vector for plotting x_edge = np.hstack([x_edge, x_edge_pull.reshape([1, -1])[0]]) y_edge = np.hstack([y_edge, y_edge_pull.reshape([1, -1])[0]]) z_edge = np.hstack([z_edge, z_edge_pull.reshape([1, -1])[0]]) if not (len(x_edge) == len(y_edge) == len(z_edge)): raise exceptions.PlotlyError("The lengths of x_edge, y_edge and " "z_edge are not the same.") # define the lines for plotting lines = graph_objs.Scatter3d( x=x_edge, y=y_edge, z=z_edge, mode='lines', line=graph_objs.Line( color=edges_color, width=1.5 ), showlegend=False ) if mean_dists_are_numbers and show_colorbar is True: return graph_objs.Data([triangles, lines, colorbar]) else: return graph_objs.Data([triangles, lines]) @staticmethod def create_trisurf(x, y, z, simplices, colormap=None, show_colorbar=True, color_func=None, title='Trisurf Plot', plot_edges=True, showbackground=True, backgroundcolor='rgb(230, 230, 230)', gridcolor='rgb(255, 255, 255)', zerolinecolor='rgb(255, 255, 255)', edges_color='rgb(50, 50, 50)', height=800, width=800, aspectratio=dict(x=1, y=1, z=1)): """ Returns figure for a triangulated surface plot :param (array) x: data values of x in a 1D array :param (array) y: data values of y in a 1D array :param (array) z: data values of z in a 1D array :param (array) simplices: an array of shape (ntri, 3) where ntri is the number of triangles in the triangularization. Each row of the array contains the indicies of the verticies of each triangle :param (str|tuple|list) colormap: either a plotly scale name, an rgb or hex color, a color tuple or a list of colors. An rgb color is of the form 'rgb(x, y, z)' where x, y, z belong to the interval [0, 255] and a color tuple is a tuple of the form (a, b, c) where a, b and c belong to [0, 1]. If colormap is a list, it must contain the valid color types aforementioned as its members :param (bool) show_colorbar: determines if colorbar is visible :param (function|list) color_func: The parameter that determines the coloring of the surface. Takes either a function with 3 arguments x, y, z or a list/array of color values the same length as simplices. If None, coloring will only depend on the z axis :param (str) title: title of the plot :param (bool) plot_edges: determines if the triangles on the trisurf are visible :param (bool) showbackground: makes background in plot visible :param (str) backgroundcolor: color of background. Takes a string of the form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive :param (str) gridcolor: color of the gridlines besides the axes. Takes a string of the form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive :param (str) zerolinecolor: color of the axes. Takes a string of the form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive :param (str) edges_color: color of the edges, if plot_edges is True :param (int|float) height: the height of the plot (in pixels) :param (int|float) width: the width of the plot (in pixels) :param (dict) aspectratio: a dictionary of the aspect ratio values for the x, y and z axes. 'x', 'y' and 'z' take (int|float) values Example 1: Sphere ``` # Necessary Imports for Trisurf import numpy as np from scipy.spatial import Delaunay import plotly.plotly as py from plotly.tools import FigureFactory as FF from plotly.graph_objs import graph_objs # Make data for plot u = np.linspace(0, 2*np.pi, 20) v = np.linspace(0, np.pi, 20) u,v = np.meshgrid(u,v) u = u.flatten() v = v.flatten() x = np.sin(v)*np.cos(u) y = np.sin(v)*np.sin(u) z = np.cos(v) points2D = np.vstack([u,v]).T tri = Delaunay(points2D) simplices = tri.simplices # Create a figure fig1 = FF.create_trisurf(x=x, y=y, z=z, colormap="Blues", simplices=simplices) # Plot the data py.iplot(fig1, filename='trisurf-plot-sphere') ``` Example 2: Torus ``` # Necessary Imports for Trisurf import numpy as np from scipy.spatial import Delaunay import plotly.plotly as py from plotly.tools import FigureFactory as FF from plotly.graph_objs import graph_objs # Make data for plot u = np.linspace(0, 2*np.pi, 20) v = np.linspace(0, 2*np.pi, 20) u,v = np.meshgrid(u,v) u = u.flatten() v = v.flatten() x = (3 + (np.cos(v)))*np.cos(u) y = (3 + (np.cos(v)))*np.sin(u) z = np.sin(v) points2D = np.vstack([u,v]).T tri = Delaunay(points2D) simplices = tri.simplices # Create a figure fig1 = FF.create_trisurf(x=x, y=y, z=z, colormap="Greys", simplices=simplices) # Plot the data py.iplot(fig1, filename='trisurf-plot-torus') ``` Example 3: Mobius Band ``` # Necessary Imports for Trisurf import numpy as np from scipy.spatial import Delaunay import plotly.plotly as py from plotly.tools import FigureFactory as FF from plotly.graph_objs import graph_objs # Make data for plot u = np.linspace(0, 2*np.pi, 24) v = np.linspace(-1, 1, 8) u,v = np.meshgrid(u,v) u = u.flatten() v = v.flatten() tp = 1 + 0.5*v*np.cos(u/2.) x = tp*np.cos(u) y = tp*np.sin(u) z = 0.5*v*np.sin(u/2.) points2D = np.vstack([u,v]).T tri = Delaunay(points2D) simplices = tri.simplices # Create a figure fig1 = FF.create_trisurf(x=x, y=y, z=z, colormap=[(0.2, 0.4, 0.6), (1, 1, 1)], simplices=simplices) # Plot the data py.iplot(fig1, filename='trisurf-plot-mobius-band') ``` Example 4: Using a Custom Colormap Function with Light Cone ``` # Necessary Imports for Trisurf import numpy as np from scipy.spatial import Delaunay import plotly.plotly as py from plotly.tools import FigureFactory as FF from plotly.graph_objs import graph_objs # Make data for plot u=np.linspace(-np.pi, np.pi, 30) v=np.linspace(-np.pi, np.pi, 30) u,v=np.meshgrid(u,v) u=u.flatten() v=v.flatten() x = u y = u*np.cos(v) z = u*np.sin(v) points2D = np.vstack([u,v]).T tri = Delaunay(points2D) simplices = tri.simplices # Define distance function def dist_origin(x, y, z): return np.sqrt((1.0 * x)**2 + (1.0 * y)**2 + (1.0 * z)**2) # Create a figure fig1 = FF.create_trisurf(x=x, y=y, z=z, colormap=['#604d9e', 'rgb(50, 150, 255)', (0.2, 0.2, 0.8)], simplices=simplices, color_func=dist_origin) # Plot the data py.iplot(fig1, filename='trisurf-plot-custom-coloring') ``` Example 5: Enter color_func as a list of colors ``` # Necessary Imports for Trisurf import numpy as np from scipy.spatial import Delaunay import random import plotly.plotly as py from plotly.tools import FigureFactory as FF from plotly.graph_objs import graph_objs # Make data for plot u=np.linspace(-np.pi, np.pi, 30) v=np.linspace(-np.pi, np.pi, 30) u,v=np.meshgrid(u,v) u=u.flatten() v=v.flatten() x = u y = u*np.cos(v) z = u*np.sin(v) points2D = np.vstack([u,v]).T tri = Delaunay(points2D) simplices = tri.simplices colors = [] color_choices = ['rgb(0, 0, 0)', '#6c4774', '#d6c7dd'] for index in range(len(simplices)): colors.append(random.choice(color_choices)) fig = FF.create_trisurf( x, y, z, simplices, color_func=colors, show_colorbar=True, edges_color='rgb(2, 85, 180)', title=' Modern Art' ) py.iplot(fig, filename="trisurf-plot-modern-art") ``` """ from plotly.graph_objs import graph_objs # Validate colormap colormap = FigureFactory._validate_colors(colormap, 'tuple') data1 = FigureFactory._trisurf(x, y, z, simplices, show_colorbar=show_colorbar, color_func=color_func, colormap=colormap, edges_color=edges_color, plot_edges=plot_edges) axis = dict( showbackground=showbackground, backgroundcolor=backgroundcolor, gridcolor=gridcolor, zerolinecolor=zerolinecolor, ) layout = graph_objs.Layout( title=title, width=width, height=height, scene=graph_objs.Scene( xaxis=graph_objs.XAxis(axis), yaxis=graph_objs.YAxis(axis), zaxis=graph_objs.ZAxis(axis), aspectratio=dict( x=aspectratio['x'], y=aspectratio['y'], z=aspectratio['z']), ) ) return graph_objs.Figure(data=data1, layout=layout) @staticmethod def _scatterplot(dataframe, headers, diag, size, height, width, title, **kwargs): """ Refer to FigureFactory.create_scatterplotmatrix() for docstring Returns fig for scatterplotmatrix without index """ from plotly.graph_objs import graph_objs dim = len(dataframe) fig = make_subplots(rows=dim, cols=dim) trace_list = [] # Insert traces into trace_list for listy in dataframe: for listx in dataframe: if (listx == listy) and (diag == 'histogram'): trace = graph_objs.Histogram( x=listx, showlegend=False ) elif (listx == listy) and (diag == 'box'): trace = graph_objs.Box( y=listx, name=None, showlegend=False ) else: if 'marker' in kwargs: kwargs['marker']['size'] = size trace = graph_objs.Scatter( x=listx, y=listy, mode='markers', showlegend=False, **kwargs ) trace_list.append(trace) else: trace = graph_objs.Scatter( x=listx, y=listy, mode='markers', marker=dict( size=size), showlegend=False, **kwargs ) trace_list.append(trace) trace_index = 0 indices = range(1, dim + 1) for y_index in indices: for x_index in indices: fig.append_trace(trace_list[trace_index], y_index, x_index) trace_index += 1 # Insert headers into the figure for j in range(dim): xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) fig['layout'][xaxis_key].update(title=headers[j]) for j in range(dim): yaxis_key = 'yaxis{}'.format(1 + (dim * j)) fig['layout'][yaxis_key].update(title=headers[j]) fig['layout'].update( height=height, width=width, title=title, showlegend=True ) return fig @staticmethod def _scatterplot_dict(dataframe, headers, diag, size, height, width, title, index, index_vals, endpts, colormap, colormap_type, **kwargs): """ Refer to FigureFactory.create_scatterplotmatrix() for docstring Returns fig for scatterplotmatrix with both index and colormap picked. Used if colormap is a dictionary with index values as keys pointing to colors. Forces colormap_type to behave categorically because it would not make sense colors are assigned to each index value and thus implies that a categorical approach should be taken """ from plotly.graph_objs import graph_objs theme = colormap dim = len(dataframe) fig = make_subplots(rows=dim, cols=dim) trace_list = [] legend_param = 0 # Work over all permutations of list pairs for listy in dataframe: for listx in dataframe: # create a dictionary for index_vals unique_index_vals = {} for name in index_vals: if name not in unique_index_vals: unique_index_vals[name] = [] # Fill all the rest of the names into the dictionary for name in sorted(unique_index_vals.keys()): new_listx = [] new_listy = [] for j in range(len(index_vals)): if index_vals[j] == name: new_listx.append(listx[j]) new_listy.append(listy[j]) # Generate trace with VISIBLE icon if legend_param == 1: if (listx == listy) and (diag == 'histogram'): trace = graph_objs.Histogram( x=new_listx, marker=dict( color=theme[name]), showlegend=True ) elif (listx == listy) and (diag == 'box'): trace = graph_objs.Box( y=new_listx, name=None, marker=dict( color=theme[name]), showlegend=True ) else: if 'marker' in kwargs: kwargs['marker']['size'] = size kwargs['marker']['color'] = theme[name] trace = graph_objs.Scatter( x=new_listx, y=new_listy, mode='markers', name=name, showlegend=True, **kwargs ) else: trace = graph_objs.Scatter( x=new_listx, y=new_listy, mode='markers', name=name, marker=dict( size=size, color=theme[name]), showlegend=True, **kwargs ) # Generate trace with INVISIBLE icon else: if (listx == listy) and (diag == 'histogram'): trace = graph_objs.Histogram( x=new_listx, marker=dict( color=theme[name]), showlegend=False ) elif (listx == listy) and (diag == 'box'): trace = graph_objs.Box( y=new_listx, name=None, marker=dict( color=theme[name]), showlegend=False ) else: if 'marker' in kwargs: kwargs['marker']['size'] = size kwargs['marker']['color'] = theme[name] trace = graph_objs.Scatter( x=new_listx, y=new_listy, mode='markers', name=name, showlegend=False, **kwargs ) else: trace = graph_objs.Scatter( x=new_listx, y=new_listy, mode='markers', name=name, marker=dict( size=size, color=theme[name]), showlegend=False, **kwargs ) # Push the trace into dictionary unique_index_vals[name] = trace trace_list.append(unique_index_vals) legend_param += 1 trace_index = 0 indices = range(1, dim + 1) for y_index in indices: for x_index in indices: for name in sorted(trace_list[trace_index].keys()): fig.append_trace( trace_list[trace_index][name], y_index, x_index) trace_index += 1 # Insert headers into the figure for j in range(dim): xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) fig['layout'][xaxis_key].update(title=headers[j]) for j in range(dim): yaxis_key = 'yaxis{}'.format(1 + (dim * j)) fig['layout'][yaxis_key].update(title=headers[j]) if diag == 'histogram': fig['layout'].update( height=height, width=width, title=title, showlegend=True, barmode='stack') return fig elif diag == 'box': fig['layout'].update( height=height, width=width, title=title, showlegend=True) return fig else: fig['layout'].update( height=height, width=width, title=title, showlegend=True) return fig @staticmethod def _scatterplot_theme(dataframe, headers, diag, size, height, width, title, index, index_vals, endpts, colormap, colormap_type, **kwargs): """ Refer to FigureFactory.create_scatterplotmatrix() for docstring Returns fig for scatterplotmatrix with both index and colormap picked """ from plotly.graph_objs import graph_objs # Check if index is made of string values if isinstance(index_vals[0], str): unique_index_vals = [] for name in index_vals: if name not in unique_index_vals: unique_index_vals.append(name) n_colors_len = len(unique_index_vals) # Convert colormap to list of n RGB tuples if colormap_type == 'seq': foo = FigureFactory._color_parser( colormap, FigureFactory._unlabel_rgb ) foo = FigureFactory._n_colors(foo[0], foo[1], n_colors_len) theme = FigureFactory._color_parser( foo, FigureFactory._label_rgb ) if colormap_type == 'cat': # leave list of colors the same way theme = colormap dim = len(dataframe) fig = make_subplots(rows=dim, cols=dim) trace_list = [] legend_param = 0 # Work over all permutations of list pairs for listy in dataframe: for listx in dataframe: # create a dictionary for index_vals unique_index_vals = {} for name in index_vals: if name not in unique_index_vals: unique_index_vals[name] = [] c_indx = 0 # color index # Fill all the rest of the names into the dictionary for name in sorted(unique_index_vals.keys()): new_listx = [] new_listy = [] for j in range(len(index_vals)): if index_vals[j] == name: new_listx.append(listx[j]) new_listy.append(listy[j]) # Generate trace with VISIBLE icon if legend_param == 1: if (listx == listy) and (diag == 'histogram'): trace = graph_objs.Histogram( x=new_listx, marker=dict( color=theme[c_indx]), showlegend=True ) elif (listx == listy) and (diag == 'box'): trace = graph_objs.Box( y=new_listx, name=None, marker=dict( color=theme[c_indx]), showlegend=True ) else: if 'marker' in kwargs: kwargs['marker']['size'] = size kwargs['marker']['color'] = theme[c_indx] trace = graph_objs.Scatter( x=new_listx, y=new_listy, mode='markers', name=name, showlegend=True, **kwargs ) else: trace = graph_objs.Scatter( x=new_listx, y=new_listy, mode='markers', name=name, marker=dict( size=size, color=theme[c_indx]), showlegend=True, **kwargs ) # Generate trace with INVISIBLE icon else: if (listx == listy) and (diag == 'histogram'): trace = graph_objs.Histogram( x=new_listx, marker=dict( color=theme[c_indx]), showlegend=False ) elif (listx == listy) and (diag == 'box'): trace = graph_objs.Box( y=new_listx, name=None, marker=dict( color=theme[c_indx]), showlegend=False ) else: if 'marker' in kwargs: kwargs['marker']['size'] = size kwargs['marker']['color'] = theme[c_indx] trace = graph_objs.Scatter( x=new_listx, y=new_listy, mode='markers', name=name, showlegend=False, **kwargs ) else: trace = graph_objs.Scatter( x=new_listx, y=new_listy, mode='markers', name=name, marker=dict( size=size, color=theme[c_indx]), showlegend=False, **kwargs ) # Push the trace into dictionary unique_index_vals[name] = trace if c_indx >= (len(theme) - 1): c_indx = -1 c_indx += 1 trace_list.append(unique_index_vals) legend_param += 1 trace_index = 0 indices = range(1, dim + 1) for y_index in indices: for x_index in indices: for name in sorted(trace_list[trace_index].keys()): fig.append_trace( trace_list[trace_index][name], y_index, x_index) trace_index += 1 # Insert headers into the figure for j in range(dim): xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) fig['layout'][xaxis_key].update(title=headers[j]) for j in range(dim): yaxis_key = 'yaxis{}'.format(1 + (dim * j)) fig['layout'][yaxis_key].update(title=headers[j]) if diag == 'histogram': fig['layout'].update( height=height, width=width, title=title, showlegend=True, barmode='stack') return fig elif diag == 'box': fig['layout'].update( height=height, width=width, title=title, showlegend=True) return fig else: fig['layout'].update( height=height, width=width, title=title, showlegend=True) return fig else: if endpts: intervals = FigureFactory._endpts_to_intervals(endpts) # Convert colormap to list of n RGB tuples if colormap_type == 'seq': foo = FigureFactory._color_parser( colormap, FigureFactory._unlabel_rgb ) foo = FigureFactory._n_colors(foo[0], foo[1], len(intervals)) theme = FigureFactory._color_parser( foo, FigureFactory._label_rgb ) if colormap_type == 'cat': # leave list of colors the same way theme = colormap dim = len(dataframe) fig = make_subplots(rows=dim, cols=dim) trace_list = [] legend_param = 0 # Work over all permutations of list pairs for listy in dataframe: for listx in dataframe: interval_labels = {} for interval in intervals: interval_labels[str(interval)] = [] c_indx = 0 # color index # Fill all the rest of the names into the dictionary for interval in intervals: new_listx = [] new_listy = [] for j in range(len(index_vals)): if interval[0] < index_vals[j] <= interval[1]: new_listx.append(listx[j]) new_listy.append(listy[j]) # Generate trace with VISIBLE icon if legend_param == 1: if (listx == listy) and (diag == 'histogram'): trace = graph_objs.Histogram( x=new_listx, marker=dict( color=theme[c_indx]), showlegend=True ) elif (listx == listy) and (diag == 'box'): trace = graph_objs.Box( y=new_listx, name=None, marker=dict( color=theme[c_indx]), showlegend=True ) else: if 'marker' in kwargs: kwargs['marker']['size'] = size (kwargs['marker'] ['color']) = theme[c_indx] trace = graph_objs.Scatter( x=new_listx, y=new_listy, mode='markers', name=str(interval), showlegend=True, **kwargs ) else: trace = graph_objs.Scatter( x=new_listx, y=new_listy, mode='markers', name=str(interval), marker=dict( size=size, color=theme[c_indx]), showlegend=True, **kwargs ) # Generate trace with INVISIBLE icon else: if (listx == listy) and (diag == 'histogram'): trace = graph_objs.Histogram( x=new_listx, marker=dict( color=theme[c_indx]), showlegend=False ) elif (listx == listy) and (diag == 'box'): trace = graph_objs.Box( y=new_listx, name=None, marker=dict( color=theme[c_indx]), showlegend=False ) else: if 'marker' in kwargs: kwargs['marker']['size'] = size (kwargs['marker'] ['color']) = theme[c_indx] trace = graph_objs.Scatter( x=new_listx, y=new_listy, mode='markers', name=str(interval), showlegend=False, **kwargs ) else: trace = graph_objs.Scatter( x=new_listx, y=new_listy, mode='markers', name=str(interval), marker=dict( size=size, color=theme[c_indx]), showlegend=False, **kwargs ) # Push the trace into dictionary interval_labels[str(interval)] = trace if c_indx >= (len(theme) - 1): c_indx = -1 c_indx += 1 trace_list.append(interval_labels) legend_param += 1 trace_index = 0 indices = range(1, dim + 1) for y_index in indices: for x_index in indices: for interval in intervals: fig.append_trace( trace_list[trace_index][str(interval)], y_index, x_index) trace_index += 1 # Insert headers into the figure for j in range(dim): xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) fig['layout'][xaxis_key].update(title=headers[j]) for j in range(dim): yaxis_key = 'yaxis{}'.format(1 + (dim * j)) fig['layout'][yaxis_key].update(title=headers[j]) if diag == 'histogram': fig['layout'].update( height=height, width=width, title=title, showlegend=True, barmode='stack') return fig elif diag == 'box': fig['layout'].update( height=height, width=width, title=title, showlegend=True) return fig else: fig['layout'].update( height=height, width=width, title=title, showlegend=True) return fig else: theme = colormap # add a copy of rgb color to theme if it contains one color if len(theme) <= 1: theme.append(theme[0]) color = [] for incr in range(len(theme)): color.append([1./(len(theme)-1)*incr, theme[incr]]) dim = len(dataframe) fig = make_subplots(rows=dim, cols=dim) trace_list = [] legend_param = 0 # Run through all permutations of list pairs for listy in dataframe: for listx in dataframe: # Generate trace with VISIBLE icon if legend_param == 1: if (listx == listy) and (diag == 'histogram'): trace = graph_objs.Histogram( x=listx, marker=dict( color=theme[0]), showlegend=False ) elif (listx == listy) and (diag == 'box'): trace = graph_objs.Box( y=listx, marker=dict( color=theme[0]), showlegend=False ) else: if 'marker' in kwargs: kwargs['marker']['size'] = size kwargs['marker']['color'] = index_vals kwargs['marker']['colorscale'] = color kwargs['marker']['showscale'] = True trace = graph_objs.Scatter( x=listx, y=listy, mode='markers', showlegend=False, **kwargs ) else: trace = graph_objs.Scatter( x=listx, y=listy, mode='markers', marker=dict( size=size, color=index_vals, colorscale=color, showscale=True), showlegend=False, **kwargs ) # Generate trace with INVISIBLE icon else: if (listx == listy) and (diag == 'histogram'): trace = graph_objs.Histogram( x=listx, marker=dict( color=theme[0]), showlegend=False ) elif (listx == listy) and (diag == 'box'): trace = graph_objs.Box( y=listx, marker=dict( color=theme[0]), showlegend=False ) else: if 'marker' in kwargs: kwargs['marker']['size'] = size kwargs['marker']['color'] = index_vals kwargs['marker']['colorscale'] = color kwargs['marker']['showscale'] = False trace = graph_objs.Scatter( x=listx, y=listy, mode='markers', showlegend=False, **kwargs ) else: trace = graph_objs.Scatter( x=listx, y=listy, mode='markers', marker=dict( size=size, color=index_vals, colorscale=color, showscale=False), showlegend=False, **kwargs ) # Push the trace into list trace_list.append(trace) legend_param += 1 trace_index = 0 indices = range(1, dim + 1) for y_index in indices: for x_index in indices: fig.append_trace(trace_list[trace_index], y_index, x_index) trace_index += 1 # Insert headers into the figure for j in range(dim): xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) fig['layout'][xaxis_key].update(title=headers[j]) for j in range(dim): yaxis_key = 'yaxis{}'.format(1 + (dim * j)) fig['layout'][yaxis_key].update(title=headers[j]) if diag == 'histogram': fig['layout'].update( height=height, width=width, title=title, showlegend=True, barmode='stack') return fig elif diag == 'box': fig['layout'].update( height=height, width=width, title=title, showlegend=True) return fig else: fig['layout'].update( height=height, width=width, title=title, showlegend=True) return fig @staticmethod def _validate_index(index_vals): """ Validates if a list contains all numbers or all strings :raises: (PlotlyError) If there are any two items in the list whose types differ """ from numbers import Number if isinstance(index_vals[0], Number): if not all(isinstance(item, Number) for item in index_vals): raise exceptions.PlotlyError("Error in indexing column. " "Make sure all entries of each " "column are all numbers or " "all strings.") elif isinstance(index_vals[0], str): if not all(isinstance(item, str) for item in index_vals): raise exceptions.PlotlyError("Error in indexing column. " "Make sure all entries of each " "column are all numbers or " "all strings.") @staticmethod def _validate_dataframe(array): """ Validates all strings or numbers in each dataframe column :raises: (PlotlyError) If there are any two items in any list whose types differ """ from numbers import Number for vector in array: if isinstance(vector[0], Number): if not all(isinstance(item, Number) for item in vector): raise exceptions.PlotlyError("Error in dataframe. " "Make sure all entries of " "each column are either " "numbers or strings.") elif isinstance(vector[0], str): if not all(isinstance(item, str) for item in vector): raise exceptions.PlotlyError("Error in dataframe. " "Make sure all entries of " "each column are either " "numbers or strings.") @staticmethod def _validate_scatterplotmatrix(df, index, diag, colormap_type, **kwargs): """ Validates basic inputs for FigureFactory.create_scatterplotmatrix() :raises: (PlotlyError) If pandas is not imported :raises: (PlotlyError) If pandas dataframe is not inputted :raises: (PlotlyError) If pandas dataframe has <= 1 columns :raises: (PlotlyError) If diagonal plot choice (diag) is not one of the viable options :raises: (PlotlyError) If colormap_type is not a valid choice :raises: (PlotlyError) If kwargs contains 'size', 'color' or 'colorscale' """ if _pandas_imported is False: raise ImportError("FigureFactory.scatterplotmatrix requires " "a pandas DataFrame.") # Check if pandas dataframe if not isinstance(df, pd.core.frame.DataFrame): raise exceptions.PlotlyError("Dataframe not inputed. Please " "use a pandas dataframe to pro" "duce a scatterplot matrix.") # Check if dataframe is 1 column or less if len(df.columns) <= 1: raise exceptions.PlotlyError("Dataframe has only one column. To " "use the scatterplot matrix, use at " "least 2 columns.") # Check that diag parameter is a valid selection if diag not in DIAG_CHOICES: raise exceptions.PlotlyError("Make sure diag is set to " "one of {}".format(DIAG_CHOICES)) # Check that colormap_types is a valid selection if colormap_type not in VALID_COLORMAP_TYPES: raise exceptions.PlotlyError("Must choose a valid colormap type. " "Either 'cat' or 'seq' for a cate" "gorical and sequential colormap " "respectively.") # Check for not 'size' or 'color' in 'marker' of **kwargs if 'marker' in kwargs: FORBIDDEN_PARAMS = ['size', 'color', 'colorscale'] if any(param in kwargs['marker'] for param in FORBIDDEN_PARAMS): raise exceptions.PlotlyError("Your kwargs dictionary cannot " "include the 'size', 'color' or " "'colorscale' key words inside " "the marker dict since 'size' is " "already an argument of the " "scatterplot matrix function and " "both 'color' and 'colorscale " "are set internally.") @staticmethod def _endpts_to_intervals(endpts): """ Returns a list of intervals for categorical colormaps Accepts a list or tuple of sequentially increasing numbers and returns a list representation of the mathematical intervals with these numbers as endpoints. For example, [1, 6] returns [[-inf, 1], [1, 6], [6, inf]] :raises: (PlotlyError) If input is not a list or tuple :raises: (PlotlyError) If the input contains a string :raises: (PlotlyError) If any number does not increase after the previous one in the sequence """ length = len(endpts) # Check if endpts is a list or tuple if not (isinstance(endpts, (tuple)) or isinstance(endpts, (list))): raise exceptions.PlotlyError("The intervals_endpts argument must " "be a list or tuple of a sequence " "of increasing numbers.") # Check if endpts contains only numbers for item in endpts: if isinstance(item, str): raise exceptions.PlotlyError("The intervals_endpts argument " "must be a list or tuple of a " "sequence of increasing " "numbers.") # Check if numbers in endpts are increasing for k in range(length-1): if endpts[k] >= endpts[k+1]: raise exceptions.PlotlyError("The intervals_endpts argument " "must be a list or tuple of a " "sequence of increasing " "numbers.") else: intervals = [] # add -inf to intervals intervals.append([float('-inf'), endpts[0]]) for k in range(length - 1): interval = [] interval.append(endpts[k]) interval.append(endpts[k + 1]) intervals.append(interval) # add +inf to intervals intervals.append([endpts[length - 1], float('inf')]) return intervals @staticmethod def _convert_to_RGB_255(colors): """ Multiplies each element of a triplet by 255 Each coordinate of the color tuple is rounded to the nearest float and then is turned into an integer. If a number is of the form x.5, then if x is odd, the number rounds up to (x+1). Otherwise, it rounds down to just x. This is the way rounding works in Python 3 and in current statistical analysis to avoid rounding bias """ rgb_components = [] for component in colors: rounded_num = decimal.Decimal(str(component*255.0)).quantize( decimal.Decimal('1'), rounding=decimal.ROUND_HALF_EVEN ) # convert rounded number to an integer from 'Decimal' form rounded_num = int(rounded_num) rgb_components.append(rounded_num) return (rgb_components[0], rgb_components[1], rgb_components[2]) @staticmethod def _n_colors(lowcolor, highcolor, n_colors): """ Splits a low and high color into a list of n_colors colors in it Accepts two color tuples and returns a list of n_colors colors which form the intermediate colors between lowcolor and highcolor from linearly interpolating through RGB space """ diff_0 = float(highcolor[0] - lowcolor[0]) incr_0 = diff_0/(n_colors - 1) diff_1 = float(highcolor[1] - lowcolor[1]) incr_1 = diff_1/(n_colors - 1) diff_2 = float(highcolor[2] - lowcolor[2]) incr_2 = diff_2/(n_colors - 1) color_tuples = [] for index in range(n_colors): new_tuple = (lowcolor[0] + (index * incr_0), lowcolor[1] + (index * incr_1), lowcolor[2] + (index * incr_2)) color_tuples.append(new_tuple) return color_tuples @staticmethod def _label_rgb(colors): """ Takes tuple (a, b, c) and returns an rgb color 'rgb(a, b, c)' """ return ('rgb(%s, %s, %s)' % (colors[0], colors[1], colors[2])) @staticmethod def _unlabel_rgb(colors): """ Takes rgb color(s) 'rgb(a, b, c)' and returns tuple(s) (a, b, c) This function takes either an 'rgb(a, b, c)' color or a list of such colors and returns the color tuples in tuple(s) (a, b, c) """ str_vals = '' for index in range(len(colors)): try: float(colors[index]) str_vals = str_vals + colors[index] except ValueError: if colors[index] == ',' or colors[index] == '.': str_vals = str_vals + colors[index] str_vals = str_vals + ',' numbers = [] str_num = '' for char in str_vals: if char != ',': str_num = str_num + char else: numbers.append(float(str_num)) str_num = '' return (numbers[0], numbers[1], numbers[2]) @staticmethod def create_scatterplotmatrix(df, index=None, endpts=None, diag='scatter', height=500, width=500, size=6, title='Scatterplot Matrix', colormap=None, colormap_type='cat', dataframe=None, headers=None, index_vals=None, **kwargs): """ Returns data for a scatterplot matrix. :param (array) df: array of the data with column headers :param (str) index: name of the index column in data array :param (list|tuple) endpts: takes an increasing sequece of numbers that defines intervals on the real line. They are used to group the entries in an index of numbers into their corresponding interval and therefore can be treated as categorical data :param (str) diag: sets the chart type for the main diagonal plots :param (int|float) height: sets the height of the chart :param (int|float) width: sets the width of the chart :param (float) size: sets the marker size (in px) :param (str) title: the title label of the scatterplot matrix :param (str|tuple|list|dict) colormap: either a plotly scale name, an rgb or hex color, a color tuple, a list of colors or a dictionary. An rgb color is of the form 'rgb(x, y, z)' where x, y and z belong to the interval [0, 255] and a color tuple is a tuple of the form (a, b, c) where a, b and c belong to [0, 1]. If colormap is a list, it must contain valid color types as its members. If colormap is a dictionary, all the string entries in the index column must be a key in colormap. In this case, the colormap_type is forced to 'cat' or categorical :param (str) colormap_type: determines how colormap is interpreted. Valid choices are 'seq' (sequential) and 'cat' (categorical). If 'seq' is selected, only the first two colors in colormap will be considered (when colormap is a list) and the index values will be linearly interpolated between those two colors. This option is forced if all index values are numeric. If 'cat' is selected, a color from colormap will be assigned to each category from index, including the intervals if endpts is being used :param (dict) **kwargs: a dictionary of scatterplot arguments The only forbidden parameters are 'size', 'color' and 'colorscale' in 'marker' Example 1: Vanilla Scatterplot Matrix ``` import plotly.plotly as py from plotly.graph_objs import graph_objs from plotly.tools import FigureFactory as FF import numpy as np import pandas as pd # Create dataframe df = pd.DataFrame(np.random.randn(10, 2), columns=['Column 1', 'Column 2']) # Create scatterplot matrix fig = FF.create_scatterplotmatrix(df) # Plot py.iplot(fig, filename='Vanilla Scatterplot Matrix') ``` Example 2: Indexing a Column ``` import plotly.plotly as py from plotly.graph_objs import graph_objs from plotly.tools import FigureFactory as FF import numpy as np import pandas as pd # Create dataframe with index df = pd.DataFrame(np.random.randn(10, 2), columns=['A', 'B']) # Add another column of strings to the dataframe df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple', 'grape', 'pear', 'pear', 'apple', 'pear']) # Create scatterplot matrix fig = FF.create_scatterplotmatrix(df, index='Fruit', size=10) # Plot py.iplot(fig, filename = 'Scatterplot Matrix with Index') ``` Example 3: Styling the Diagonal Subplots ``` import plotly.plotly as py from plotly.graph_objs import graph_objs from plotly.tools import FigureFactory as FF import numpy as np import pandas as pd # Create dataframe with index df = pd.DataFrame(np.random.randn(10, 4), columns=['A', 'B', 'C', 'D']) # Add another column of strings to the dataframe df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple', 'grape', 'pear', 'pear', 'apple', 'pear']) # Create scatterplot matrix fig = FF.create_scatterplotmatrix(df, diag='box', index='Fruit', height=1000, width=1000) # Plot py.iplot(fig, filename = 'Scatterplot Matrix - Diagonal Styling') ``` Example 4: Use a Theme to Style the Subplots ``` import plotly.plotly as py from plotly.graph_objs import graph_objs from plotly.tools import FigureFactory as FF import numpy as np import pandas as pd # Create dataframe with random data df = pd.DataFrame(np.random.randn(100, 3), columns=['A', 'B', 'C']) # Create scatterplot matrix using a built-in # Plotly palette scale and indexing column 'A' fig = FF.create_scatterplotmatrix(df, diag='histogram', index='A', colormap='Blues', height=800, width=800) # Plot py.iplot(fig, filename = 'Scatterplot Matrix - Colormap Theme') ``` Example 5: Example 4 with Interval Factoring ``` import plotly.plotly as py from plotly.graph_objs import graph_objs from plotly.tools import FigureFactory as FF import numpy as np import pandas as pd # Create dataframe with random data df = pd.DataFrame(np.random.randn(100, 3), columns=['A', 'B', 'C']) # Create scatterplot matrix using a list of 2 rgb tuples # and endpoints at -1, 0 and 1 fig = FF.create_scatterplotmatrix(df, diag='histogram', index='A', colormap=['rgb(140, 255, 50)', 'rgb(170, 60, 115)', '#6c4774', (0.5, 0.1, 0.8)], endpts=[-1, 0, 1], height=800, width=800) # Plot py.iplot(fig, filename = 'Scatterplot Matrix - Intervals') ``` Example 6: Using the colormap as a Dictionary ``` import plotly.plotly as py from plotly.graph_objs import graph_objs from plotly.tools import FigureFactory as FF import numpy as np import pandas as pd import random # Create dataframe with random data df = pd.DataFrame(np.random.randn(100, 3), columns=['Column A', 'Column B', 'Column C']) # Add new color column to dataframe new_column = [] strange_colors = ['turquoise', 'limegreen', 'goldenrod'] for j in range(100): new_column.append(random.choice(strange_colors)) df['Colors'] = pd.Series(new_column, index=df.index) # Create scatterplot matrix using a dictionary of hex color values # which correspond to actual color names in 'Colors' column fig = FF.create_scatterplotmatrix( df, diag='box', index='Colors', colormap= dict( turquoise = '#00F5FF', limegreen = '#32CD32', goldenrod = '#DAA520' ), colormap_type='cat', height=800, width=800 ) # Plot py.iplot(fig, filename = 'Scatterplot Matrix - colormap dictionary ') ``` """ # TODO: protected until #282 if dataframe is None: dataframe = [] if headers is None: headers = [] if index_vals is None: index_vals = [] FigureFactory._validate_scatterplotmatrix(df, index, diag, colormap_type, **kwargs) # Validate colormap if isinstance(colormap, dict): colormap = FigureFactory._validate_colors_dict(colormap, 'rgb') else: colormap = FigureFactory._validate_colors(colormap, 'rgb') if not index: for name in df: headers.append(name) for name in headers: dataframe.append(df[name].values.tolist()) # Check for same data-type in df columns FigureFactory._validate_dataframe(dataframe) figure = FigureFactory._scatterplot(dataframe, headers, diag, size, height, width, title, **kwargs) return figure else: # Validate index selection if index not in df: raise exceptions.PlotlyError("Make sure you set the index " "input variable to one of the " "column names of your " "dataframe.") index_vals = df[index].values.tolist() for name in df: if name != index: headers.append(name) for name in headers: dataframe.append(df[name].values.tolist()) # check for same data-type in each df column FigureFactory._validate_dataframe(dataframe) FigureFactory._validate_index(index_vals) # check if all colormap keys are in the index # if colormap is a dictionary if isinstance(colormap, dict): for key in colormap: if not all(index in colormap for index in index_vals): raise exceptions.PlotlyError("If colormap is a " "dictionary, all the " "names in the index " "must be keys.") figure = FigureFactory._scatterplot_dict( dataframe, headers, diag, size, height, width, title, index, index_vals, endpts, colormap, colormap_type, **kwargs ) return figure else: figure = FigureFactory._scatterplot_theme( dataframe, headers, diag, size, height, width, title, index, index_vals, endpts, colormap, colormap_type, **kwargs ) return figure @staticmethod def _validate_equal_length(*args): """ Validates that data lists or ndarrays are the same length. :raises: (PlotlyError) If any data lists are not the same length. """ length = len(args[0]) if any(len(lst) != length for lst in args): raise exceptions.PlotlyError("Oops! Your data lists or ndarrays " "should be the same length.") @staticmethod def _validate_ohlc(open, high, low, close, direction, **kwargs): """ ohlc and candlestick specific validations Specifically, this checks that the high value is the greatest value and the low value is the lowest value in each unit. See FigureFactory.create_ohlc() or FigureFactory.create_candlestick() for params :raises: (PlotlyError) If the high value is not the greatest value in each unit. :raises: (PlotlyError) If the low value is not the lowest value in each unit. :raises: (PlotlyError) If direction is not 'increasing' or 'decreasing' """ for lst in [open, low, close]: for index in range(len(high)): if high[index] < lst[index]: raise exceptions.PlotlyError("Oops! Looks like some of " "your high values are less " "the corresponding open, " "low, or close values. " "Double check that your data " "is entered in O-H-L-C order") for lst in [open, high, close]: for index in range(len(low)): if low[index] > lst[index]: raise exceptions.PlotlyError("Oops! Looks like some of " "your low values are greater " "than the corresponding high" ", open, or close values. " "Double check that your data " "is entered in O-H-L-C order") direction_opts = ('increasing', 'decreasing', 'both') if direction not in direction_opts: raise exceptions.PlotlyError("direction must be defined as " "'increasing', 'decreasing', or " "'both'") @staticmethod def _validate_distplot(hist_data, curve_type): """ Distplot-specific validations :raises: (PlotlyError) If hist_data is not a list of lists :raises: (PlotlyError) If curve_type is not valid (i.e. not 'kde' or 'normal'). """ try: import pandas as pd _pandas_imported = True except ImportError: _pandas_imported = False hist_data_types = (list,) if _numpy_imported: hist_data_types += (np.ndarray,) if _pandas_imported: hist_data_types += (pd.core.series.Series,) if not isinstance(hist_data[0], hist_data_types): raise exceptions.PlotlyError("Oops, this function was written " "to handle multiple datasets, if " "you want to plot just one, make " "sure your hist_data variable is " "still a list of lists, i.e. x = " "[1, 2, 3] -> x = [[1, 2, 3]]") curve_opts = ('kde', 'normal') if curve_type not in curve_opts: raise exceptions.PlotlyError("curve_type must be defined as " "'kde' or 'normal'") if _scipy_imported is False: raise ImportError("FigureFactory.create_distplot requires scipy") @staticmethod def _validate_positive_scalars(**kwargs): """ Validates that all values given in key/val pairs are positive. Accepts kwargs to improve Exception messages. :raises: (PlotlyError) If any value is < 0 or raises. """ for key, val in kwargs.items(): try: if val <= 0: raise ValueError('{} must be > 0, got {}'.format(key, val)) except TypeError: raise exceptions.PlotlyError('{} must be a number, got {}' .format(key, val)) @staticmethod def _validate_streamline(x, y): """ Streamline-specific validations Specifically, this checks that x and y are both evenly spaced, and that the package numpy is available. See FigureFactory.create_streamline() for params :raises: (ImportError) If numpy is not available. :raises: (PlotlyError) If x is not evenly spaced. :raises: (PlotlyError) If y is not evenly spaced. """ if _numpy_imported is False: raise ImportError("FigureFactory.create_streamline requires numpy") for index in range(len(x) - 1): if ((x[index + 1] - x[index]) - (x[1] - x[0])) > .0001: raise exceptions.PlotlyError("x must be a 1 dimensional, " "evenly spaced array") for index in range(len(y) - 1): if ((y[index + 1] - y[index]) - (y[1] - y[0])) > .0001: raise exceptions.PlotlyError("y must be a 1 dimensional, " "evenly spaced array") @staticmethod def _validate_annotated_heatmap(z, x, y, annotation_text): """ Annotated-heatmap-specific validations Check that if a text matrix is supplied, it has the same dimensions as the z matrix. See FigureFactory.create_annotated_heatmap() for params :raises: (PlotlyError) If z and text matrices do not have the same dimensions. """ if annotation_text is not None and isinstance(annotation_text, list): FigureFactory._validate_equal_length(z, annotation_text) for lst in range(len(z)): if len(z[lst]) != len(annotation_text[lst]): raise exceptions.PlotlyError("z and text should have the " "same dimensions") if x: if len(x) != len(z[0]): raise exceptions.PlotlyError("oops, the x list that you " "provided does not match the " "width of your z matrix ") if y: if len(y) != len(z): raise exceptions.PlotlyError("oops, the y list that you " "provided does not match the " "length of your z matrix ") @staticmethod def _validate_table(table_text, font_colors): """ Table-specific validations Check that font_colors is supplied correctly (1, 3, or len(text) colors). :raises: (PlotlyError) If font_colors is supplied incorretly. See FigureFactory.create_table() for params """ font_colors_len_options = [1, 3, len(table_text)] if len(font_colors) not in font_colors_len_options: raise exceptions.PlotlyError("Oops, font_colors should be a list " "of length 1, 3 or len(text)") @staticmethod def _flatten(array): """ Uses list comprehension to flatten array :param (array): An iterable to flatten :raises (PlotlyError): If iterable is not nested. :rtype (list): The flattened list. """ try: return [item for sublist in array for item in sublist] except TypeError: raise exceptions.PlotlyError("Your data array could not be " "flattened! Make sure your data is " "entered as lists or ndarrays!") @staticmethod def _hex_to_rgb(value): """ Calculates rgb values from a hex color code. :param (string) value: Hex color string :rtype (tuple) (r_value, g_value, b_value): tuple of rgb values """ value = value.lstrip('#') hex_total_length = len(value) rgb_section_length = hex_total_length // 3 return tuple(int(value[i:i + rgb_section_length], 16) for i in range(0, hex_total_length, rgb_section_length)) @staticmethod def create_quiver(x, y, u, v, scale=.1, arrow_scale=.3, angle=math.pi / 9, **kwargs): """ Returns data for a quiver plot. :param (list|ndarray) x: x coordinates of the arrow locations :param (list|ndarray) y: y coordinates of the arrow locations :param (list|ndarray) u: x components of the arrow vectors :param (list|ndarray) v: y components of the arrow vectors :param (float in [0,1]) scale: scales size of the arrows(ideally to avoid overlap). Default = .1 :param (float in [0,1]) arrow_scale: value multiplied to length of barb to get length of arrowhead. Default = .3 :param (angle in radians) angle: angle of arrowhead. Default = pi/9 :param kwargs: kwargs passed through plotly.graph_objs.Scatter for more information on valid kwargs call help(plotly.graph_objs.Scatter) :rtype (dict): returns a representation of quiver figure. Example 1: Trivial Quiver ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF import math # 1 Arrow from (0,0) to (1,1) fig = FF.create_quiver(x=[0], y=[0], u=[1], v=[1], scale=1) py.plot(fig, filename='quiver') ``` Example 2: Quiver plot using meshgrid ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF import numpy as np import math # Add data x,y = np.meshgrid(np.arange(0, 2, .2), np.arange(0, 2, .2)) u = np.cos(x)*y v = np.sin(x)*y #Create quiver fig = FF.create_quiver(x, y, u, v) # Plot py.plot(fig, filename='quiver') ``` Example 3: Styling the quiver plot ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF import numpy as np import math # Add data x, y = np.meshgrid(np.arange(-np.pi, math.pi, .5), np.arange(-math.pi, math.pi, .5)) u = np.cos(x)*y v = np.sin(x)*y # Create quiver fig = FF.create_quiver(x, y, u, v, scale=.2, arrow_scale=.3, angle=math.pi/6, name='Wind Velocity', line=Line(width=1)) # Add title to layout fig['layout'].update(title='Quiver Plot') # Plot py.plot(fig, filename='quiver') ``` """ # TODO: protected until #282 from plotly.graph_objs import graph_objs FigureFactory._validate_equal_length(x, y, u, v) FigureFactory._validate_positive_scalars(arrow_scale=arrow_scale, scale=scale) barb_x, barb_y = _Quiver(x, y, u, v, scale, arrow_scale, angle).get_barbs() arrow_x, arrow_y = _Quiver(x, y, u, v, scale, arrow_scale, angle).get_quiver_arrows() quiver = graph_objs.Scatter(x=barb_x + arrow_x, y=barb_y + arrow_y, mode='lines', **kwargs) data = [quiver] layout = graph_objs.Layout(hovermode='closest') return graph_objs.Figure(data=data, layout=layout) @staticmethod def create_streamline(x, y, u, v, density=1, angle=math.pi / 9, arrow_scale=.09, **kwargs): """ Returns data for a streamline plot. :param (list|ndarray) x: 1 dimensional, evenly spaced list or array :param (list|ndarray) y: 1 dimensional, evenly spaced list or array :param (ndarray) u: 2 dimensional array :param (ndarray) v: 2 dimensional array :param (float|int) density: controls the density of streamlines in plot. This is multiplied by 30 to scale similiarly to other available streamline functions such as matplotlib. Default = 1 :param (angle in radians) angle: angle of arrowhead. Default = pi/9 :param (float in [0,1]) arrow_scale: value to scale length of arrowhead Default = .09 :param kwargs: kwargs passed through plotly.graph_objs.Scatter for more information on valid kwargs call help(plotly.graph_objs.Scatter) :rtype (dict): returns a representation of streamline figure. Example 1: Plot simple streamline and increase arrow size ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF import numpy as np import math # Add data x = np.linspace(-3, 3, 100) y = np.linspace(-3, 3, 100) Y, X = np.meshgrid(x, y) u = -1 - X**2 + Y v = 1 + X - Y**2 u = u.T # Transpose v = v.T # Transpose # Create streamline fig = FF.create_streamline(x, y, u, v, arrow_scale=.1) # Plot py.plot(fig, filename='streamline') ``` Example 2: from nbviewer.ipython.org/github/barbagroup/AeroPython ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF import numpy as np import math # Add data N = 50 x_start, x_end = -2.0, 2.0 y_start, y_end = -1.0, 1.0 x = np.linspace(x_start, x_end, N) y = np.linspace(y_start, y_end, N) X, Y = np.meshgrid(x, y) ss = 5.0 x_s, y_s = -1.0, 0.0 # Compute the velocity field on the mesh grid u_s = ss/(2*np.pi) * (X-x_s)/((X-x_s)**2 + (Y-y_s)**2) v_s = ss/(2*np.pi) * (Y-y_s)/((X-x_s)**2 + (Y-y_s)**2) # Create streamline fig = FF.create_streamline(x, y, u_s, v_s, density=2, name='streamline') # Add source point point = Scatter(x=[x_s], y=[y_s], mode='markers', marker=Marker(size=14), name='source point') # Plot fig['data'].append(point) py.plot(fig, filename='streamline') ``` """ # TODO: protected until #282 from plotly.graph_objs import graph_objs FigureFactory._validate_equal_length(x, y) FigureFactory._validate_equal_length(u, v) FigureFactory._validate_streamline(x, y) FigureFactory._validate_positive_scalars(density=density, arrow_scale=arrow_scale) streamline_x, streamline_y = _Streamline(x, y, u, v, density, angle, arrow_scale).sum_streamlines() arrow_x, arrow_y = _Streamline(x, y, u, v, density, angle, arrow_scale).get_streamline_arrows() streamline = graph_objs.Scatter(x=streamline_x + arrow_x, y=streamline_y + arrow_y, mode='lines', **kwargs) data = [streamline] layout = graph_objs.Layout(hovermode='closest') return graph_objs.Figure(data=data, layout=layout) @staticmethod def _make_increasing_ohlc(open, high, low, close, dates, **kwargs): """ Makes increasing ohlc sticks _make_increasing_ohlc() and _make_decreasing_ohlc separate the increasing trace from the decreasing trace so kwargs (such as color) can be passed separately to increasing or decreasing traces when direction is set to 'increasing' or 'decreasing' in FigureFactory.create_candlestick() :param (list) open: opening values :param (list) high: high values :param (list) low: low values :param (list) close: closing values :param (list) dates: list of datetime objects. Default: None :param kwargs: kwargs to be passed to increasing trace via plotly.graph_objs.Scatter. :rtype (trace) ohlc_incr_data: Scatter trace of all increasing ohlc sticks. """ (flat_increase_x, flat_increase_y, text_increase) = _OHLC(open, high, low, close, dates).get_increase() if 'name' in kwargs: showlegend = True else: kwargs.setdefault('name', 'Increasing') showlegend = False kwargs.setdefault('line', dict(color=_DEFAULT_INCREASING_COLOR, width=1)) kwargs.setdefault('text', text_increase) ohlc_incr = dict(type='scatter', x=flat_increase_x, y=flat_increase_y, mode='lines', showlegend=showlegend, **kwargs) return ohlc_incr @staticmethod def _make_decreasing_ohlc(open, high, low, close, dates, **kwargs): """ Makes decreasing ohlc sticks :param (list) open: opening values :param (list) high: high values :param (list) low: low values :param (list) close: closing values :param (list) dates: list of datetime objects. Default: None :param kwargs: kwargs to be passed to increasing trace via plotly.graph_objs.Scatter. :rtype (trace) ohlc_decr_data: Scatter trace of all decreasing ohlc sticks. """ (flat_decrease_x, flat_decrease_y, text_decrease) = _OHLC(open, high, low, close, dates).get_decrease() kwargs.setdefault('line', dict(color=_DEFAULT_DECREASING_COLOR, width=1)) kwargs.setdefault('text', text_decrease) kwargs.setdefault('showlegend', False) kwargs.setdefault('name', 'Decreasing') ohlc_decr = dict(type='scatter', x=flat_decrease_x, y=flat_decrease_y, mode='lines', **kwargs) return ohlc_decr @staticmethod def create_ohlc(open, high, low, close, dates=None, direction='both', **kwargs): """ BETA function that creates an ohlc chart :param (list) open: opening values :param (list) high: high values :param (list) low: low values :param (list) close: closing :param (list) dates: list of datetime objects. Default: None :param (string) direction: direction can be 'increasing', 'decreasing', or 'both'. When the direction is 'increasing', the returned figure consists of all units where the close value is greater than the corresponding open value, and when the direction is 'decreasing', the returned figure consists of all units where the close value is less than or equal to the corresponding open value. When the direction is 'both', both increasing and decreasing units are returned. Default: 'both' :param kwargs: kwargs passed through plotly.graph_objs.Scatter. These kwargs describe other attributes about the ohlc Scatter trace such as the color or the legend name. For more information on valid kwargs call help(plotly.graph_objs.Scatter) :rtype (dict): returns a representation of an ohlc chart figure. Example 1: Simple OHLC chart from a Pandas DataFrame ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF from datetime import datetime import pandas.io.data as web df = web.DataReader("aapl", 'yahoo', datetime(2008, 8, 15), datetime(2008, 10, 15)) fig = FF.create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index) py.plot(fig, filename='finance/aapl-ohlc') ``` Example 2: Add text and annotations to the OHLC chart ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF from datetime import datetime import pandas.io.data as web df = web.DataReader("aapl", 'yahoo', datetime(2008, 8, 15), datetime(2008, 10, 15)) fig = FF.create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index) # Update the fig - all options here: https://plot.ly/python/reference/#Layout fig['layout'].update({ 'title': 'The Great Recession', 'yaxis': {'title': 'AAPL Stock'}, 'shapes': [{ 'x0': '2008-09-15', 'x1': '2008-09-15', 'type': 'line', 'y0': 0, 'y1': 1, 'xref': 'x', 'yref': 'paper', 'line': {'color': 'rgb(40,40,40)', 'width': 0.5} }], 'annotations': [{ 'text': "the fall of Lehman Brothers", 'x': '2008-09-15', 'y': 1.02, 'xref': 'x', 'yref': 'paper', 'showarrow': False, 'xanchor': 'left' }] }) py.plot(fig, filename='finance/aapl-recession-ohlc', validate=False) ``` Example 3: Customize the OHLC colors ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF from plotly.graph_objs import Line, Marker from datetime import datetime import pandas.io.data as web df = web.DataReader("aapl", 'yahoo', datetime(2008, 1, 1), datetime(2009, 4, 1)) # Make increasing ohlc sticks and customize their color and name fig_increasing = FF.create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index, direction='increasing', name='AAPL', line=Line(color='rgb(150, 200, 250)')) # Make decreasing ohlc sticks and customize their color and name fig_decreasing = FF.create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index, direction='decreasing', line=Line(color='rgb(128, 128, 128)')) # Initialize the figure fig = fig_increasing # Add decreasing data with .extend() fig['data'].extend(fig_decreasing['data']) py.iplot(fig, filename='finance/aapl-ohlc-colors', validate=False) ``` Example 4: OHLC chart with datetime objects ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF from datetime import datetime # Add data open_data = [33.0, 33.3, 33.5, 33.0, 34.1] high_data = [33.1, 33.3, 33.6, 33.2, 34.8] low_data = [32.7, 32.7, 32.8, 32.6, 32.8] close_data = [33.0, 32.9, 33.3, 33.1, 33.1] dates = [datetime(year=2013, month=10, day=10), datetime(year=2013, month=11, day=10), datetime(year=2013, month=12, day=10), datetime(year=2014, month=1, day=10), datetime(year=2014, month=2, day=10)] # Create ohlc fig = FF.create_ohlc(open_data, high_data, low_data, close_data, dates=dates) py.iplot(fig, filename='finance/simple-ohlc', validate=False) ``` """ # TODO: protected until #282 from plotly.graph_objs import graph_objs if dates is not None: FigureFactory._validate_equal_length(open, high, low, close, dates) else: FigureFactory._validate_equal_length(open, high, low, close) FigureFactory._validate_ohlc(open, high, low, close, direction, **kwargs) if direction is 'increasing': ohlc_incr = FigureFactory._make_increasing_ohlc(open, high, low, close, dates, **kwargs) data = [ohlc_incr] elif direction is 'decreasing': ohlc_decr = FigureFactory._make_decreasing_ohlc(open, high, low, close, dates, **kwargs) data = [ohlc_decr] else: ohlc_incr = FigureFactory._make_increasing_ohlc(open, high, low, close, dates, **kwargs) ohlc_decr = FigureFactory._make_decreasing_ohlc(open, high, low, close, dates, **kwargs) data = [ohlc_incr, ohlc_decr] layout = graph_objs.Layout(xaxis=dict(zeroline=False), hovermode='closest') return graph_objs.Figure(data=data, layout=layout) @staticmethod def _make_increasing_candle(open, high, low, close, dates, **kwargs): """ Makes boxplot trace for increasing candlesticks _make_increasing_candle() and _make_decreasing_candle separate the increasing traces from the decreasing traces so kwargs (such as color) can be passed separately to increasing or decreasing traces when direction is set to 'increasing' or 'decreasing' in FigureFactory.create_candlestick() :param (list) open: opening values :param (list) high: high values :param (list) low: low values :param (list) close: closing values :param (list) dates: list of datetime objects. Default: None :param kwargs: kwargs to be passed to increasing trace via plotly.graph_objs.Scatter. :rtype (list) candle_incr_data: list of the box trace for increasing candlesticks. """ increase_x, increase_y = _Candlestick( open, high, low, close, dates, **kwargs).get_candle_increase() if 'line' in kwargs: kwargs.setdefault('fillcolor', kwargs['line']['color']) else: kwargs.setdefault('fillcolor', _DEFAULT_INCREASING_COLOR) if 'name' in kwargs: kwargs.setdefault('showlegend', True) else: kwargs.setdefault('showlegend', False) kwargs.setdefault('name', 'Increasing') kwargs.setdefault('line', dict(color=_DEFAULT_INCREASING_COLOR)) candle_incr_data = dict(type='box', x=increase_x, y=increase_y, whiskerwidth=0, boxpoints=False, **kwargs) return [candle_incr_data] @staticmethod def _make_decreasing_candle(open, high, low, close, dates, **kwargs): """ Makes boxplot trace for decreasing candlesticks :param (list) open: opening values :param (list) high: high values :param (list) low: low values :param (list) close: closing values :param (list) dates: list of datetime objects. Default: None :param kwargs: kwargs to be passed to decreasing trace via plotly.graph_objs.Scatter. :rtype (list) candle_decr_data: list of the box trace for decreasing candlesticks. """ decrease_x, decrease_y = _Candlestick( open, high, low, close, dates, **kwargs).get_candle_decrease() if 'line' in kwargs: kwargs.setdefault('fillcolor', kwargs['line']['color']) else: kwargs.setdefault('fillcolor', _DEFAULT_DECREASING_COLOR) kwargs.setdefault('showlegend', False) kwargs.setdefault('line', dict(color=_DEFAULT_DECREASING_COLOR)) kwargs.setdefault('name', 'Decreasing') candle_decr_data = dict(type='box', x=decrease_x, y=decrease_y, whiskerwidth=0, boxpoints=False, **kwargs) return [candle_decr_data] @staticmethod def create_candlestick(open, high, low, close, dates=None, direction='both', **kwargs): """ BETA function that creates a candlestick chart :param (list) open: opening values :param (list) high: high values :param (list) low: low values :param (list) close: closing values :param (list) dates: list of datetime objects. Default: None :param (string) direction: direction can be 'increasing', 'decreasing', or 'both'. When the direction is 'increasing', the returned figure consists of all candlesticks where the close value is greater than the corresponding open value, and when the direction is 'decreasing', the returned figure consists of all candlesticks where the close value is less than or equal to the corresponding open value. When the direction is 'both', both increasing and decreasing candlesticks are returned. Default: 'both' :param kwargs: kwargs passed through plotly.graph_objs.Scatter. These kwargs describe other attributes about the ohlc Scatter trace such as the color or the legend name. For more information on valid kwargs call help(plotly.graph_objs.Scatter) :rtype (dict): returns a representation of candlestick chart figure. Example 1: Simple candlestick chart from a Pandas DataFrame ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF from datetime import datetime import pandas.io.data as web df = web.DataReader("aapl", 'yahoo', datetime(2007, 10, 1), datetime(2009, 4, 1)) fig = FF.create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index) py.plot(fig, filename='finance/aapl-candlestick', validate=False) ``` Example 2: Add text and annotations to the candlestick chart ``` fig = FF.create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index) # Update the fig - all options here: https://plot.ly/python/reference/#Layout fig['layout'].update({ 'title': 'The Great Recession', 'yaxis': {'title': 'AAPL Stock'}, 'shapes': [{ 'x0': '2007-12-01', 'x1': '2007-12-01', 'y0': 0, 'y1': 1, 'xref': 'x', 'yref': 'paper', 'line': {'color': 'rgb(30,30,30)', 'width': 1} }], 'annotations': [{ 'x': '2007-12-01', 'y': 0.05, 'xref': 'x', 'yref': 'paper', 'showarrow': False, 'xanchor': 'left', 'text': 'Official start of the recession' }] }) py.plot(fig, filename='finance/aapl-recession-candlestick', validate=False) ``` Example 3: Customize the candlestick colors ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF from plotly.graph_objs import Line, Marker from datetime import datetime import pandas.io.data as web df = web.DataReader("aapl", 'yahoo', datetime(2008, 1, 1), datetime(2009, 4, 1)) # Make increasing candlesticks and customize their color and name fig_increasing = FF.create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index, direction='increasing', name='AAPL', marker=Marker(color='rgb(150, 200, 250)'), line=Line(color='rgb(150, 200, 250)')) # Make decreasing candlesticks and customize their color and name fig_decreasing = FF.create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index, direction='decreasing', marker=Marker(color='rgb(128, 128, 128)'), line=Line(color='rgb(128, 128, 128)')) # Initialize the figure fig = fig_increasing # Add decreasing data with .extend() fig['data'].extend(fig_decreasing['data']) py.iplot(fig, filename='finance/aapl-candlestick-custom', validate=False) ``` Example 4: Candlestick chart with datetime objects ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF from datetime import datetime # Add data open_data = [33.0, 33.3, 33.5, 33.0, 34.1] high_data = [33.1, 33.3, 33.6, 33.2, 34.8] low_data = [32.7, 32.7, 32.8, 32.6, 32.8] close_data = [33.0, 32.9, 33.3, 33.1, 33.1] dates = [datetime(year=2013, month=10, day=10), datetime(year=2013, month=11, day=10), datetime(year=2013, month=12, day=10), datetime(year=2014, month=1, day=10), datetime(year=2014, month=2, day=10)] # Create ohlc fig = FF.create_candlestick(open_data, high_data, low_data, close_data, dates=dates) py.iplot(fig, filename='finance/simple-candlestick', validate=False) ``` """ # TODO: protected until #282 from plotly.graph_objs import graph_objs if dates is not None: FigureFactory._validate_equal_length(open, high, low, close, dates) else: FigureFactory._validate_equal_length(open, high, low, close) FigureFactory._validate_ohlc(open, high, low, close, direction, **kwargs) if direction is 'increasing': candle_incr_data = FigureFactory._make_increasing_candle( open, high, low, close, dates, **kwargs) data = candle_incr_data elif direction is 'decreasing': candle_decr_data = FigureFactory._make_decreasing_candle( open, high, low, close, dates, **kwargs) data = candle_decr_data else: candle_incr_data = FigureFactory._make_increasing_candle( open, high, low, close, dates, **kwargs) candle_decr_data = FigureFactory._make_decreasing_candle( open, high, low, close, dates, **kwargs) data = candle_incr_data + candle_decr_data layout = graph_objs.Layout() return graph_objs.Figure(data=data, layout=layout) @staticmethod def create_distplot(hist_data, group_labels, bin_size=1., curve_type='kde', colors=[], rug_text=[], histnorm=DEFAULT_HISTNORM, show_hist=True, show_curve=True, show_rug=True): """ BETA function that creates a distplot similar to seaborn.distplot The distplot can be composed of all or any combination of the following 3 components: (1) histogram, (2) curve: (a) kernel density estimation or (b) normal curve, and (3) rug plot. Additionally, multiple distplots (from multiple datasets) can be created in the same plot. :param (list[list]) hist_data: Use list of lists to plot multiple data sets on the same plot. :param (list[str]) group_labels: Names for each data set. :param (list[float]|float) bin_size: Size of histogram bins. Default = 1. :param (str) curve_type: 'kde' or 'normal'. Default = 'kde' :param (str) histnorm: 'probability density' or 'probability' Default = 'probability density' :param (bool) show_hist: Add histogram to distplot? Default = True :param (bool) show_curve: Add curve to distplot? Default = True :param (bool) show_rug: Add rug to distplot? Default = True :param (list[str]) colors: Colors for traces. :param (list[list]) rug_text: Hovertext values for rug_plot, :return (dict): Representation of a distplot figure. Example 1: Simple distplot of 1 data set ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5, 3.5, 4.1, 4.4, 4.5, 4.5, 5.0, 5.0, 5.2, 5.5, 5.5, 5.5, 5.5, 5.5, 6.1, 7.0]] group_labels = ['distplot example'] fig = FF.create_distplot(hist_data, group_labels) url = py.plot(fig, filename='Simple distplot', validate=False) ``` Example 2: Two data sets and added rug text ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF # Add histogram data hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6, -0.9, -0.07, 1.95, 0.9, -0.2, -0.5, 0.3, 0.4, -0.37, 0.6] hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59, 1.0, 0.8, 1.7, 0.5, 0.8, -0.3, 1.2, 0.56, 0.3, 2.2] # Group data together hist_data = [hist1_x, hist2_x] group_labels = ['2012', '2013'] # Add text rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1', 'f1', 'g1', 'h1', 'i1', 'j1', 'k1', 'l1', 'm1', 'n1', 'o1'] rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2', 'f2', 'g2', 'h2', 'i2', 'j2', 'k2', 'l2', 'm2', 'n2', 'o2'] # Group text together rug_text_all = [rug_text_1, rug_text_2] # Create distplot fig = FF.create_distplot( hist_data, group_labels, rug_text=rug_text_all, bin_size=.2) # Add title fig['layout'].update(title='Dist Plot') # Plot! url = py.plot(fig, filename='Distplot with rug text', validate=False) ``` Example 3: Plot with normal curve and hide rug plot ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF import numpy as np x1 = np.random.randn(190) x2 = np.random.randn(200)+1 x3 = np.random.randn(200)-1 x4 = np.random.randn(210)+2 hist_data = [x1, x2, x3, x4] group_labels = ['2012', '2013', '2014', '2015'] fig = FF.create_distplot( hist_data, group_labels, curve_type='normal', show_rug=False, bin_size=.4) url = py.plot(fig, filename='hist and normal curve', validate=False) Example 4: Distplot with Pandas ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF import numpy as np import pandas as pd df = pd.DataFrame({'2012': np.random.randn(200), '2013': np.random.randn(200)+1}) py.iplot(FF.create_distplot([df[c] for c in df.columns], df.columns), filename='examples/distplot with pandas', validate=False) ``` """ # TODO: protected until #282 from plotly.graph_objs import graph_objs FigureFactory._validate_distplot(hist_data, curve_type) FigureFactory._validate_equal_length(hist_data, group_labels) if isinstance(bin_size, (float, int)): bin_size = [bin_size]*len(hist_data) hist = _Distplot( hist_data, histnorm, group_labels, bin_size, curve_type, colors, rug_text, show_hist, show_curve).make_hist() if curve_type == 'normal': curve = _Distplot( hist_data, histnorm, group_labels, bin_size, curve_type, colors, rug_text, show_hist, show_curve).make_normal() else: curve = _Distplot( hist_data, histnorm, group_labels, bin_size, curve_type, colors, rug_text, show_hist, show_curve).make_kde() rug = _Distplot( hist_data, histnorm, group_labels, bin_size, curve_type, colors, rug_text, show_hist, show_curve).make_rug() data = [] if show_hist: data.append(hist) if show_curve: data.append(curve) if show_rug: data.append(rug) layout = graph_objs.Layout( barmode='overlay', hovermode='closest', legend=dict(traceorder='reversed'), xaxis1=dict(domain=[0.0, 1.0], anchor='y2', zeroline=False), yaxis1=dict(domain=[0.35, 1], anchor='free', position=0.0), yaxis2=dict(domain=[0, 0.25], anchor='x1', dtick=1, showticklabels=False)) else: layout = graph_objs.Layout( barmode='overlay', hovermode='closest', legend=dict(traceorder='reversed'), xaxis1=dict(domain=[0.0, 1.0], anchor='y2', zeroline=False), yaxis1=dict(domain=[0., 1], anchor='free', position=0.0)) data = sum(data, []) return graph_objs.Figure(data=data, layout=layout) @staticmethod def create_dendrogram(X, orientation="bottom", labels=None, colorscale=None): """ BETA function that returns a dendrogram Plotly figure object. :param (ndarray) X: Matrix of observations as array of arrays :param (str) orientation: 'top', 'right', 'bottom', or 'left' :param (list) labels: List of axis category labels(observation labels) :param (list) colorscale: Optional colorscale for dendrogram tree clusters Example 1: Simple bottom oriented dendrogram ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF import numpy as np X = np.random.rand(10,10) dendro = FF.create_dendrogram(X) plot_url = py.plot(dendro, filename='simple-dendrogram') ``` Example 2: Dendrogram to put on the left of the heatmap ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF import numpy as np X = np.random.rand(5,5) names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark'] dendro = FF.create_dendrogram(X, orientation='right', labels=names) dendro['layout'].update({'width':700, 'height':500}) py.iplot(dendro, filename='vertical-dendrogram') ``` Example 3: Dendrogram with Pandas ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF import numpy as np import pandas as pd Index= ['A','B','C','D','E','F','G','H','I','J'] df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index) fig = FF.create_dendrogram(df, labels=Index) url = py.plot(fig, filename='pandas-dendrogram') ``` """ dependencies = (_scipy_imported and _scipy__spatial_imported and _scipy__cluster__hierarchy_imported) if dependencies is False: raise ImportError("FigureFactory.create_dendrogram requires scipy, \ scipy.spatial and scipy.hierarchy") s = X.shape if len(s) != 2: exceptions.PlotlyError("X should be 2-dimensional array.") dendrogram = _Dendrogram(X, orientation, labels, colorscale) return {'layout': dendrogram.layout, 'data': dendrogram.data} @staticmethod def create_annotated_heatmap(z, x=None, y=None, annotation_text=None, colorscale='RdBu', font_colors=None, showscale=False, reversescale=False, **kwargs): """ BETA function that creates annotated heatmaps This function adds annotations to each cell of the heatmap. :param (list[list]|ndarray) z: z matrix to create heatmap. :param (list) x: x axis labels. :param (list) y: y axis labels. :param (list[list]|ndarray) annotation_text: Text strings for annotations. Should have the same dimensions as the z matrix. If no text is added, the values of the z matrix are annotated. Default = z matrix values. :param (list|str) colorscale: heatmap colorscale. :param (list) font_colors: List of two color strings: [min_text_color, max_text_color] where min_text_color is applied to annotations for heatmap values < (max_value - min_value)/2. If font_colors is not defined, the colors are defined logically as black or white depending on the heatmap's colorscale. :param (bool) showscale: Display colorscale. Default = False :param kwargs: kwargs passed through plotly.graph_objs.Heatmap. These kwargs describe other attributes about the annotated Heatmap trace such as the colorscale. For more information on valid kwargs call help(plotly.graph_objs.Heatmap) Example 1: Simple annotated heatmap with default configuration ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF z = [[0.300000, 0.00000, 0.65, 0.300000], [1, 0.100005, 0.45, 0.4300], [0.300000, 0.00000, 0.65, 0.300000], [1, 0.100005, 0.45, 0.00000]] figure = FF.create_annotated_heatmap(z) py.iplot(figure) ``` """ # TODO: protected until #282 from plotly.graph_objs import graph_objs # Avoiding mutables in the call signature font_colors = font_colors if font_colors is not None else [] FigureFactory._validate_annotated_heatmap(z, x, y, annotation_text) annotations = _AnnotatedHeatmap(z, x, y, annotation_text, colorscale, font_colors, reversescale, **kwargs).make_annotations() if x or y: trace = dict(type='heatmap', z=z, x=x, y=y, colorscale=colorscale, showscale=showscale, **kwargs) layout = dict(annotations=annotations, xaxis=dict(ticks='', dtick=1, side='top', gridcolor='rgb(0, 0, 0)'), yaxis=dict(ticks='', dtick=1, ticksuffix=' ')) else: trace = dict(type='heatmap', z=z, colorscale=colorscale, showscale=showscale, **kwargs) layout = dict(annotations=annotations, xaxis=dict(ticks='', side='top', gridcolor='rgb(0, 0, 0)', showticklabels=False), yaxis=dict(ticks='', ticksuffix=' ', showticklabels=False)) data = [trace] return graph_objs.Figure(data=data, layout=layout) @staticmethod def create_table(table_text, colorscale=None, font_colors=None, index=False, index_title='', annotation_offset=.45, height_constant=30, hoverinfo='none', **kwargs): """ BETA function that creates data tables :param (pandas.Dataframe | list[list]) text: data for table. :param (str|list[list]) colorscale: Colorscale for table where the color at value 0 is the header color, .5 is the first table color and 1 is the second table color. (Set .5 and 1 to avoid the striped table effect). Default=[[0, '#66b2ff'], [.5, '#d9d9d9'], [1, '#ffffff']] :param (list) font_colors: Color for fonts in table. Can be a single color, three colors, or a color for each row in the table. Default=['#000000'] (black text for the entire table) :param (int) height_constant: Constant multiplied by # of rows to create table height. Default=30. :param (bool) index: Create (header-colored) index column index from Pandas dataframe or list[0] for each list in text. Default=False. :param (string) index_title: Title for index column. Default=''. :param kwargs: kwargs passed through plotly.graph_objs.Heatmap. These kwargs describe other attributes about the annotated Heatmap trace such as the colorscale. For more information on valid kwargs call help(plotly.graph_objs.Heatmap) Example 1: Simple Plotly Table ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF text = [['Country', 'Year', 'Population'], ['US', 2000, 282200000], ['Canada', 2000, 27790000], ['US', 2010, 309000000], ['Canada', 2010, 34000000]] table = FF.create_table(text) py.iplot(table) ``` Example 2: Table with Custom Coloring ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF text = [['Country', 'Year', 'Population'], ['US', 2000, 282200000], ['Canada', 2000, 27790000], ['US', 2010, 309000000], ['Canada', 2010, 34000000]] table = FF.create_table(text, colorscale=[[0, '#000000'], [.5, '#80beff'], [1, '#cce5ff']], font_colors=['#ffffff', '#000000', '#000000']) py.iplot(table) ``` Example 3: Simple Plotly Table with Pandas ``` import plotly.plotly as py from plotly.tools import FigureFactory as FF import pandas as pd df = pd.read_csv('http://www.stat.ubc.ca/~jenny/notOcto/STAT545A/examples/gapminder/data/gapminderDataFiveYear.txt', sep='\t') df_p = df[0:25] table_simple = FF.create_table(df_p) py.iplot(table_simple) ``` """ # TODO: protected until #282 from plotly.graph_objs import graph_objs # Avoiding mutables in the call signature colorscale = \ colorscale if colorscale is not None else [[0, '#00083e'], [.5, '#ededee'], [1, '#ffffff']] font_colors = font_colors if font_colors is not None else ['#ffffff', '#000000', '#000000'] FigureFactory._validate_table(table_text, font_colors) table_matrix = _Table(table_text, colorscale, font_colors, index, index_title, annotation_offset, **kwargs).get_table_matrix() annotations = _Table(table_text, colorscale, font_colors, index, index_title, annotation_offset, **kwargs).make_table_annotations() trace = dict(type='heatmap', z=table_matrix, opacity=.75, colorscale=colorscale, showscale=False, hoverinfo=hoverinfo, **kwargs) data = [trace] layout = dict(annotations=annotations, height=len(table_matrix)*height_constant + 50, margin=dict(t=0, b=0, r=0, l=0), yaxis=dict(autorange='reversed', zeroline=False, gridwidth=2, ticks='', dtick=1, tick0=.5, showticklabels=False), xaxis=dict(zeroline=False, gridwidth=2, ticks='', dtick=1, tick0=-0.5, showticklabels=False)) return graph_objs.Figure(data=data, layout=layout) class _Quiver(FigureFactory): """ Refer to FigureFactory.create_quiver() for docstring """ def __init__(self, x, y, u, v, scale, arrow_scale, angle, **kwargs): try: x = FigureFactory._flatten(x) except exceptions.PlotlyError: pass try: y = FigureFactory._flatten(y) except exceptions.PlotlyError: pass try: u = FigureFactory._flatten(u) except exceptions.PlotlyError: pass try: v = FigureFactory._flatten(v) except exceptions.PlotlyError: pass self.x = x self.y = y self.u = u self.v = v self.scale = scale self.arrow_scale = arrow_scale self.angle = angle self.end_x = [] self.end_y = [] self.scale_uv() barb_x, barb_y = self.get_barbs() arrow_x, arrow_y = self.get_quiver_arrows() def scale_uv(self): """ Scales u and v to avoid overlap of the arrows. u and v are added to x and y to get the endpoints of the arrows so a smaller scale value will result in less overlap of arrows. """ self.u = [i * self.scale for i in self.u] self.v = [i * self.scale for i in self.v] def get_barbs(self): """ Creates x and y startpoint and endpoint pairs After finding the endpoint of each barb this zips startpoint and endpoint pairs to create 2 lists: x_values for barbs and y values for barbs :rtype: (list, list) barb_x, barb_y: list of startpoint and endpoint x_value pairs separated by a None to create the barb of the arrow, and list of startpoint and endpoint y_value pairs separated by a None to create the barb of the arrow. """ self.end_x = [i + j for i, j in zip(self.x, self.u)] self.end_y = [i + j for i, j in zip(self.y, self.v)] empty = [None] * len(self.x) barb_x = FigureFactory._flatten(zip(self.x, self.end_x, empty)) barb_y = FigureFactory._flatten(zip(self.y, self.end_y, empty)) return barb_x, barb_y def get_quiver_arrows(self): """ Creates lists of x and y values to plot the arrows Gets length of each barb then calculates the length of each side of the arrow. Gets angle of barb and applies angle to each side of the arrowhead. Next uses arrow_scale to scale the length of arrowhead and creates x and y values for arrowhead point1 and point2. Finally x and y values for point1, endpoint and point2s for each arrowhead are separated by a None and zipped to create lists of x and y values for the arrows. :rtype: (list, list) arrow_x, arrow_y: list of point1, endpoint, point2 x_values separated by a None to create the arrowhead and list of point1, endpoint, point2 y_values separated by a None to create the barb of the arrow. """ dif_x = [i - j for i, j in zip(self.end_x, self.x)] dif_y = [i - j for i, j in zip(self.end_y, self.y)] # Get barb lengths(default arrow length = 30% barb length) barb_len = [None] * len(self.x) for index in range(len(barb_len)): barb_len[index] = math.hypot(dif_x[index], dif_y[index]) # Make arrow lengths arrow_len = [None] * len(self.x) arrow_len = [i * self.arrow_scale for i in barb_len] # Get barb angles barb_ang = [None] * len(self.x) for index in range(len(barb_ang)): barb_ang[index] = math.atan2(dif_y[index], dif_x[index]) # Set angles to create arrow ang1 = [i + self.angle for i in barb_ang] ang2 = [i - self.angle for i in barb_ang] cos_ang1 = [None] * len(ang1) for index in range(len(ang1)): cos_ang1[index] = math.cos(ang1[index]) seg1_x = [i * j for i, j in zip(arrow_len, cos_ang1)] sin_ang1 = [None] * len(ang1) for index in range(len(ang1)): sin_ang1[index] = math.sin(ang1[index]) seg1_y = [i * j for i, j in zip(arrow_len, sin_ang1)] cos_ang2 = [None] * len(ang2) for index in range(len(ang2)): cos_ang2[index] = math.cos(ang2[index]) seg2_x = [i * j for i, j in zip(arrow_len, cos_ang2)] sin_ang2 = [None] * len(ang2) for index in range(len(ang2)): sin_ang2[index] = math.sin(ang2[index]) seg2_y = [i * j for i, j in zip(arrow_len, sin_ang2)] # Set coordinates to create arrow for index in range(len(self.end_x)): point1_x = [i - j for i, j in zip(self.end_x, seg1_x)] point1_y = [i - j for i, j in zip(self.end_y, seg1_y)] point2_x = [i - j for i, j in zip(self.end_x, seg2_x)] point2_y = [i - j for i, j in zip(self.end_y, seg2_y)] # Combine lists to create arrow empty = [None] * len(self.end_x) arrow_x = FigureFactory._flatten(zip(point1_x, self.end_x, point2_x, empty)) arrow_y = FigureFactory._flatten(zip(point1_y, self.end_y, point2_y, empty)) return arrow_x, arrow_y class _Streamline(FigureFactory): """ Refer to FigureFactory.create_streamline() for docstring """ def __init__(self, x, y, u, v, density, angle, arrow_scale, **kwargs): self.x = np.array(x) self.y = np.array(y) self.u = np.array(u) self.v = np.array(v) self.angle = angle self.arrow_scale = arrow_scale self.density = int(30 * density) # Scale similarly to other functions self.delta_x = self.x[1] - self.x[0] self.delta_y = self.y[1] - self.y[0] self.val_x = self.x self.val_y = self.y # Set up spacing self.blank = np.zeros((self.density, self.density)) self.spacing_x = len(self.x) / float(self.density - 1) self.spacing_y = len(self.y) / float(self.density - 1) self.trajectories = [] # Rescale speed onto axes-coordinates self.u = self.u / (self.x[-1] - self.x[0]) self.v = self.v / (self.y[-1] - self.y[0]) self.speed = np.sqrt(self.u ** 2 + self.v ** 2) # Rescale u and v for integrations. self.u *= len(self.x) self.v *= len(self.y) self.st_x = [] self.st_y = [] self.get_streamlines() streamline_x, streamline_y = self.sum_streamlines() arrows_x, arrows_y = self.get_streamline_arrows() def blank_pos(self, xi, yi): """ Set up positions for trajectories to be used with rk4 function. """ return (int((xi / self.spacing_x) + 0.5), int((yi / self.spacing_y) + 0.5)) def value_at(self, a, xi, yi): """ Set up for RK4 function, based on Bokeh's streamline code """ if isinstance(xi, np.ndarray): self.x = xi.astype(np.int) self.y = yi.astype(np.int) else: self.val_x = np.int(xi) self.val_y = np.int(yi) a00 = a[self.val_y, self.val_x] a01 = a[self.val_y, self.val_x + 1] a10 = a[self.val_y + 1, self.val_x] a11 = a[self.val_y + 1, self.val_x + 1] xt = xi - self.val_x yt = yi - self.val_y a0 = a00 * (1 - xt) + a01 * xt a1 = a10 * (1 - xt) + a11 * xt return a0 * (1 - yt) + a1 * yt def rk4_integrate(self, x0, y0): """ RK4 forward and back trajectories from the initial conditions. Adapted from Bokeh's streamline -uses Runge-Kutta method to fill x and y trajectories then checks length of traj (s in units of axes) """ def f(xi, yi): dt_ds = 1. / self.value_at(self.speed, xi, yi) ui = self.value_at(self.u, xi, yi) vi = self.value_at(self.v, xi, yi) return ui * dt_ds, vi * dt_ds def g(xi, yi): dt_ds = 1. / self.value_at(self.speed, xi, yi) ui = self.value_at(self.u, xi, yi) vi = self.value_at(self.v, xi, yi) return -ui * dt_ds, -vi * dt_ds check = lambda xi, yi: (0 <= xi < len(self.x) - 1 and 0 <= yi < len(self.y) - 1) xb_changes = [] yb_changes = [] def rk4(x0, y0, f): ds = 0.01 stotal = 0 xi = x0 yi = y0 xb, yb = self.blank_pos(xi, yi) xf_traj = [] yf_traj = [] while check(xi, yi): xf_traj.append(xi) yf_traj.append(yi) try: k1x, k1y = f(xi, yi) k2x, k2y = f(xi + .5 * ds * k1x, yi + .5 * ds * k1y) k3x, k3y = f(xi + .5 * ds * k2x, yi + .5 * ds * k2y) k4x, k4y = f(xi + ds * k3x, yi + ds * k3y) except IndexError: break xi += ds * (k1x + 2 * k2x + 2 * k3x + k4x) / 6. yi += ds * (k1y + 2 * k2y + 2 * k3y + k4y) / 6. if not check(xi, yi): break stotal += ds new_xb, new_yb = self.blank_pos(xi, yi) if new_xb != xb or new_yb != yb: if self.blank[new_yb, new_xb] == 0: self.blank[new_yb, new_xb] = 1 xb_changes.append(new_xb) yb_changes.append(new_yb) xb = new_xb yb = new_yb else: break if stotal > 2: break return stotal, xf_traj, yf_traj sf, xf_traj, yf_traj = rk4(x0, y0, f) sb, xb_traj, yb_traj = rk4(x0, y0, g) stotal = sf + sb x_traj = xb_traj[::-1] + xf_traj[1:] y_traj = yb_traj[::-1] + yf_traj[1:] if len(x_traj) < 1: return None if stotal > .2: initxb, inityb = self.blank_pos(x0, y0) self.blank[inityb, initxb] = 1 return x_traj, y_traj else: for xb, yb in zip(xb_changes, yb_changes): self.blank[yb, xb] = 0 return None def traj(self, xb, yb): """ Integrate trajectories :param (int) xb: results of passing xi through self.blank_pos :param (int) xy: results of passing yi through self.blank_pos Calculate each trajectory based on rk4 integrate method. """ if xb < 0 or xb >= self.density or yb < 0 or yb >= self.density: return if self.blank[yb, xb] == 0: t = self.rk4_integrate(xb * self.spacing_x, yb * self.spacing_y) if t is not None: self.trajectories.append(t) def get_streamlines(self): """ Get streamlines by building trajectory set. """ for indent in range(self.density // 2): for xi in range(self.density - 2 * indent): self.traj(xi + indent, indent) self.traj(xi + indent, self.density - 1 - indent) self.traj(indent, xi + indent) self.traj(self.density - 1 - indent, xi + indent) self.st_x = [np.array(t[0]) * self.delta_x + self.x[0] for t in self.trajectories] self.st_y = [np.array(t[1]) * self.delta_y + self.y[0] for t in self.trajectories] for index in range(len(self.st_x)): self.st_x[index] = self.st_x[index].tolist() self.st_x[index].append(np.nan) for index in range(len(self.st_y)): self.st_y[index] = self.st_y[index].tolist() self.st_y[index].append(np.nan) def get_streamline_arrows(self): """ Makes an arrow for each streamline. Gets angle of streamline at 1/3 mark and creates arrow coordinates based off of user defined angle and arrow_scale. :param (array) st_x: x-values for all streamlines :param (array) st_y: y-values for all streamlines :param (angle in radians) angle: angle of arrowhead. Default = pi/9 :param (float in [0,1]) arrow_scale: value to scale length of arrowhead Default = .09 :rtype (list, list) arrows_x: x-values to create arrowhead and arrows_y: y-values to create arrowhead """ arrow_end_x = np.empty((len(self.st_x))) arrow_end_y = np.empty((len(self.st_y))) arrow_start_x = np.empty((len(self.st_x))) arrow_start_y = np.empty((len(self.st_y))) for index in range(len(self.st_x)): arrow_end_x[index] = (self.st_x[index] [int(len(self.st_x[index]) / 3)]) arrow_start_x[index] = (self.st_x[index] [(int(len(self.st_x[index]) / 3)) - 1]) arrow_end_y[index] = (self.st_y[index] [int(len(self.st_y[index]) / 3)]) arrow_start_y[index] = (self.st_y[index] [(int(len(self.st_y[index]) / 3)) - 1]) dif_x = arrow_end_x - arrow_start_x dif_y = arrow_end_y - arrow_start_y streamline_ang = np.arctan(dif_y / dif_x) ang1 = streamline_ang + (self.angle) ang2 = streamline_ang - (self.angle) seg1_x = np.cos(ang1) * self.arrow_scale seg1_y = np.sin(ang1) * self.arrow_scale seg2_x = np.cos(ang2) * self.arrow_scale seg2_y = np.sin(ang2) * self.arrow_scale point1_x = np.empty((len(dif_x))) point1_y = np.empty((len(dif_y))) point2_x = np.empty((len(dif_x))) point2_y = np.empty((len(dif_y))) for index in range(len(dif_x)): if dif_x[index] >= 0: point1_x[index] = arrow_end_x[index] - seg1_x[index] point1_y[index] = arrow_end_y[index] - seg1_y[index] point2_x[index] = arrow_end_x[index] - seg2_x[index] point2_y[index] = arrow_end_y[index] - seg2_y[index] else: point1_x[index] = arrow_end_x[index] + seg1_x[index] point1_y[index] = arrow_end_y[index] + seg1_y[index] point2_x[index] = arrow_end_x[index] + seg2_x[index] point2_y[index] = arrow_end_y[index] + seg2_y[index] space = np.empty((len(point1_x))) space[:] = np.nan # Combine arrays into matrix arrows_x = np.matrix([point1_x, arrow_end_x, point2_x, space]) arrows_x = np.array(arrows_x) arrows_x = arrows_x.flatten('F') arrows_x = arrows_x.tolist() # Combine arrays into matrix arrows_y = np.matrix([point1_y, arrow_end_y, point2_y, space]) arrows_y = np.array(arrows_y) arrows_y = arrows_y.flatten('F') arrows_y = arrows_y.tolist() return arrows_x, arrows_y def sum_streamlines(self): """ Makes all streamlines readable as a single trace. :rtype (list, list): streamline_x: all x values for each streamline combined into single list and streamline_y: all y values for each streamline combined into single list """ streamline_x = sum(self.st_x, []) streamline_y = sum(self.st_y, []) return streamline_x, streamline_y class _OHLC(FigureFactory): """ Refer to FigureFactory.create_ohlc_increase() for docstring. """ def __init__(self, open, high, low, close, dates, **kwargs): self.open = open self.high = high self.low = low self.close = close self.empty = [None] * len(open) self.dates = dates self.all_x = [] self.all_y = [] self.increase_x = [] self.increase_y = [] self.decrease_x = [] self.decrease_y = [] self.get_all_xy() self.separate_increase_decrease() def get_all_xy(self): """ Zip data to create OHLC shape OHLC shape: low to high vertical bar with horizontal branches for open and close values. If dates were added, the smallest date difference is calculated and multiplied by .2 to get the length of the open and close branches. If no date data was provided, the x-axis is a list of integers and the length of the open and close branches is .2. """ self.all_y = list(zip(self.open, self.open, self.high, self.low, self.close, self.close, self.empty)) if self.dates is not None: date_dif = [] for i in range(len(self.dates) - 1): date_dif.append(self.dates[i + 1] - self.dates[i]) date_dif_min = (min(date_dif)) / 5 self.all_x = [[x - date_dif_min, x, x, x, x, x + date_dif_min, None] for x in self.dates] else: self.all_x = [[x - .2, x, x, x, x, x + .2, None] for x in range(len(self.open))] def separate_increase_decrease(self): """ Separate data into two groups: increase and decrease (1) Increase, where close > open and (2) Decrease, where close <= open """ for index in range(len(self.open)): if self.close[index] is None: pass elif self.close[index] > self.open[index]: self.increase_x.append(self.all_x[index]) self.increase_y.append(self.all_y[index]) else: self.decrease_x.append(self.all_x[index]) self.decrease_y.append(self.all_y[index]) def get_increase(self): """ Flatten increase data and get increase text :rtype (list, list, list): flat_increase_x: x-values for the increasing trace, flat_increase_y: y=values for the increasing trace and text_increase: hovertext for the increasing trace """ flat_increase_x = FigureFactory._flatten(self.increase_x) flat_increase_y = FigureFactory._flatten(self.increase_y) text_increase = (("Open", "Open", "High", "Low", "Close", "Close", '') * (len(self.increase_x))) return flat_increase_x, flat_increase_y, text_increase def get_decrease(self): """ Flatten decrease data and get decrease text :rtype (list, list, list): flat_decrease_x: x-values for the decreasing trace, flat_decrease_y: y=values for the decreasing trace and text_decrease: hovertext for the decreasing trace """ flat_decrease_x = FigureFactory._flatten(self.decrease_x) flat_decrease_y = FigureFactory._flatten(self.decrease_y) text_decrease = (("Open", "Open", "High", "Low", "Close", "Close", '') * (len(self.decrease_x))) return flat_decrease_x, flat_decrease_y, text_decrease class _Candlestick(FigureFactory): """ Refer to FigureFactory.create_candlestick() for docstring. """ def __init__(self, open, high, low, close, dates, **kwargs): self.open = open self.high = high self.low = low self.close = close if dates is not None: self.x = dates else: self.x = [x for x in range(len(self.open))] self.get_candle_increase() def get_candle_increase(self): """ Separate increasing data from decreasing data. The data is increasing when close value > open value and decreasing when the close value <= open value. """ increase_y = [] increase_x = [] for index in range(len(self.open)): if self.close[index] > self.open[index]: increase_y.append(self.low[index]) increase_y.append(self.open[index]) increase_y.append(self.close[index]) increase_y.append(self.close[index]) increase_y.append(self.close[index]) increase_y.append(self.high[index]) increase_x.append(self.x[index]) increase_x = [[x, x, x, x, x, x] for x in increase_x] increase_x = FigureFactory._flatten(increase_x) return increase_x, increase_y def get_candle_decrease(self): """ Separate increasing data from decreasing data. The data is increasing when close value > open value and decreasing when the close value <= open value. """ decrease_y = [] decrease_x = [] for index in range(len(self.open)): if self.close[index] <= self.open[index]: decrease_y.append(self.low[index]) decrease_y.append(self.open[index]) decrease_y.append(self.close[index]) decrease_y.append(self.close[index]) decrease_y.append(self.close[index]) decrease_y.append(self.high[index]) decrease_x.append(self.x[index]) decrease_x = [[x, x, x, x, x, x] for x in decrease_x] decrease_x = FigureFactory._flatten(decrease_x) return decrease_x, decrease_y class _Distplot(FigureFactory): """ Refer to TraceFactory.create_distplot() for docstring """ def __init__(self, hist_data, histnorm, group_labels, bin_size, curve_type, colors, rug_text, show_hist, show_curve): self.hist_data = hist_data self.histnorm = histnorm self.group_labels = group_labels self.bin_size = bin_size self.show_hist = show_hist self.show_curve = show_curve self.trace_number = len(hist_data) if rug_text: self.rug_text = rug_text else: self.rug_text = [None] * self.trace_number self.start = [] self.end = [] if colors: self.colors = colors else: self.colors = [ "rgb(31, 119, 180)", "rgb(255, 127, 14)", "rgb(44, 160, 44)", "rgb(214, 39, 40)", "rgb(148, 103, 189)", "rgb(140, 86, 75)", "rgb(227, 119, 194)", "rgb(127, 127, 127)", "rgb(188, 189, 34)", "rgb(23, 190, 207)"] self.curve_x = [None] * self.trace_number self.curve_y = [None] * self.trace_number for trace in self.hist_data: self.start.append(min(trace) * 1.) self.end.append(max(trace) * 1.) def make_hist(self): """ Makes the histogram(s) for FigureFactory.create_distplot(). :rtype (list) hist: list of histogram representations """ hist = [None] * self.trace_number for index in range(self.trace_number): hist[index] = dict(type='histogram', x=self.hist_data[index], xaxis='x1', yaxis='y1', histnorm=self.histnorm, name=self.group_labels[index], legendgroup=self.group_labels[index], marker=dict(color=self.colors[index]), autobinx=False, xbins=dict(start=self.start[index], end=self.end[index], size=self.bin_size[index]), opacity=.7) return hist def make_kde(self): """ Makes the kernel density estimation(s) for create_distplot(). This is called when curve_type = 'kde' in create_distplot(). :rtype (list) curve: list of kde representations """ curve = [None] * self.trace_number for index in range(self.trace_number): self.curve_x[index] = [self.start[index] + x * (self.end[index] - self.start[index]) / 500 for x in range(500)] self.curve_y[index] = (scipy.stats.gaussian_kde (self.hist_data[index]) (self.curve_x[index])) if self.histnorm == ALTERNATIVE_HISTNORM: self.curve_y[index] *= self.bin_size[index] for index in range(self.trace_number): curve[index] = dict(type='scatter', x=self.curve_x[index], y=self.curve_y[index], xaxis='x1', yaxis='y1', mode='lines', name=self.group_labels[index], legendgroup=self.group_labels[index], showlegend=False if self.show_hist else True, marker=dict(color=self.colors[index])) return curve def make_normal(self): """ Makes the normal curve(s) for create_distplot(). This is called when curve_type = 'normal' in create_distplot(). :rtype (list) curve: list of normal curve representations """ curve = [None] * self.trace_number mean = [None] * self.trace_number sd = [None] * self.trace_number for index in range(self.trace_number): mean[index], sd[index] = (scipy.stats.norm.fit (self.hist_data[index])) self.curve_x[index] = [self.start[index] + x * (self.end[index] - self.start[index]) / 500 for x in range(500)] self.curve_y[index] = scipy.stats.norm.pdf( self.curve_x[index], loc=mean[index], scale=sd[index]) if self.histnorm == ALTERNATIVE_HISTNORM: self.curve_y[index] *= self.bin_size[index] for index in range(self.trace_number): curve[index] = dict(type='scatter', x=self.curve_x[index], y=self.curve_y[index], xaxis='x1', yaxis='y1', mode='lines', name=self.group_labels[index], legendgroup=self.group_labels[index], showlegend=False if self.show_hist else True, marker=dict(color=self.colors[index])) return curve def make_rug(self): """ Makes the rug plot(s) for create_distplot(). :rtype (list) rug: list of rug plot representations """ rug = [None] * self.trace_number for index in range(self.trace_number): rug[index] = dict(type='scatter', x=self.hist_data[index], y=([self.group_labels[index]] * len(self.hist_data[index])), xaxis='x1', yaxis='y2', mode='markers', name=self.group_labels[index], legendgroup=self.group_labels[index], showlegend=(False if self.show_hist or self.show_curve else True), text=self.rug_text[index], marker=dict(color=self.colors[index], symbol='line-ns-open')) return rug class _Dendrogram(FigureFactory): """Refer to FigureFactory.create_dendrogram() for docstring.""" def __init__(self, X, orientation='bottom', labels=None, colorscale=None, width="100%", height="100%", xaxis='xaxis', yaxis='yaxis'): # TODO: protected until #282 from plotly.graph_objs import graph_objs self.orientation = orientation self.labels = labels self.xaxis = xaxis self.yaxis = yaxis self.data = [] self.leaves = [] self.sign = {self.xaxis: 1, self.yaxis: 1} self.layout = {self.xaxis: {}, self.yaxis: {}} if self.orientation in ['left', 'bottom']: self.sign[self.xaxis] = 1 else: self.sign[self.xaxis] = -1 if self.orientation in ['right', 'bottom']: self.sign[self.yaxis] = 1 else: self.sign[self.yaxis] = -1 (dd_traces, xvals, yvals, ordered_labels, leaves) = self.get_dendrogram_traces(X, colorscale) self.labels = ordered_labels self.leaves = leaves yvals_flat = yvals.flatten() xvals_flat = xvals.flatten() self.zero_vals = [] for i in range(len(yvals_flat)): if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals: self.zero_vals.append(xvals_flat[i]) self.zero_vals.sort() self.layout = self.set_figure_layout(width, height) self.data = graph_objs.Data(dd_traces) def get_color_dict(self, colorscale): """ Returns colorscale used for dendrogram tree clusters. :param (list) colorscale: Colors to use for the plot in rgb format. :rtype (dict): A dict of default colors mapped to the user colorscale. """ # These are the color codes returned for dendrograms # We're replacing them with nicer colors d = {'r': 'red', 'g': 'green', 'b': 'blue', 'c': 'cyan', 'm': 'magenta', 'y': 'yellow', 'k': 'black', 'w': 'white'} default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0])) if colorscale is None: colorscale = [ 'rgb(0,116,217)', # blue 'rgb(35,205,205)', # cyan 'rgb(61,153,112)', # green 'rgb(40,35,35)', # black 'rgb(133,20,75)', # magenta 'rgb(255,65,54)', # red 'rgb(255,255,255)', # white 'rgb(255,220,0)'] # yellow for i in range(len(default_colors.keys())): k = list(default_colors.keys())[i] # PY3 won't index keys if i < len(colorscale): default_colors[k] = colorscale[i] return default_colors def set_axis_layout(self, axis_key): """ Sets and returns default axis object for dendrogram figure. :param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc. :rtype (dict): An axis_key dictionary with set parameters. """ axis_defaults = { 'type': 'linear', 'ticks': 'outside', 'mirror': 'allticks', 'rangemode': 'tozero', 'showticklabels': True, 'zeroline': False, 'showgrid': False, 'showline': True, } if len(self.labels) != 0: axis_key_labels = self.xaxis if self.orientation in ['left', 'right']: axis_key_labels = self.yaxis if axis_key_labels not in self.layout: self.layout[axis_key_labels] = {} self.layout[axis_key_labels]['tickvals'] = \ [zv*self.sign[axis_key] for zv in self.zero_vals] self.layout[axis_key_labels]['ticktext'] = self.labels self.layout[axis_key_labels]['tickmode'] = 'array' self.layout[axis_key].update(axis_defaults) return self.layout[axis_key] def set_figure_layout(self, width, height): """ Sets and returns default layout object for dendrogram figure. """ self.layout.update({ 'showlegend': False, 'autosize': False, 'hovermode': 'closest', 'width': width, 'height': height }) self.set_axis_layout(self.xaxis) self.set_axis_layout(self.yaxis) return self.layout def get_dendrogram_traces(self, X, colorscale): """ Calculates all the elements needed for plotting a dendrogram. :param (ndarray) X: Matrix of observations as array of arrays :param (list) colorscale: Color scale for dendrogram tree clusters :rtype (tuple): Contains all the traces in the following order: (a) trace_list: List of Plotly trace objects for dendrogram tree (b) icoord: All X points of the dendrogram tree as array of arrays with length 4 (c) dcoord: All Y points of the dendrogram tree as array of arrays with length 4 (d) ordered_labels: leaf labels in the order they are going to appear on the plot (e) P['leaves']: left-to-right traversal of the leaves """ # TODO: protected until #282 from plotly.graph_objs import graph_objs d = scs.distance.pdist(X) Z = sch.linkage(d, method='complete') P = sch.dendrogram(Z, orientation=self.orientation, labels=self.labels, no_plot=True) icoord = scp.array(P['icoord']) dcoord = scp.array(P['dcoord']) ordered_labels = scp.array(P['ivl']) color_list = scp.array(P['color_list']) colors = self.get_color_dict(colorscale) trace_list = [] for i in range(len(icoord)): # xs and ys are arrays of 4 points that make up the '∩' shapes # of the dendrogram tree if self.orientation in ['top', 'bottom']: xs = icoord[i] else: xs = dcoord[i] if self.orientation in ['top', 'bottom']: ys = dcoord[i] else: ys = icoord[i] color_key = color_list[i] trace = graph_objs.Scatter( x=np.multiply(self.sign[self.xaxis], xs), y=np.multiply(self.sign[self.yaxis], ys), mode='lines', marker=graph_objs.Marker(color=colors[color_key]) ) try: x_index = int(self.xaxis[-1]) except ValueError: x_index = '' try: y_index = int(self.yaxis[-1]) except ValueError: y_index = '' trace['xaxis'] = 'x' + x_index trace['yaxis'] = 'y' + y_index trace_list.append(trace) return trace_list, icoord, dcoord, ordered_labels, P['leaves'] class _AnnotatedHeatmap(FigureFactory): """ Refer to TraceFactory.create_annotated_heatmap() for docstring """ def __init__(self, z, x, y, annotation_text, colorscale, font_colors, reversescale, **kwargs): from plotly.graph_objs import graph_objs self.z = z if x: self.x = x else: self.x = range(len(z[0])) if y: self.y = y else: self.y = range(len(z)) if annotation_text is not None: self.annotation_text = annotation_text else: self.annotation_text = self.z self.colorscale = colorscale self.reversescale = reversescale self.font_colors = font_colors def get_text_color(self): """ Get font color for annotations. The annotated heatmap can feature two text colors: min_text_color and max_text_color. The min_text_color is applied to annotations for heatmap values < (max_value - min_value)/2. The user can define these two colors. Otherwise the colors are defined logically as black or white depending on the heatmap's colorscale. :rtype (string, string) min_text_color, max_text_color: text color for annotations for heatmap values < (max_value - min_value)/2 and text color for annotations for heatmap values >= (max_value - min_value)/2 """ # Plotly colorscales ranging from a lighter shade to a darker shade colorscales = ['Greys', 'Greens', 'Blues', 'YIGnBu', 'YIOrRd', 'RdBu', 'Picnic', 'Jet', 'Hot', 'Blackbody', 'Earth', 'Electric', 'Viridis'] # Plotly colorscales ranging from a darker shade to a lighter shade colorscales_reverse = ['Reds'] if self.font_colors: min_text_color = self.font_colors[0] max_text_color = self.font_colors[-1] elif self.colorscale in colorscales and self.reversescale: min_text_color = '#000000' max_text_color = '#FFFFFF' elif self.colorscale in colorscales: min_text_color = '#FFFFFF' max_text_color = '#000000' elif self.colorscale in colorscales_reverse and self.reversescale: min_text_color = '#FFFFFF' max_text_color = '#000000' elif self.colorscale in colorscales_reverse: min_text_color = '#000000' max_text_color = '#FFFFFF' elif isinstance(self.colorscale, list): if 'rgb' in self.colorscale[0][1]: min_col = map(int, self.colorscale[0][1].strip('rgb()').split(',')) max_col = map(int, self.colorscale[-1][1].strip('rgb()').split(',')) elif '#' in self.colorscale[0][1]: min_col = FigureFactory._hex_to_rgb(self.colorscale[0][1]) max_col = FigureFactory._hex_to_rgb(self.colorscale[-1][1]) else: min_col = [255, 255, 255] max_col = [255, 255, 255] if (min_col[0]*0.299 + min_col[1]*0.587 + min_col[2]*0.114) > 186: min_text_color = '#000000' else: min_text_color = '#FFFFFF' if (max_col[0]*0.299 + max_col[1]*0.587 + max_col[2]*0.114) > 186: max_text_color = '#000000' else: max_text_color = '#FFFFFF' else: min_text_color = '#000000' max_text_color = '#000000' return min_text_color, max_text_color def get_z_mid(self): """ Get the mid value of z matrix :rtype (float) z_avg: average val from z matrix """ if _numpy_imported and isinstance(self.z, np.ndarray): z_min = np.amin(self.z) z_max = np.amax(self.z) else: z_min = min(min(self.z)) z_max = max(max(self.z)) z_mid = (z_max+z_min) / 2 return z_mid def make_annotations(self): """ Get annotations for each cell of the heatmap with graph_objs.Annotation :rtype (list[dict]) annotations: list of annotations for each cell of the heatmap """ from plotly.graph_objs import graph_objs min_text_color, max_text_color = _AnnotatedHeatmap.get_text_color(self) z_mid = _AnnotatedHeatmap.get_z_mid(self) annotations = [] for n, row in enumerate(self.z): for m, val in enumerate(row): font_color = min_text_color if val < z_mid else max_text_color annotations.append( graph_objs.Annotation( text=str(self.annotation_text[n][m]), x=self.x[m], y=self.y[n], xref='x1', yref='y1', font=dict(color=font_color), showarrow=False)) return annotations class _Table(FigureFactory): """ Refer to TraceFactory.create_table() for docstring """ def __init__(self, table_text, colorscale, font_colors, index, index_title, annotation_offset, **kwargs): from plotly.graph_objs import graph_objs if _pandas_imported and isinstance(table_text, pd.DataFrame): headers = table_text.columns.tolist() table_text_index = table_text.index.tolist() table_text = table_text.values.tolist() table_text.insert(0, headers) if index: table_text_index.insert(0, index_title) for i in range(len(table_text)): table_text[i].insert(0, table_text_index[i]) self.table_text = table_text self.colorscale = colorscale self.font_colors = font_colors self.index = index self.annotation_offset = annotation_offset self.x = range(len(table_text[0])) self.y = range(len(table_text)) def get_table_matrix(self): """ Create z matrix to make heatmap with striped table coloring :rtype (list[list]) table_matrix: z matrix to make heatmap with striped table coloring. """ header = [0] * len(self.table_text[0]) odd_row = [.5] * len(self.table_text[0]) even_row = [1] * len(self.table_text[0]) table_matrix = [None] * len(self.table_text) table_matrix[0] = header for i in range(1, len(self.table_text), 2): table_matrix[i] = odd_row for i in range(2, len(self.table_text), 2): table_matrix[i] = even_row if self.index: for array in table_matrix: array[0] = 0 return table_matrix def get_table_font_color(self): """ Fill font-color array. Table text color can vary by row so this extends a single color or creates an array to set a header color and two alternating colors to create the striped table pattern. :rtype (list[list]) all_font_colors: list of font colors for each row in table. """ if len(self.font_colors) == 1: all_font_colors = self.font_colors*len(self.table_text) elif len(self.font_colors) == 3: all_font_colors = list(range(len(self.table_text))) all_font_colors[0] = self.font_colors[0] for i in range(1, len(self.table_text), 2): all_font_colors[i] = self.font_colors[1] for i in range(2, len(self.table_text), 2): all_font_colors[i] = self.font_colors[2] elif len(self.font_colors) == len(self.table_text): all_font_colors = self.font_colors else: all_font_colors = ['#000000']*len(self.table_text) return all_font_colors def make_table_annotations(self): """ Generate annotations to fill in table text :rtype (list) annotations: list of annotations for each cell of the table. """ from plotly.graph_objs import graph_objs table_matrix = _Table.get_table_matrix(self) all_font_colors = _Table.get_table_font_color(self) annotations = [] for n, row in enumerate(self.table_text): for m, val in enumerate(row): # Bold text in header and index format_text = ('' + str(val) + '' if n == 0 or self.index and m < 1 else str(val)) # Match font color of index to font color of header font_color = (self.font_colors[0] if self.index and m == 0 else all_font_colors[n]) annotations.append( graph_objs.Annotation( text=format_text, x=self.x[m] - self.annotation_offset, y=self.y[n], xref='x1', yref='y1', align="left", xanchor="left", font=dict(color=font_color), showarrow=False)) return annotations