How do I select only a specific digit from the MNIST dataset provided by Keras?

sherry picture sherry · Jul 6, 2018 · Viewed 7.3k times · Source

I'm currently training a Feedforward Neural Network on the MNIST data set using Keras. I'm loading the data set using the format

(X_train, Y_train), (X_test, Y_test) = mnist.load_data()

but then I only want to train my model using digit 0 and 4 not all of them. How do I select only the 2 digits? I am fairly new to python and can figure out how to filter the mnist dataset...

Answer

umutto picture umutto · Jul 6, 2018

Y_train and Y_test give you the labels of images, you can use them with numpy.where to filter out a subset of labels with 0's and 4's. All your variables are numpy arrays, so you can simply do;

import numpy as np

train_filter = np.where((Y_train == 0 ) | (Y_train == 4))
test_filter = np.where((Y_test == 0) | (Y_test == 4))

and you can use these filters to get the subset of arrays by index.

X_train, Y_train = X_train[train_filter], Y_train[train_filter]
X_test, Y_test = X_test[test_filter], Y_test[test_filter]

If you are interested in more than 2 labels, the syntax can get hairy with where and or. So you can also use numpy.isin to create masks.

train_mask = np.isin(Y_train, [0, 4])
test_mask = np.isin(Y_test, [0, 4])

You can use these masks for boolean indexing, same as before.