Pytorch: Convert FloatTensor into DoubleTensor

N8_Coder picture N8_Coder · Jun 23, 2017 · Viewed 28k times · Source

I have 2 numpy arrays, which I convert into tensors to use the TensorDataset object.

import torch.utils.data as data_utils

X = np.zeros((100,30))
Y = np.zeros((100,30))

train = data_utils.TensorDataset(torch.from_numpy(X).double(), torch.from_numpy(Y))
train_loader = data_utils.DataLoader(train, batch_size=50, shuffle=True)

when I do:

for batch_idx, (data, target) in enumerate(train_loader):
    data, target = Variable(data), Variable(target)
    optimizer.zero_grad()
    output = model(data)               # error occurs here

I get the fallowing error:

TypeError: addmm_ received an invalid combination of arguments - got (int, int, torch.DoubleTensor, torch.FloatTensor), but expected one of: [...]
* (float beta, float alpha, torch.DoubleTensor mat1, torch.DoubleTensor mat2) didn't match because some of the arguments have invalid types: (int, int, torch.DoubleTensor, torch.FloatTensor)
* (float beta, float alpha, torch.SparseDoubleTensor mat1, torch.DoubleTensor mat2) didn't match because some of the arguments have invalid types: (int, int, torch.DoubleTensor, torch.FloatTensor)

The last error comes from:

output.addmm_(0, 1, input, weight.t())

As you see in my code I tried converting the tensor by using .double() - but this did not work. Why is he casting one array into a FloatTensor object and the other into a DoubleTensor? Any ideas?

Answer

mbpaulus picture mbpaulus · Jun 23, 2017

Your numpy arrays are 64-bit floating point and will be converted to torch.DoubleTensor standardly. Now, if you use them with your model, you'll need to make sure that your model parameters are also Double. Or you need to make sure, that your numpy arrays are cast as Float, because model parameters are standardly cast as float.

Hence, do either of the following:

data_utils.TensorDataset(torch.from_numpy(X).float(), torch.from_numpy(Y).float())

or do:

model.double()

Depeding, if you want to cast your model parameters, inputs and targets as Float or as Double.