Setting
As already mentioned in the title, I got a problem with my custom loss function, when trying to load the saved model. My loss looks as follows:
def weighted_cross_entropy(weights):
weights = K.variable(weights)
def loss(y_true, y_pred):
y_pred = K.clip(y_pred, K.epsilon(), 1-K.epsilon())
loss = y_true * K.log(y_pred) * weights
loss = -K.sum(loss, -1)
return loss
return loss
weighted_loss = weighted_cross_entropy([0.1,0.9])
So during training, I used the weighted_loss
function as loss function and everything worked well. When training is finished I save the model as .h5
file with the standard model.save
function from keras API.
Problem
When I am trying to load the model via
model = load_model(path,custom_objects={"weighted_loss":weighted_loss})
I am getting a ValueError
telling me that the loss is unknown.
Error
The error message looks as follows:
File "...\predict.py", line 29, in my_script
"weighted_loss": weighted_loss})
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 419, in load_model
model = _deserialize_model(f, custom_objects, compile)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\saving.py", line 312, in _deserialize_model
sample_weight_mode=sample_weight_mode)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\engine\training.py", line 139, in compile
loss_function = losses.get(loss)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 133, in get
return deserialize(identifier)
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\losses.py", line 114, in deserialize
printable_module_name='loss function')
File "...\Continuum\anaconda3\envs\processing\lib\site-packages\keras\utils\generic_utils.py", line 165, in deserialize_keras_object
':' + function_name)
ValueError: Unknown loss function:loss
Questions
How can I fix this problem? May it be possible that the reason for that is my wrapped loss definition? So keras
doesn't know, how to handle the weights
variable?
Your loss function's name is loss
(i.e. def loss(y_true, y_pred):
). Therefore, when loading back the model you need to specify 'loss'
as its name:
model = load_model(path, custom_objects={'loss': weighted_loss})