How to add group labels for bar charts in matplotlib?

Varicus picture Varicus · Oct 4, 2013 · Viewed 11.5k times · Source

I want to plot data of the following form using matplotlib's bar plot feature:

data = {'Room A':
           {'Shelf 1':
               {'Milk': 10,
                'Water': 20},
            'Shelf 2':
               {'Sugar': 5,
                'Honey': 6}
           },
        'Room B':
           {'Shelf 1':
               {'Wheat': 4,
                'Corn': 7},
            'Shelf 2':
               {'Chicken': 2,
                'Cow': 1}
           }
       }

The bar chart is supposed to look like this. The bar groups should be visible from the labels on the x axis. Is there any way to do this with matplotlib?

Answer

Varicus picture Varicus · Oct 8, 2013

Since I could not find a built-in solution for this in matplotlib, I coded my own:

#!/usr/bin/env python

from matplotlib import pyplot as plt

def mk_groups(data):
    try:
        newdata = data.items()
    except:
        return

    thisgroup = []
    groups = []
    for key, value in newdata:
        newgroups = mk_groups(value)
        if newgroups is None:
            thisgroup.append((key, value))
        else:
            thisgroup.append((key, len(newgroups[-1])))
            if groups:
                groups = [g + n for n, g in zip(newgroups, groups)]
            else:
                groups = newgroups
    return [thisgroup] + groups

def add_line(ax, xpos, ypos):
    line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
                      transform=ax.transAxes, color='black')
    line.set_clip_on(False)
    ax.add_line(line)

def label_group_bar(ax, data):
    groups = mk_groups(data)
    xy = groups.pop()
    x, y = zip(*xy)
    ly = len(y)
    xticks = range(1, ly + 1)

    ax.bar(xticks, y, align='center')
    ax.set_xticks(xticks)
    ax.set_xticklabels(x)
    ax.set_xlim(.5, ly + .5)
    ax.yaxis.grid(True)

    scale = 1. / ly
    for pos in xrange(ly + 1):
        add_line(ax, pos * scale, -.1)
    ypos = -.2
    while groups:
        group = groups.pop()
        pos = 0
        for label, rpos in group:
            lxpos = (pos + .5 * rpos) * scale
            ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes)
            add_line(ax, pos * scale, ypos)
            pos += rpos
        add_line(ax, pos * scale, ypos)
        ypos -= .1

if __name__ == '__main__':
    data = {'Room A':
               {'Shelf 1':
                   {'Milk': 10,
                    'Water': 20},
                'Shelf 2':
                   {'Sugar': 5,
                    'Honey': 6}
               },
            'Room B':
               {'Shelf 1':
                   {'Wheat': 4,
                    'Corn': 7},
                'Shelf 2':
                   {'Chicken': 2,
                    'Cow': 1}
               }
           }
    fig = plt.figure()
    ax = fig.add_subplot(1,1,1)
    label_group_bar(ax, data)
    fig.subplots_adjust(bottom=0.3)
    fig.savefig('label_group_bar_example.png')

The mk_groups function takes a dictionary (or anything with an items() method, like collections.OrderedDict) and converts it to a data format that is then used to create the chart. It is basically a list of the form:

[ [(label, bars_to_span), ...], ..., [(tick_label, bar_value), ...] ]

The add_line function creates a vertical line in the subplot at the specified positions (in axes coordinates).

The label_group_bar function takes a dictionary and creates the bar chart in the subplot with the labels beneath. The result from the example then looks like this.

Easier or better solutions and suggestions are still very much appreciated.

bar chart with groups