tabular legend layout for matplotlib
Expanding on The Dude's answer, I've tried to turn this into a copy-paste solution by creating a function (see/copy further below) that automates the generation of a table legend such that you only need to add the following instead of ax.legend()
to your plot:
tablelegend(ax, ncol=3, bbox_to_anchor=(1, 1),
row_labels=['$i=2$', '$i=3$'],
col_labels=['$j=1$', '$j=2$', '$j=3$'],
title_label='$f_{i,j}$')
row_labels
, col_labels
and title_label
are all optional, such that you can e.g. create a legend table with a column header but without a row header.
Full example usage
import numpy as np
import matplotlib.pyplot as plt
fig = plt.figure()
ax = plt.gca()
ax.plot(range(10), np.random.randn(10), "r:", label='$i=1$, $j=1$')
ax.plot(range(10), np.random.randn(10), "g:", label='$i=2$, $j=1$')
ax.plot(range(10), np.random.randn(10), "b:", label='$i=3$, $j=1$')
ax.plot(range(10), np.random.randn(10), "r.", label='$i=1$, $j=2$')
ax.plot(range(10), np.random.randn(10), "g.", label='$i=2$, $j=2$')
ax.plot(range(10), np.random.randn(10), "b.", label='$i=3$, $j=2$')
ax.plot(range(10), np.random.randn(10), "r^", label='$i=1$, $j=3$')
ax.plot(range(10), np.random.randn(10), "g^", label='$i=2$, $j=3$')
ax.plot(range(10), np.random.randn(10), "b^", label='$i=3$, $j=3$')
tablelegend(ax, ncol=3, bbox_to_anchor=(1, 1),
row_labels=['$i=1$', '$i=2$', '$i=3$'],
col_labels=['$j=1$', '$j=2$', '$j=3$'],
title_label='$f_{i,j}$')
The tablelegend
function
import matplotlib.legend as mlegend
from matplotlib.patches import Rectangle
def tablelegend(ax, col_labels=None, row_labels=None, title_label="", *args, **kwargs):
"""
Place a table legend on the axes.
Creates a legend where the labels are not directly placed with the artists,
but are used as row and column headers, looking like this:
title_label | col_labels[1] | col_labels[2] | col_labels[3]
-------------------------------------------------------------
row_labels[1] |
row_labels[2] | <artists go there>
row_labels[3] |
Parameters
----------
ax : `matplotlib.axes.Axes`
The artist that contains the legend table, i.e. current axes instant.
col_labels : list of str, optional
A list of labels to be used as column headers in the legend table.
`len(col_labels)` needs to match `ncol`.
row_labels : list of str, optional
A list of labels to be used as row headers in the legend table.
`len(row_labels)` needs to match `len(handles) // ncol`.
title_label : str, optional
Label for the top left corner in the legend table.
ncol : int
Number of columns.
Other Parameters
----------------
Refer to `matplotlib.legend.Legend` for other parameters.
"""
#################### same as `matplotlib.axes.Axes.legend` #####################
handles, labels, extra_args, kwargs = mlegend._parse_legend_args([ax], *args, **kwargs)
if len(extra_args):
raise TypeError('legend only accepts two non-keyword arguments')
if col_labels is None and row_labels is None:
ax.legend_ = mlegend.Legend(ax, handles, labels, **kwargs)
ax.legend_._remove_method = ax._remove_legend
return ax.legend_
#################### modifications for table legend ############################
else:
ncol = kwargs.pop('ncol')
handletextpad = kwargs.pop('handletextpad', 0 if col_labels is None else -2)
title_label = [title_label]
# blank rectangle handle
extra = [Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor='none', linewidth=0)]
# empty label
empty = [""]
# number of rows infered from number of handles and desired number of columns
nrow = len(handles) // ncol
# organise the list of handles and labels for table construction
if col_labels is None:
assert nrow == len(row_labels), "nrow = len(handles) // ncol = %s, but should be equal to len(row_labels) = %s." % (nrow, len(row_labels))
leg_handles = extra * nrow
leg_labels = row_labels
elif row_labels is None:
assert ncol == len(col_labels), "ncol = %s, but should be equal to len(col_labels) = %s." % (ncol, len(col_labels))
leg_handles = []
leg_labels = []
else:
assert nrow == len(row_labels), "nrow = len(handles) // ncol = %s, but should be equal to len(row_labels) = %s." % (nrow, len(row_labels))
assert ncol == len(col_labels), "ncol = %s, but should be equal to len(col_labels) = %s." % (ncol, len(col_labels))
leg_handles = extra + extra * nrow
leg_labels = title_label + row_labels
for col in range(ncol):
if col_labels is not None:
leg_handles += extra
leg_labels += [col_labels[col]]
leg_handles += handles[col*nrow:(col+1)*nrow]
leg_labels += empty * nrow
# Create legend
ax.legend_ = mlegend.Legend(ax, leg_handles, leg_labels, ncol=ncol+int(row_labels is not None), handletextpad=handletextpad, **kwargs)
ax.legend_._remove_method = ax._remove_legend
return ax.legend_
Not a very easy question but I figured it out. The trick I use is to initialize an empty rectangle which acts as a handle. These additional empty handles are used to construct the table. I get rid of any excessive space using handletextpad
:
import numpy
import pylab
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
fig = plt.figure()
ax = fig.add_subplot(111)
im1 ,= ax.plot(range(10), pylab.randn(10), "r--")
im2 ,= ax.plot(range(10), pylab.randn(10), "g--")
im3 ,= ax.plot(range(10), pylab.randn(10), "b--")
im4 ,= ax.plot(range(10), pylab.randn(10), "r.")
im5 ,= ax.plot(range(10), pylab.randn(10), "g.")
im6 ,= ax.plot(range(10), pylab.randn(10), "b.")
im7 ,= ax.plot(range(10), pylab.randn(10), "r^")
im8 ,= ax.plot(range(10), pylab.randn(10), "g^")
im9 ,= ax.plot(range(10), pylab.randn(10), "b^")
# create blank rectangle
extra = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor='none', linewidth=0)
#Create organized list containing all handles for table. Extra represent empty space
legend_handle = [extra, extra, extra, extra, extra, im1, im2, im3, extra, im4, im5, im6, extra, im7, im8, im9]
#Define the labels
label_row_1 = [r"$f_{i,j}$", r"$i = 1$", r"$i = 2$", r"$i = 3$"]
label_j_1 = [r"$j = 1$"]
label_j_2 = [r"$j = 2$"]
label_j_3 = [r"$j = 3$"]
label_empty = [""]
#organize labels for table construction
legend_labels = numpy.concatenate([label_row_1, label_j_1, label_empty * 3, label_j_2, label_empty * 3, label_j_3, label_empty * 3])
#Create legend
ax.legend(legend_handle, legend_labels,
loc = 9, ncol = 4, shadow = True, handletextpad = -2)
plt.show()