How to use torch.stack function

조수호 picture 조수호 · Sep 12, 2018 · Viewed 33.1k times · Source

I have a question about torch.stack

I have 2 tensors, a.shape=(2, 3, 4) and b.shape=(2, 3). How to stack them without in-place operation?

Answer

arjoonn picture arjoonn · Sep 12, 2018

Stacking requires same number of dimensions. One way would be to unsqueeze and stack. For example:

a.size()  # 2, 3, 4
b.size()  # 2, 3
b = torch.unsqueeze(b, dim=2)  # 2, 3, 1
# torch.unsqueeze(b, dim=-1) does the same thing

torch.stack([a, b], dim=2)  # 2, 3, 5