What is the most efficient way to plot 3d array in Python?
For example:
volume = np.random.rand(512, 512, 512)
where array items represent grayscale color of each pixel.
The following code works too slow:
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.gca(projection='3d')
volume = np.random.rand(20, 20, 20)
for x in range(len(volume[:, 0, 0])):
for y in range(len(volume[0, :, 0])):
for z in range(len(volume[0, 0, :])):
ax.scatter(x, y, z, c = tuple([volume[x, y, z], volume[x, y, z], volume[x, y, z], 1]))
plt.show()
For better performance, avoid calling ax.scatter
multiple times, if possible.
Instead, pack all the x
,y
,z
coordinates and colors into 1D arrays (or
lists), then call ax.scatter
once:
ax.scatter(x, y, z, c=volume.ravel())
The problem (in terms of both CPU time and memory) grows as size**3
, where size
is the side length of the cube.
Moreover, ax.scatter
will try to render all size**3
points without regard to
the fact that most of those points are obscured by those on the outer
shell.
It would help to reduce the number of points in volume
-- perhaps by
summarizing or resampling/interpolating it in some way -- before rendering it.
We can also reduce the CPU and memory required from O(size**3)
to O(size**2)
by only plotting the outer shell:
import functools
import itertools as IT
import numpy as np
import scipy.ndimage as ndimage
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def cartesian_product_broadcasted(*arrays):
"""
http://stackoverflow.com/a/11146645/190597 (senderle)
"""
broadcastable = np.ix_(*arrays)
broadcasted = np.broadcast_arrays(*broadcastable)
dtype = np.result_type(*arrays)
rows, cols = functools.reduce(np.multiply, broadcasted[0].shape), len(broadcasted)
out = np.empty(rows * cols, dtype=dtype)
start, end = 0, rows
for a in broadcasted:
out[start:end] = a.reshape(-1)
start, end = end, end + rows
return out.reshape(cols, rows).T
# @profile # used with `python -m memory_profiler script.py` to measure memory usage
def main():
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection='3d')
size = 512
volume = np.random.rand(size, size, size)
x, y, z = cartesian_product_broadcasted(*[np.arange(size, dtype='int16')]*3).T
mask = ((x == 0) | (x == size-1)
| (y == 0) | (y == size-1)
| (z == 0) | (z == size-1))
x = x[mask]
y = y[mask]
z = z[mask]
volume = volume.ravel()[mask]
ax.scatter(x, y, z, c=volume, cmap=plt.get_cmap('Greys'))
plt.show()
if __name__ == '__main__':
main()
But note that even when plotting only the outer shell, to achieve a plot with
size=512
we still need around 1.3 GiB of memory. Also beware that even if you have enough total memory but, due to a lack of RAM, the program uses swap space, then the overall speed of the program will
slow down dramatically. If you find yourself in this situation, then the only solution is to find a smarter way to render an acceptable image using fewer points, or to buy more RAM.