I have run the model with LSTM as the first layer successfully. But out of curiosity, I replace LSTM with CuDNNLSTM. But after model.fit, it replied the following error message:
UnknownError: Fail to find the dnn implementation.
[[{{node cu_dnnlstm_5/CudnnRNN}} = CudnnRNN[T=DT_FLOAT, _class=["loc:@training_2/Adam/gradients/cu_dnnlstm_5/CudnnRNN_grad/CudnnRNNBackprop"], direction="unidirectional", dropout=0, input_mode="linear_input", is_training=true, rnn_mode="lstm", seed=87654321, seed2=0, _device="/job:localhost/replica:0/task:0/device:GPU:0"](cu_dnnlstm_5/transpose, cu_dnnlstm_5/ExpandDims_1, cu_dnnlstm_5/ExpandDims_1, cu_dnnlstm_5/concat_1)]]
[[{{node metrics_3/mean_squared_error/Mean_1/_1877}} = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_4852_metrics_3/mean_squared_error/Mean_1", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
I have tried TestCudnnLSTM() on this discussion and pass the test successfully:
Keras version: 2.2.4 Tensorflow version: 1.12.0 Creating Model _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= cu_dnnlstm_1 (CuDNNLSTM) (None, 1000, 1) 16 ================================================================= Total params: 16 Trainable params: 16 Non-trainable params: 0 _________________________________________________________________ None Model compiled
It seems that the problem appears during model fitting. But I don't know exactly what is the problem?
For TensorFlow v2, one solution would be -
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)
Then you can use keras model too -
from tensorflow.keras.models import Model
This solution worked for me, it enables memory growth for only one GPU.