.. _api-sankey_demo_old:

api example code: sankey_demo_old.py
====================================



.. plot:: /home/tcaswell/source/p/matplotlib/doc/mpl_examples/api/sankey_demo_old.py

::

    """
    ==========================
    Yet another Sankey diagram
    ==========================
    
    This example showcases a more complex sankey diagram.
    """
    
    from __future__ import print_function
    
    __author__ = "Yannick Copin <ycopin@ipnl.in2p3.fr>"
    __version__ = "Time-stamp: <10/02/2010 16:49 ycopin@lyopc548.in2p3.fr>"
    
    import numpy as np
    
    
    def sankey(ax,
               outputs=[100.], outlabels=None,
               inputs=[100.], inlabels='',
               dx=40, dy=10, outangle=45, w=3, inangle=30, offset=2, **kwargs):
        """Draw a Sankey diagram.
    
        outputs: array of outputs, should sum up to 100%
        outlabels: output labels (same length as outputs),
        or None (use default labels) or '' (no labels)
        inputs and inlabels: similar for inputs
        dx: horizontal elongation
        dy: vertical elongation
        outangle: output arrow angle [deg]
        w: output arrow shoulder
        inangle: input dip angle
        offset: text offset
        **kwargs: propagated to Patch (e.g., fill=False)
    
        Return (patch,[intexts,outtexts]).
        """
        import matplotlib.patches as mpatches
        from matplotlib.path import Path
    
        outs = np.absolute(outputs)
        outsigns = np.sign(outputs)
        outsigns[-1] = 0  # Last output
    
        ins = np.absolute(inputs)
        insigns = np.sign(inputs)
        insigns[0] = 0  # First input
    
        assert sum(outs) == 100, "Outputs don't sum up to 100%"
        assert sum(ins) == 100, "Inputs don't sum up to 100%"
    
        def add_output(path, loss, sign=1):
            # Arrow tip height
            h = (loss/2 + w) * np.tan(np.radians(outangle))
            move, (x, y) = path[-1]  # Use last point as reference
            if sign == 0:  # Final loss (horizontal)
                path.extend([(Path.LINETO, [x + dx, y]),
                             (Path.LINETO, [x + dx, y + w]),
                             (Path.LINETO, [x + dx + h, y - loss/2]),  # Tip
                             (Path.LINETO, [x + dx, y - loss - w]),
                             (Path.LINETO, [x + dx, y - loss])])
                outtips.append((sign, path[-3][1]))
            else:  # Intermediate loss (vertical)
                path.extend([(Path.CURVE4, [x + dx/2, y]),
                             (Path.CURVE4, [x + dx, y]),
                             (Path.CURVE4, [x + dx, y + sign*dy]),
                             (Path.LINETO, [x + dx - w, y + sign*dy]),
                             # Tip
                             (Path.LINETO, [
                              x + dx + loss/2, y + sign*(dy + h)]),
                             (Path.LINETO, [x + dx + loss + w, y + sign*dy]),
                             (Path.LINETO, [x + dx + loss, y + sign*dy]),
                             (Path.CURVE3, [x + dx + loss, y - sign*loss]),
                             (Path.CURVE3, [x + dx/2 + loss, y - sign*loss])])
                outtips.append((sign, path[-5][1]))
    
        def add_input(path, gain, sign=1):
            h = (gain / 2) * np.tan(np.radians(inangle))  # Dip depth
            move, (x, y) = path[-1]  # Use last point as reference
            if sign == 0:  # First gain (horizontal)
                path.extend([(Path.LINETO, [x - dx, y]),
                             (Path.LINETO, [x - dx + h, y + gain/2]),  # Dip
                             (Path.LINETO, [x - dx, y + gain])])
                xd, yd = path[-2][1]  # Dip position
                indips.append((sign, [xd - h, yd]))
            else:  # Intermediate gain (vertical)
                path.extend([(Path.CURVE4, [x - dx/2, y]),
                             (Path.CURVE4, [x - dx, y]),
                             (Path.CURVE4, [x - dx, y + sign*dy]),
                             # Dip
                             (Path.LINETO, [
                              x - dx - gain / 2, y + sign*(dy - h)]),
                             (Path.LINETO, [x - dx - gain, y + sign*dy]),
                             (Path.CURVE3, [x - dx - gain, y - sign*gain]),
                             (Path.CURVE3, [x - dx/2 - gain, y - sign*gain])])
                xd, yd = path[-4][1]  # Dip position
                indips.append((sign, [xd, yd + sign*h]))
    
        outtips = []  # Output arrow tip dir. and positions
        urpath = [(Path.MOVETO, [0, 100])]  # 1st point of upper right path
        lrpath = [(Path.LINETO, [0, 0])]  # 1st point of lower right path
        for loss, sign in zip(outs, outsigns):
            add_output(sign >= 0 and urpath or lrpath, loss, sign=sign)
    
        indips = []  # Input arrow tip dir. and positions
        llpath = [(Path.LINETO, [0, 0])]  # 1st point of lower left path
        ulpath = [(Path.MOVETO, [0, 100])]  # 1st point of upper left path
        for gain, sign in reversed(list(zip(ins, insigns))):
            add_input(sign <= 0 and llpath or ulpath, gain, sign=sign)
    
        def revert(path):
            """A path is not just revertable by path[::-1] because of Bezier
            curves."""
            rpath = []
            nextmove = Path.LINETO
            for move, pos in path[::-1]:
                rpath.append((nextmove, pos))
                nextmove = move
            return rpath
    
        # Concatenate subpathes in correct order
        path = urpath + revert(lrpath) + llpath + revert(ulpath)
    
        codes, verts = zip(*path)
        verts = np.array(verts)
    
        # Path patch
        path = Path(verts, codes)
        patch = mpatches.PathPatch(path, **kwargs)
        ax.add_patch(patch)
    
        if False:  # DEBUG
            print("urpath", urpath)
            print("lrpath", revert(lrpath))
            print("llpath", llpath)
            print("ulpath", revert(ulpath))
            xs, ys = zip(*verts)
            ax.plot(xs, ys, 'go-')
    
        # Labels
    
        def set_labels(labels, values):
            """Set or check labels according to values."""
            if labels == '':  # No labels
                return labels
            elif labels is None:  # Default labels
                return ['%2d%%' % val for val in values]
            else:
                assert len(labels) == len(values)
                return labels
    
        def put_labels(labels, positions, output=True):
            """Put labels to positions."""
            texts = []
            lbls = output and labels or labels[::-1]
            for i, label in enumerate(lbls):
                s, (x, y) = positions[i]  # Label direction and position
                if s == 0:
                    t = ax.text(x + offset, y, label,
                                ha=output and 'left' or 'right', va='center')
                elif s > 0:
                    t = ax.text(x, y + offset, label, ha='center', va='bottom')
                else:
                    t = ax.text(x, y - offset, label, ha='center', va='top')
                texts.append(t)
            return texts
    
        outlabels = set_labels(outlabels, outs)
        outtexts = put_labels(outlabels, outtips, output=True)
    
        inlabels = set_labels(inlabels, ins)
        intexts = put_labels(inlabels, indips, output=False)
    
        # Axes management
        ax.set_xlim(verts[:, 0].min() - dx, verts[:, 0].max() + dx)
        ax.set_ylim(verts[:, 1].min() - dy, verts[:, 1].max() + dy)
        ax.set_aspect('equal', adjustable='datalim')
    
        return patch, [intexts, outtexts]
    
    
    if __name__ == '__main__':
    
        import matplotlib.pyplot as plt
    
        outputs = [10., -20., 5., 15., -10., 40.]
        outlabels = ['First', 'Second', 'Third', 'Fourth', 'Fifth', 'Hurray!']
        outlabels = [s + '\n%d%%' % abs(l) for l, s in zip(outputs, outlabels)]
    
        inputs = [60., -25., 15.]
    
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1, xticks=[], yticks=[], title="Sankey diagram")
    
        patch, (intexts, outtexts) = sankey(ax, outputs=outputs,
                                            outlabels=outlabels, inputs=inputs,
                                            inlabels=None)
        outtexts[1].set_color('r')
        outtexts[-1].set_fontweight('bold')
    
        plt.show()
    

Keywords: python, matplotlib, pylab, example, codex (see :ref:`how-to-search-examples`)