plotting 3d scatter in matplotlib

user248237 picture user248237 · Mar 30, 2011 · Viewed 17.6k times · Source

I have a collection of Nx3 matrices in scipy/numpy and I'd like to make a 3 dimensional scatter of it, where the X and Y axes are determined by the values of first and second columns of the matrix, the height of each bar is the third column in the matrix, and the number of bars is determined by N.

Each matrix represents a different data group and I want each to be plotted with a different color, and then set a legend for the entire figure.

I have the following code:

fig = pylab.figure()
s = plt.subplot(1, 1, 1)
colors = ['k', "#B3C95A", 'b', '#63B8FF', 'g', "#FF3300",
          'r', 'k']
ax = Axes3D(fig)
plots = []
index = 0

for data, curr_color in zip(datasets, colors):
    p = ax.scatter(log2(data[:, 0]), log2(data[:, 1]),
                   log2(data[:, 2]), c=curr_color, label=my_labels[index])

    s.legend()
    index += 1

    plots.append(p)

    ax.set_zlim3d([-1, 9])
    ax.set_ylim3d([-1, 9])
    ax.set_xlim3d([-1, 9])

The issue is that ax.scatter plots things with a transparency and I'd like that remove. Also, I'd like to set the xticks and yticks and zticks -- how can I do that?

Finally, the legend call does not appear, even though I am calling label="" for each scatter call. How can I get the legend to appear?

thanks very much for your help.

Answer

Daan picture Daan · Apr 4, 2011

Try replacing 'ax.scatter' with ax.plot', possibly with the 'o' parameter to get similar circles. This fixes the transparency and the legend.

import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt
from numpy.random import random

mpl.rcParams['legend.fontsize'] = 10

fig = plt.figure(1)
fig.clf()
ax = Axes3D(fig)
datasets = random((8,100,3))*512
my_labels = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']

colors = ['k', "#B3C95A", 'b', '#63B8FF', 'g', "#FF3300",
          'r', 'k']
index = 0
for data, curr_color in zip(datasets, colors):
    ax.plot(np.log2(data[:, 0]), np.log2(data[:, 1]), 
                   np.log2(data[:, 2]), 'o', c=curr_color, label=my_labels[index])
    index += 1

ax.set_zlim3d([-1, 9])
ax.set_ylim3d([-1, 9])
ax.set_xlim3d([-1, 9])

ax.set_xticks(range(0,11))
ax.set_yticks([1,2,8])
ax.set_zticks(np.arange(0,9,.5))

ax.legend(loc = 'upper left')
    
plt.draw()

plt.show()

I added a few lines and tweaks to get some sample data and get the rest of your demo working. I assume you'll be able to get it to work.

Setting the ticks requires the August 2010 update to mplot3d as described here. I got the latest mplot3d from Sourceforge. I'm not quite sure if Matplotlib 1.0.1 contains this latest update as I'm still running Python 2.6 with Matplotlib 1.0.0.

Edit

A quick and dirty dummy plot for the legends while keeping the 3d transparency effect you get from scatter:

index = 0
for data, curr_color in zip(datasets, colors):
    ax.scatter(np.log2(data[:, 0]), np.log2(data[:, 1]), 
                   np.log2(data[:, 2]), 'o', c=curr_color, label=my_labels[index])
    ax.plot([], [], 'o', c = curr_color, label=my_labels[index])                    
    index += 1