Drawing a correlation graph in matplotlib

Yuval Adam picture Yuval Adam · Nov 16, 2011 · Viewed 34.7k times · Source

Suppose I have a data set of discrete vectors with n=2:

DATA = [
    ('a', 4),
    ('b', 5),
    ('c', 5),
    ('d', 4),
    ('e', 2),
    ('f', 5),
]

How can I plot that data set with matplotlib so as to visualize any correlation between the two variables?

Any simple code examples would be great.

Answer

Yann picture Yann · Nov 16, 2011

Joe Kington has the correct answer, but your DATA probably is more complicated that is represented. It might have multiple values at 'a'. The way Joe builds the x axis values is quick but would only work for a list of unique values. There may be a faster way to do this, but this how I accomplished it:

import matplotlib.pyplot as plt

def assignIDs(list):
    '''Take a list of strings, and for each unique value assign a number.
    Returns a map for "unique-val"->id.
    '''
    sortedList = sorted(list)

    #taken from
    #http://stackoverflow.com/questions/480214/how-do-you-remove-duplicates-from-a-list-in-python-whilst-preserving-order/480227#480227
    seen = set()
    seen_add = seen.add
    uniqueList =  [ x for x in sortedList if x not in seen and not seen_add(x)]

    return  dict(zip(uniqueList,range(len(uniqueList))))

def plotData(inData,color):
    x,y = zip(*inData)

    xMap = assignIDs(x)
    xAsInts = [xMap[i] for i in x]


    plt.scatter(xAsInts,y,color=color)
    plt.xticks(xMap.values(),xMap.keys())


DATA = [
    ('a', 4),
    ('b', 5),
    ('c', 5),
    ('d', 4),
    ('e', 2),
    ('f', 5),
]


DATA2 = [
    ('a', 3),
    ('b', 4),
    ('c', 4),
    ('d', 3),
    ('e', 1),
    ('f', 4),
    ('a', 5),
    ('b', 7),
    ('c', 7),
    ('d', 6),
    ('e', 4),
    ('f', 7),
]

plotData(DATA,'blue')
plotData(DATA2,'red')

plt.gcf().savefig("correlation.png")

My DATA2 set has two values for every x axis value. It's plotted in red below: enter image description here

EDIT

The question you asked is very broad. I searched 'correlation', and Wikipedia had a good discussion on Pearson's product-moment coefficient, which characterizes the slope of a linear fit. Keep in mind that this value is only a guide, and in no way predicts whether or not a linear fit is a reasonable assumption, see the notes in the above page on correlation and linearity. Here is an updated plotData method, which uses numpy.linalg.lstsq to do linear regression and numpy.corrcoef to calculate Pearson's R:

import matplotlib.pyplot as plt
import numpy as np

def plotData(inData,color):
    x,y = zip(*inData)

    xMap = assignIDs(x)
    xAsInts = np.array([xMap[i] for i in x])

    pearR = np.corrcoef(xAsInts,y)[1,0]
    # least squares from:
    # http://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.lstsq.html
    A = np.vstack([xAsInts,np.ones(len(xAsInts))]).T
    m,c = np.linalg.lstsq(A,np.array(y))[0]

    plt.scatter(xAsInts,y,label='Data '+color,color=color)
    plt.plot(xAsInts,xAsInts*m+c,color=color,
             label="Fit %6s, r = %6.2e"%(color,pearR))
    plt.xticks(xMap.values(),xMap.keys())
    plt.legend(loc=3)

The new figure is: enter image description here

Also flattening each direction and looking at the individual distributions might be useful, and their are examples of doing this in matplotlib: enter image description here

If a linear approximation is useful, which you can determine qualitatively by just looking at the fit, you might want to subtract out this trend before flatting the y direction. This would help show that you have a Gaussian random distribution about a linear trend.