Keras Concatenate Layers: Difference between different types of concatenate functions

Lim Kaizhuo picture Lim Kaizhuo · Aug 1, 2018 · Viewed 9.8k times · Source

I just recently started playing around with Keras and got into making custom layers. However, I am rather confused by the many different types of layers with slightly different names but with the same functionality.

For example, there are 3 different forms of the concatenate function from https://keras.io/layers/merge/ and https://www.tensorflow.org/api_docs/python/tf/keras/backend/concatenate

keras.layers.Concatenate(axis=-1)
keras.layers.concatenate(inputs, axis=-1)
tf.keras.backend.concatenate()

I know the 2nd one is used for functional API but what is the difference between the 3? The documentation seems a bit unclear on this.

Also, for the 3rd one, I have seen a code that does this below. Why must there be the line ._keras_shape after the concatenation?

# Concatenate the summed atom and bond features
atoms_bonds_features = K.concatenate([atoms, summed_bond_features], axis=-1)

# Compute fingerprint
atoms_bonds_features._keras_shape = (None, max_atoms, num_atom_features + num_bond_features)

Lastly, under keras.layers, there always seems to be 2 duplicates. For example, Add() and add(), and so on.

Answer

Daniel Möller picture Daniel Möller · Aug 1, 2018

First, the backend: tf.keras.backend.concatenate()

Backend functions are supposed to be used "inside" layers. You'd only use this in Lambda layers, custom layers, custom loss functions, custom metrics, etc.

It works directly on "tensors".

It's not the choice if you're not going deep on customizing. (And it was a bad choice in your example code -- See details at the end).

If you dive deep into keras code, you will notice that the Concatenate layer uses this function internally:

import keras.backend as K
class Concatenate(_Merge):  
    #blablabla   
    def _merge_function(self, inputs):
        return K.concatenate(inputs, axis=self.axis)
    #blablabla

Then, the Layer: keras.layers.Concatenate(axis=-1)

As any other keras layers, you instantiate and call it on tensors.

Pretty straighforward:

#in a functional API model:
inputTensor1 = Input(shape) #or some tensor coming out of any other layer   
inputTensor2 = Input(shape2) #or some tensor coming out of any other layer

#first parentheses are creating an instance of the layer
#second parentheses are "calling" the layer on the input tensors
outputTensor = keras.layers.Concatenate(axis=someAxis)([inputTensor1, inputTensor2])

This is not suited for sequential models, unless the previous layer outputs a list (this is possible but not common).


Finally, the concatenate function from the layers module: keras.layers.concatenate(inputs, axis=-1)

This is not a layer. This is a function that will return the tensor produced by an internal Concatenate layer.

The code is simple:

def concatenate(inputs, axis=-1, **kwargs):
   #blablabla
   return Concatenate(axis=axis, **kwargs)(inputs)

Older functions

In Keras 1, people had functions that were meant to receive "layers" as input and return an output "layer". Their names were related to the merge word.

But since Keras 2 doesn't mention or document these, I'd probably avoid using them, and if old code is found, I'd probably update it to a proper Keras 2 code.


Why the _keras_shape word?

This backend function was not supposed to be used in high level codes. The coder should have used a Concatenate layer.

atoms_bonds_features = Concatenate(axis=-1)([atoms, summed_bond_features])   
#just this line is perfect

Keras layers add the _keras_shape property to all their output tensors, and Keras uses this property for infering the shapes of the entire model.

If you use any backend function "outside" a layer or loss/metric, your output tensor will lack this property and an error will appear telling _keras_shape doesn't exist.

The coder is creating a bad workaround by adding the property manually, when it should have been added by a proper keras layer. (This may work now, but in case of keras updates this code will break while proper codes will remain ok)