How can I multiply a vector and a matrix in tensorflow without reshaping?

Mr_and_Mrs_D picture Mr_and_Mrs_D · Apr 7, 2017 · Viewed 18.8k times · Source


import numpy as np
a = np.array([1, 2, 1])
w = np.array([[.5, .6], [.7, .8], [.7, .8]])

print(, w))
# [ 2.6  3. ] # plain nice old matrix multiplication n x (n, m) -> m

import tensorflow as tf

a = tf.constant(a, dtype=tf.float64)
w = tf.constant(w)

with tf.Session() as sess:
    print(tf.matmul(a, w).eval())

results in:

C:\_\Python35\python.exe C:/Users/MrD/.PyCharm2017.1/config/scratches/
[ 2.6  3. ]
# bunch of errors in windows...
Traceback (most recent call last):
  File "C:\_\Python35\lib\site-packages\tensorflow\python\framework\", line 671, in _call_cpp_shape_fn_impl
    input_tensors_as_shapes, status)
  File "C:\_\Python35\lib\", line 66, in __exit__
  File "C:\_\Python35\lib\site-packages\tensorflow\python\framework\", line 466, in raise_exception_on_not_ok_status
tensorflow.python.framework.errors_impl.InvalidArgumentError: Shape must be rank 2 but is rank 1 for 'MatMul' (op: 'MatMul') with input shapes: [3], [3,2].

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:/Users/MrD/.PyCharm2017.1/config/scratches/", line 14, in <module>
    print(tf.matmul(a, w).eval())
  File "C:\_\Python35\lib\site-packages\tensorflow\python\ops\", line 1765, in matmul
    a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)
  File "C:\_\Python35\lib\site-packages\tensorflow\python\ops\", line 1454, in _mat_mul
    transpose_b=transpose_b, name=name)
  File "C:\_\Python35\lib\site-packages\tensorflow\python\framework\", line 763, in apply_op
  File "C:\_\Python35\lib\site-packages\tensorflow\python\framework\", line 2329, in create_op
  File "C:\_\Python35\lib\site-packages\tensorflow\python\framework\", line 1717, in set_shapes_for_outputs
    shapes = shape_func(op)
  File "C:\_\Python35\lib\site-packages\tensorflow\python\framework\", line 1667, in call_with_requiring
    return call_cpp_shape_fn(op, require_shape_fn=True)
  File "C:\_\Python35\lib\site-packages\tensorflow\python\framework\", line 610, in call_cpp_shape_fn
    debug_python_shape_fn, require_shape_fn)
  File "C:\_\Python35\lib\site-packages\tensorflow\python\framework\", line 676, in _call_cpp_shape_fn_impl
    raise ValueError(err.message)
ValueError: Shape must be rank 2 but is rank 1 for 'MatMul' (op: 'MatMul') with input shapes: [3], [3,2].

Process finished with exit code 1

(not sure why the same exception is raised inside its handling)

The solution suggested in Tensorflow exception with matmul is reshaping the vector to a matrix but this leads to needlessly complicated code - is there still no other way to multiply a vector with a matrix?

Incidentally using expand_dims (as suggested in the link above) with default arguments raises a ValueError - that's not mentioned in the docs and defeats the purpose of having a default argument.


dsalaj picture dsalaj · Sep 21, 2017

tf.einsum gives you the ability to do exactly what you need in concise and intuitive form:

with tf.Session() as sess:
    print(tf.einsum('n,nm->m', a, w).eval())
    # [ 2.6  3. ] 

You even get to write your comment explicitly n x (n, m) -> m. It is more readable and intuitive in my opinion.

My favorite use case is when you want to multiply a batch of matrices with a weight vector:

n_in = 10
n_step = 6
input = tf.placeholder(dtype=tf.float32, shape=(None, n_step, n_in))
weights = tf.Variable(tf.truncated_normal((n_in, 1), stddev=1.0/np.sqrt(n_in)))
Y_predict = tf.einsum('ijk,kl->ijl', input, weights)
# (?, 6, 1)

So you can easily multiply weights over all batches with no transformations or duplication. This you can not do by expanding dimensions like in other answer. So you avoid the tf.matmul requirement to have matching dimensions for batch and other outer dimensions:

The inputs must, following any transpositions, be tensors of rank >= 2 where the inner 2 dimensions specify valid matrix multiplication arguments, and any further outer dimensions match.