matplotlib iterate subplot axis array through single list
The ax
return value is a numpy array, which can be reshaped, I believe, without any copying of the data. If you use the following, you'll get a linear array that you can iterate over cleanly.
nrow = 1; ncol = 2;
fig, axs = plt.subplots(nrows=nrow, ncols=ncol)
for ax in axs.reshape(-1):
ax.set_ylabel(str(i))
This doesn't hold when ncols and nrows are both 1, since the return value is not an array; you could turn the return value into an array with one element for consistency, though it feels a bit like a cludge:
nrow = 1; ncol = 1;
fig, axs = plt.subplots(nrows=nrow, ncols=nrow)
axs = np.array(axs)
for ax in axs.reshape(-1):
ax.set_ylabel(str(i))
reshape docs.
The argument -1
causes reshape to infer dimensions of the output.
Matplotlib has its own flatten function on axes.
Why don't you try following code?
fig, axes = plt.subplots(2, 3)
for ax in axes.flat:
## do something with instance of 'ax'
I am not sure when it was added, but there is now a squeeze
keyword argument. This makes sure the result is always a 2D numpy array. Turning that into a 1D array is easy:
fig, ax2d = subplots(2, 2, squeeze=False)
axli = ax2d.flatten()
Works for any number of subplots, no trick for single ax, so a little easier than the accepted answer (perhaps squeeze
didn't exist yet back then).
The fig
return value of plt.subplots
has a list of all the axes. To iterate over all the subplots in a figure you can use:
nrow = 2
ncol = 2
fig, axs = plt.subplots(nrow, ncol)
for i, ax in enumerate(fig.axes):
ax.set_ylabel(str(i))
This also works for nrow == ncol == 1
.