Table Of Contents

Search

Enter search terms or a module, class or function name.

Source code for swimpy.plot

"""
SWIM related plotting functions and the generic plot_function decorator.

Standalone functions to create plots for SWIM input/output. They are used
throught the SWIMpy package but collected here to enable reuse.

All functions should accept an optional ax=None argument to plot to. This
argument will always be converted to a valid axes (i.e. plt.gca() if None).

Project.method or Project.plugin.methods that implement plots should use the
``plot_function`` decorator to allow generic functionality.
"""
from __future__ import print_function, absolute_import
import sys
import tempfile
import functools
import datetime as dt
import warnings
import itertools
import os

import numpy as np
import pandas as pd
from modelmanager.settings import FunctionInfo, parse_settings
import matplotlib as mpl
# needed to use matplotlib in django browser
if len(sys.argv) > 1 and sys.argv[1] == 'browser':
    mpl.use('Agg')
import matplotlib.pyplot as plt


[docs]def save(output, figure=None, tight_layout=True, **savekwargs): """Convenience function to set figure size and save a matplotlib figure. Arguments --------- output : str Path to save figure to. Extension determines format. figure : matplotlib.Figure object, optional Defaults to current figure. tight_layout : bool Apply ``pyplot.tight_layout`` to figure reducing figure whitespace. **savekwargs : Any keyword argument parsed to ``figure.savefig()`` method. Special keys: ``size`` : len 2 tuple, size in mm. """ figure = figure or plt.gcf() assert type(output) == str, 'output %r must be string path.' size = savekwargs.pop('size', None) if size: assert len(size) == 2, 'size must be (width, height) not %r' % size mmpi = 25.4 figure.set_size_inches(size[0]/mmpi, size[1]/mmpi) # (width, hight) if tight_layout: # tight_layout doesnt work with irregular grid plots, lets try try: figure.tight_layout() except RuntimeError: warnings.warn('Figure doesnt allow tight layout.') figure.savefig(output, **savekwargs) return
[docs]def plot_waterbalance(series, ax=None, **barkwargs): """Bar plot of water balance terms. Arguments --------- df : pd.Series Values to plot. Index will be used as x labels. ax : plt.Axes, optional An axes to plot to. If None given, the current axes are used. **barkwargs : plt.bar keyword arguments. Returns ------- bars """ ax = ax or plt.gca() bars = series.plot.bar(ax=ax, **barkwargs) ax.set_ylabel('mm per year') ax.set_title('Catchment mean water balance') return bars
[docs]def plot_temperature_range(series, ax=None, minmax=[], **linekwargs): """Plot temperature with optional min-max range.""" assert len(minmax) in [0, 2] ax = ax or plt.gca() if minmax: kw = dict(alpha=0.3, color='k') mmfill = ax.fill_between(_index_to_timestamp(series.index), minmax[0], minmax[1], **kw) line = ax.plot(_index_to_timestamp(series.index), series, **linekwargs) ax.set_ylabel('Temperature [C]') ax.set_xlabel('Time') return (line, mmfill) if minmax else line
[docs]def plot_precipitation_bars(series, ax=None, **barkwargs): """Plot precipitation as bars.""" ax = ax or plt.gca() if hasattr(series.index, 'to_timestamp'): freqstr = series.index.freqstr.split('-')[0][-1].lower() # last letter width = {'a': 365, 'm': series.index.days_in_month, 'd': 1} barkwargs.setdefault('width', width[freqstr]*0.8) bars = ax.bar(_index_to_timestamp(series.index), series, **barkwargs) ax.set_ylabel('Precipitation [mm]') ax.set_xlabel('Time') return bars
[docs]def plot_discharge(series, ax=None, **linekwargs): """Plot several discharge lines.""" ax = ax or plt.gca() lines = ax.plot(_index_to_timestamp(series.index), series, **linekwargs) ax.set_ylabel('Discharge [m$^3$s$^{-1}$]') ax.set_xlabel('Time') return lines
[docs]def plot_flow_duration(series, ax=None, **linekwargs): """Plot flow duration lines.""" ax = ax or plt.gca() lines = ax.plot(series.index, series, **linekwargs) ax.set_ylabel('Discharge [m$^3$s$^{-1}$]') ax.set_xlabel('% of time exceeded') return lines
[docs]def plot_flow_duration_polar(series, axes=None, percentilestep=10, freq='m', colormap='jet_r', **barkw): """Bins the values in series into 100/percentilestep steps and displays the relative frequency per month or day-of-year (freq= m|f) on a polar bar chart of the year. See in action and more docs in: :meth:`swimpy.output.station_daily_discharge.plot_flow_duration_polar` """ assert percentilestep <= 50 axes = axes or plt.gca() # exchange axes for polar axes apo = axes.get_position() axes.set_axis_off() axes = plt.gcf().add_axes(apo, projection='polar') ssorted = series.dropna().sort_values() n = len(ssorted) dom = 'dayofyear' if freq.lower() == 'd' else 'month' ib = 0 count = {} for b in range(percentilestep, 100+1, percentilestep): # get percentile bin iib = min(int(round(n*b/100.)), n) bin = ssorted.iloc[ib:iib] # count values in percentile bin count[b] = bin.groupby(getattr(bin.index, dom)).count()/float(n) ib = iib ntb = 365 if freq.lower() == 'd' else 12 theta = np.arange(ntb) * 2 * np.pi / ntb cm = plt.get_cmap(colormap) countdf = pd.DataFrame(count).loc[:ntb] countdf /= countdf.max().max() countdf.fillna(0, inplace=True) for b, col in countdf.items(): bars = axes.bar(x=theta, height=[percentilestep]*len(theta), width=2*np.pi/ntb, bottom=b-percentilestep, color=cm(col), edgecolor='none') axes.set_theta_zero_location('N') axes.set_theta_direction(-1) axes.set_rmin(0) axes.set_rmax(100) axes.grid(False) month_names = [dt.date(2000, i, 1).strftime('%B') for i in range(1, 13)] tcks, tcklbls = plt.xticks(np.arange(12)*2*np.pi/12, month_names) rots = (list(range(0, -91, -30)) + list(range(60, -61, -30)) + list(range(90, 30-1, -30))) for l, r in zip(tcklbls, rots): l.set_rotation(r) axes.set_yticks([50]) axes.set_yticklabels(['50%']) axes.grid(True, axis='y') return axes
[docs]def plot_objective_scatter(performances, selected=None, selected_color='r', ax=None, **scatterkw): '''Plot scatter against all objectives combinations in a stepped subplot. Arguments --------- performances : pd.DataFrame DataFrame with performance values. selected : dict-like Highlight selected point(s). selected_color : matplotlib.color spec | str Color for the selected points. ''' objectives = performances.columns # calculate limits nticks = 5 margin = 0.1 # fraction of median stats = performances.describe() rng = (stats.loc['max'] - stats.loc['min']) * margin limits = {'max': stats.loc['max'] + rng, 'min': stats.loc['min'] - rng} extend_limits = True naxes = len(objectives) - 1 if ax: f = ax.get_figure() axs = f.get_axes() if len(axs) == naxes**2: ax = np.array(axs).reshape(naxes, naxes) extend_limits = True else: f.clear() ax = None else: f = plt.figure() if ax is None: ax = f.subplots(naxes, naxes, squeeze=False) f.subplots_adjust(hspace=0.1, wspace=0.1) for i, n in enumerate(objectives[1:]): # row for ii, nn in enumerate(objectives[:-1]): # column if ii <= i: ax[i][ii].scatter( performances[nn], performances[n], **scatterkw) if selected is not None: ax[i][ii].scatter(selected[nn], selected[n], c=selected_color) # axis adjustments xticks = mpl.ticker.MaxNLocator(nbins=nticks, prune='upper') ax[i][ii].xaxis.set_major_locator(xticks) yticks = mpl.ticker.MaxNLocator(nbins=nticks, prune='upper') ax[i][ii].yaxis.set_major_locator(yticks) if limits is not None: if extend_limits: exl, eyl = ax[i][ii].get_xlim(), ax[i][ii].get_ylim() limits['min'][n] = min(limits['min'][n], eyl[0]) limits['max'][n] = max(limits['max'][n], eyl[1]) limits['min'][nn] = min(limits['min'][nn], exl[0]) limits['max'][nn] = max(limits['max'][nn], exl[1]) ax[i][ii].set_ylim(limits['min'][n], limits['max'][n]) ax[i][ii].set_xlim(limits['min'][nn], limits['max'][nn]) else: # remove unused axes ax[i][ii].set_frame_on(False) ax[i][ii].set_xticks([]) ax[i][ii].set_yticks([]) # labels if i == naxes - 1: ax[i][ii].set_xlabel(nn) else: ax[i][ii].set_xticklabels([]) if ii == 0: ax[i][0].set_ylabel(n) else: ax[i][ii].set_yticklabels([]) return ax
def _index_to_timestamp(index): """Convert a pandas index to timestamps if needed. Needed to parse pandas PeriodIndex to pyplot plotting functions.""" return index.to_timestamp() if hasattr(index, 'to_timestamp') else index
[docs]def default_colors(n, colors=[]): """Return n default colors starting with given colors if any.""" defc = colors + plt.rcParams['axes.prop_cycle'].by_key()['color'] return list(itertools.islice(itertools.cycle(defc), n))
[docs]def plot_function(function): """Decorator for the PlotFunction class. This factory function is required to return a function rather than an object if PlotFunction was used as a decorator alone. """ pf = PlotFunction(function) @functools.wraps(function) def f(*args, **kwargs): return pf(*args, **kwargs) # add signiture to beginning of docstrign if PY2 if sys.version_info < (3, 0): sig = '%s(%s)\n' % (pf.finfo.name, pf.finfo.signiture) f.__doc__ = sig + pf.finfo.doc # add generic docs docs = (pf.finfo.doc or '') + PlotFunction.ax_output_docs if 'runs' in pf.finfo.optional_arguments: docs += PlotFunction.runs_docs function.__doc__ = docs # attach original function f.decorated_function = function return f
[docs]class PlotFunction(object): """A a class that enforces and handles generic plot function tasks. To be used in plot_function decorator. - enforces name starting with 'plot'. - enforces ax=None arugment and ensures a valid axes is always parsed. - enforces to accept ``**kwargs``. - enforces the method instance (first function argement) to either be a project or have a project attribute - reads savefig_defaults from project - enforces output=None argument and allows saving of figure to file with that may either be string path or a dict with kwargs to save. - displays interactive plot if executed from commandline. - saves current figure to a temp path when executed in browser API. - allows running function with a run instance if the function has a run argument. The argument input is normalised (see additional_docs). """ ax_output_docs = """ Plot function arguments: ------------------------ ax : <matplotlib.Axes>, optional Axes to plot to. Default is the current axes if None. output : str path | dict Path to writeout or dict of keywords to parse to save_or_show.""" runs_docs = """ runs : Run | runID | iterable of Run/runID | QuerySet | (str), optional Show plot for runs if they have the same method or plugin.method. If a string is parsed, the current project will also be plot with the string as label. The runs argument is transformed to (run QuerySet, index) to enable per run stylingy. """ def __init__(self, function): # enforce arugments self.finfo = FunctionInfo(function) fname = self.finfo.name # takes care of decorated functions self.decorated_function = self.finfo.function oargs = dict(zip(self.finfo.optional_arguments, self.finfo.defaults)) errmsg = fname + ' has no optional argument "%s=None".' for a in ['output', 'ax']: assert a in oargs and oargs[a] is None, errmsg % a errmsg = fname + ' should start with "plot".' assert fname.startswith('plot') or fname == '__call__', errmsg assert self.finfo.kwargs, fname+' must accept **kwargs.' # attributes assigned during call callattr = ('project instance args kwargs ax figure savekwargs ' 'output runs ax_parsed') for a in callattr.split(): setattr(self, a, None) return def _infer_project(self): from .project import Project self.instance = self.args[0] # assumes method if isinstance(self.instance, Project): self.project = self.instance elif hasattr(self.instance, 'project'): self.project = self.instance.project else: em = '%s is not a Project instance or has a project attribute.' raise AttributeError(em % self.instance) return def _interpret_args(self, args, kwargs): self.args = args self.kwargs = kwargs self.runs = kwargs.pop('runs', None) self.output = kwargs.get('output') pax = kwargs.get('ax', None) self.ax_parsed = pax is not None self.ax = pax or plt.gca() self.kwargs['ax'] = self.ax self.figure = self.ax.get_figure() self._infer_project() self.savekwargs = {} self.savekwargs.update(self.project.save_figure_defaults) if type(self.output) is dict: op = self.output.pop('output', None) self.savekwargs.update(self.output) self.output = op return def __call__(self, *args, **kwargs): self._interpret_args(args, kwargs) if self.runs: result = self._plot_runs() else: result = self.decorated_function(*self.args, **self.kwargs) if self.output: save(self.output, self.figure, **self.savekwargs) # display if from commandline or browser api (if ax parsed it cant be # from the commandline or browser so must be a subcall) elif sys.argv[0].endswith('swimpy') and not self.ax_parsed: result = self._display_figure() return result def _plot_runs(self): ispi = self.instance.__class__ != self.project.__class__ piname = self.instance.__class__.__name__ # project if not plugin # extract stings as labels for current current_label = None if hasattr(self.runs, '__iter__'): current_label = [r for r in self.runs if type(r) == str] if current_label: self.runs = [r for r in self.runs if r not in current_label] current_label = current_label[0] # transform runs to QuerySet runs = self.project.browser.runs.get_runs(self.runs) # plot current if current lable parsed if current_label: res = self.decorated_function(*self.args, label=current_label, **self.kwargs) result = [res] else: result = [] for i, r in enumerate(runs): try: piinstance = r if ispi: # if project.plugin piinstance = getattr(r, piname) pmeth = getattr(piinstance, self.finfo.name) except AttributeError: m = self.finfo.name if ispi else piname+'.'+self.finfo.name print('%s doesnt have a %s method.' % (r, m)) continue rkw = self.kwargs.copy() rkw['runs'] = (runs, i) rkw.setdefault('label', str(r)) # call method with different instance as first argument as # decorated_function is unbound rre = pmeth.decorated_function(piinstance, *self.args[1:], **rkw) result.append(rre) # make sure a legend is shown if not already if self.ax.get_legend() is None: self.ax.legend() return result def _display_figure(self): # tight_layout doesnt work with irregular grid plots, lets try try: self.figure.tight_layout() except RuntimeError: warnings.warn('Figure doesnt allow tight layout.') # in Django API if len(sys.argv) > 1 and sys.argv[1] == 'browser': tmpdir = self.project.browser.settings.tmpfilesdir tf, imgpath = tempfile.mkstemp(suffix='.png', dir=tmpdir) os.close(tf) save(imgpath, self.figure, **self.savekwargs) self.figure.clear() return imgpath else: # in CLI plt.show(block=True) return
[docs]def plot_many(functions, **kwall): """Plot mutiple plots in a grid. Arguments --------- functions : list (of lists) of callables or tuple(callable, dict) Plotting functions to call with either just defaults or with keyword arguments if a tuple with the callable and a dict is parsed. List in the list subdivides rows in the subplot grid. All functions must accept the ax keyword. kwall : Keywords applied to all functions. Returns ------- axes """ def norm_f(f): if callable(f): return (f, kwall) elif type(f) == tuple: assert len(f) == 2 and callable(f[0]) and type(f[1]) == dict f[1].update(kwall) return f elif type(f) == list: return [norm_f(i) for i in f] else: raise TypeError('functions entry %r is neither a ' % (f,) + 'callable or a tuple (callable, dict).') def call_f(f, ax, kw): try: f(ax=ax, **kw) except Exception: import traceback etype, ex, tb = sys.exc_info() raise etype('While calling %s, the below error occurred:\n\n' % f + str(ex) + ':\n' + ''.join(traceback.format_tb(tb))) assert type(functions) == list nrow = len(functions) ncol = max([1]+[len(r) for r in functions if type(r) == list]) grid = (nrow, ncol) axes = [] normf = norm_f(functions) for irow, f in enumerate(normf): if type(f) == list: colspan = int(len(f)/float(ncol)) for icol, (fu, kw) in enumerate(f): axes += [plt.subplot2grid(grid, (irow, icol), colspan=colspan)] call_f(fu, axes[-1], kw) else: axes += [plt.subplot2grid(grid, (irow, 0), colspan=ncol)] call_f(f[0], axes[-1], f[1]) return axes
[docs]class plot_summary(object): """A plugin to enable project.plot_summary and run.plot.""" plugin = ['__call__'] def __init__(self, project, host=None): host = host or project self.host = host self.project = project return def _getattr(self, address): """Deep getattr with a dotted address.""" a = self.host try: for i in address.split('.'): a = getattr(a, i) except AttributeError: print('Cant find %s, will be ignored.') return return a def _convert(self, l): """Recursively convert functions entries into callables or skip.""" if type(l) == str: try: return self._getattr(l) except (AttributeError, IOError): return None elif type(l) == tuple: assert len(l) == 2 and type(l[0]) == str and type(l[1]) == dict e = self._convert(l[0]) return (e, l[1]) if e else None elif type(l) == list: return [i for i in [self._convert(i) for i in l] if i] @parse_settings def __call__(self, functions=None, output=None, runs=None, ax=None, **kw): """Summary plot. Arguments --------- functions : list (of lists) of callables or tuple(callable, dict) Plotting functions to call with either just defaults or with keyword arguments if a tuple with the callable and a dict is parsed. List in the list subdivides rows in the subplot grid. All functions must accept the ax keyword. Mainly intended to be set in settings with the ``plot_summary_functions`` variable. kw : Keywords to all subplots. Returns ------- list : Flat list of axes that were created. """ plot_function = PlotFunction(self.__call__) pfkw = dict(ax=ax, output=output, runs=runs) plot_function._interpret_args([self], pfkw) normed_functions = self._convert(functions) if normed_functions: axes = plot_many(normed_functions, runs=runs, **kw) else: raise RuntimeError('No valid plots found in %s' % self.host) # remove axes legends and add figure legend legends = [l for l in [a.get_legend() for a in axes] if l is not None] if legends: plot_function.figure.legend(*axes[0].get_legend_handles_labels()) [l.remove() for l in legends] # tight layout plot_function.figure.tight_layout() # output/display if plot_function.output: save(plot_function.output, plot_function.figure, **plot_function.savekwargs) # display if from commandline or browser api elif sys.argv[0].endswith('swimpy'): return plot_function._display_figure() return axes
Scroll To Top