I'm currently training a Feedforward Neural Network on the MNIST data set using Keras. I'm loading the data set using the format
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
but then I only want to train my model using digit 0 and 4 not all of them. How do I select only the 2 digits? I am fairly new to python and can figure out how to filter the mnist dataset...
Y_train
and Y_test
give you the labels of images, you can use them with numpy.where
to filter out a subset of labels with 0's and 4's. All your variables are numpy arrays, so you can simply do;
import numpy as np
train_filter = np.where((Y_train == 0 ) | (Y_train == 4))
test_filter = np.where((Y_test == 0) | (Y_test == 4))
and you can use these filters to get the subset of arrays by index.
X_train, Y_train = X_train[train_filter], Y_train[train_filter]
X_test, Y_test = X_test[test_filter], Y_test[test_filter]
If you are interested in more than 2 labels, the syntax can get hairy with where and or. So you can also use numpy.isin
to create masks.
train_mask = np.isin(Y_train, [0, 4])
test_mask = np.isin(Y_test, [0, 4])
You can use these masks for boolean indexing, same as before.