No broadcasting for tf.matmul in TensorFlow

Alessio B picture Alessio B · Jun 27, 2016 · Viewed 7.6k times · Source

I have a problem with which I've been struggling. It is related to tf.matmul() and its absence of broadcasting.

I am aware of a similar issue on https://github.com/tensorflow/tensorflow/issues/216, but tf.batch_matmul() doesn't look like a solution for my case.

I need to encode my input data as a 4D tensor: X = tf.placeholder(tf.float32, shape=(None, None, None, 100)) The first dimension is the size of a batch, the second the number of entries in the batch. You can imagine each entry as a composition of a number of objects (third dimension). Finally, each object is described by a vector of 100 float values.

Note that I used None for the second and third dimensions because the actual sizes may change in each batch. However, for simplicity, let's shape the tensor with actual numbers: X = tf.placeholder(tf.float32, shape=(5, 10, 4, 100))

These are the steps of my computation:

  1. compute a function of each vector of 100 float values (e.g., linear function) W = tf.Variable(tf.truncated_normal([100, 50], stddev=0.1)) Y = tf.matmul(X, W) problem: no broadcasting for tf.matmul() and no success using tf.batch_matmul() expected shape of Y: (5, 10, 4, 50)

  2. applying average pooling for each entry of the batch (over the objects of each entry): Y_avg = tf.reduce_mean(Y, 2) expected shape of Y_avg: (5, 10, 50)

I expected that tf.matmul() would have supported broadcasting. Then I found tf.batch_matmul(), but still it looks like doesn't apply to my case (e.g., W needs to have 3 dimensions at least, not clear why).

BTW, above I used a simple linear function (the weights of which are stored in W). But in my model I have a deep network instead. So, the more general problem I have is automatically computing a function for each slice of a tensor. This is why I expected that tf.matmul() would have had a broadcasting behavior (if so, maybe tf.batch_matmul() wouldn't even be necessary).

Look forward to learning from you! Alessio

Answer

lballes picture lballes · Jun 27, 2016

You could achieve that by reshaping X to shape [n, d], where d is the dimensionality of one single "instance" of computation (100 in your example) and n is the number of those instances in your multi-dimensional object (5*10*4=200 in your example). After reshaping, you can use tf.matmul and then reshape back to the desired shape. The fact that the first three dimensions can vary makes that little tricky, but you can use tf.shape to determine the actual shapes during run time. Finally, you can perform the second step of your computation, which should be a simple tf.reduce_mean over the respective dimension. All in all, it would look like this:

X = tf.placeholder(tf.float32, shape=(None, None, None, 100))
W = tf.Variable(tf.truncated_normal([100, 50], stddev=0.1))
X_ = tf.reshape(X, [-1, 100])
Y_ = tf.matmul(X_, W)
X_shape = tf.gather(tf.shape(X), [0,1,2]) # Extract the first three dimensions
target_shape = tf.concat(0, [X_shape, [50]])
Y = tf.reshape(Y_, target_shape)
Y_avg = tf.reduce_mean(Y, 2)