PyTorch RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed

Shani Gamrian picture Shani Gamrian · Aug 19, 2017 · Viewed 9.4k times · Source

I’m trying to create a basic binary classifier in Pytorch that classifies whether my player plays on the right or the left side in the game Pong. The input is an 1x42x42 image and the label is my player's side (right = 1 or left = 2). The code:

class Net(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

net = Net(42 * 42, 100, 2)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer_net = torch.optim.Adam(net.parameters(), 0.001)
net.train()

while True:
    state = get_game_img()
    state = torch.from_numpy(state)

    # right = 1, left = 2
    current_side = get_player_side()
    target = torch.LongTensor(current_side)
    x = Variable(state.view(-1, 42 * 42))
    y = Variable(target)
    optimizer_net.zero_grad()
    y_pred = net(x)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()

The error I get:

  File "train.py", line 109, in train
    loss = criterion(y_pred, y)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 206, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/modules/loss.py", line 321, in forward
    self.weight, self.size_average)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/functional.py", line 533, in cross_entropy
    return nll_loss(log_softmax(input), target, weight, size_average)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/functional.py", line 501, in nll_loss
    return f(input, target)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/_functions/thnn/auto.py", line 41, in forward
    output, *self.additional_args)
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed.  at /py/conda-bld/pytorch_1493676237139/work/torch/lib/THNN/generic/ClassNLLCriterion.c:57

Answer

Jing picture Jing · Aug 24, 2017

For most of deeplearning library, target(or label) should start from 0.

It means that your target should be in the range of [0,n) with n-classes.