populating matplotlib subplots through a loop and a function

Charles picture Charles · Dec 19, 2014 · Viewed 34.7k times · Source

I need to draw subplots of a figure through loop iterations; each iteration calls a function defined in another module (=another py file), which draws a pair of subplots. Here is what I tried -- and alas does not work:

1) Before the loop, create a figure with the adequate number of rows, and 2 columns:

 import matplotlib.pyplot as plt     
 fig, axarr = plt.subplots(nber_rows,2)

2) Inside the loop, at iteration number iter_nber, call on the function drawing each subplot:

 fig, axarr = module.graph_function(fig,axarr,iter_nber,some_parameters, some_data)

3) The function in question is basically like this; each iteration creates a pair of subplots on the same row:

 def graph_function(fig,axarr,iter_nber,some_parameters, some_data):

     axarr[iter_nber,1].plot(--some plotting 1--)
     axarr[iter_nber,2].plot(--some plotting 2--)

     return fig,axarr

This does not work. I end up with an empty figure at the end of the loop. I have tried various combinations of the above, like leaving only axarr in the function's return argument, to no avail. Obviously I do not understand the logic of this figure and its subplots.

Any suggestions much appreciated.

Answer

Joe Kington picture Joe Kington · Dec 19, 2014

The code you've posted seems largely correct. Other than the indexing, as @hitzg mentioned, nothing you're doing looks terribly out of the ordinary.

However, it doesn't make much sense to return the figure and axes array from your plotting function. (If you need access to the figure object, you can always get it through ax.figure.) It won't change anything to pass them in and return them, though.

Here's a quick example of the type of thing it sounds like you're trying to do. Maybe it helps clear some confusion?

import numpy as np
import matplotlib.pyplot as plt

def main():
    nrows = 3
    fig, axes = plt.subplots(nrows, 2)

    for row in axes:
        x = np.random.normal(0, 1, 100).cumsum()
        y = np.random.normal(0, 0.5, 100).cumsum()
        plot(row, x, y)

    plt.show()

def plot(axrow, x, y):
    axrow[0].plot(x, color='red')
    axrow[1].plot(y, color='green')

main()

enter image description here