Display MNIST image using matplotlib

buydadip picture buydadip · Feb 20, 2017 · Viewed 41.4k times · Source

I am using tensorflow to import some MNIST input data. I followed this tutorial...https://www.tensorflow.org/get_started/mnist/beginners

I am importing them as so...

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

I want to be able to display any of the images from the training set. I know the location of the images is mnist.train.images, so I try to access the first images and display it like so...

with tf.Session() as sess:
    #access first image
    first_image = mnist.train.images[0]

    first_image = np.array(first_image, dtype='uint8')
    pixels = first_image.reshape((28, 28))
    plt.imshow(pixels, cmap='gray')

I a attempt to convert the image to a 28 by 28 numpy array because I know that each image is 28 by 28 pixels.

However, when I run the code all I get is the following...

enter image description here

Clearly I am doing something wrong. When I print out the matrix, everything seems to look good, but I think I am incorrectly reshaping it.

Answer

Vinh Trieu picture Vinh Trieu · Nov 17, 2017

Here is the complete code for showing image using matplotlib

from matplotlib import pyplot as plt
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot = True)
first_image = mnist.test.images[0]
first_image = np.array(first_image, dtype='float')
pixels = first_image.reshape((28, 28))
plt.imshow(pixels, cmap='gray')
plt.show()