Optimizer and scheduler for BERT fine-tuning

geekazoid picture geekazoid · Feb 7, 2020 · Viewed 7.2k times · Source

I'm trying to fine-tune a model with BERT (using transformers library), and I'm a bit unsure about the optimizer and scheduler.

First, I understand that I should use transformers.AdamW instead of Pytorch's version of it. Also, we should use a warmup scheduler as suggested in the paper, so the scheduler is created using get_linear_scheduler_with_warmup function from transformers package.

The main questions I have are:

  1. get_linear_scheduler_with_warmup should be called with the warm up. Is it ok to use 2 for warmup out of 10 epochs?
  2. When should I call scheduler.step()? If I do after train, the learning rate is zero for the first epoch. Should I call it for each batch?

Am I doing something wrong with this?

from transformers import AdamW
from transformers.optimization import get_linear_scheduler_with_warmup

N_EPOCHS = 10

model = BertGRUModel(finetune_bert=True,...)
num_training_steps = N_EPOCHS+1
num_warmup_steps = 2
warmup_proportion = float(num_warmup_steps) / float(num_training_steps)  # 0.1

optimizer = AdamW(model.parameters())
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([class_weights[1]]))


scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=num_warmup_steps, 
    num_training_steps=num_training_steps
)

for epoch in range(N_EPOCHS):
    scheduler.step() #If I do after train, LR = 0 for the first epoch
    print(optimizer.param_groups[0]["lr"])

    train(...) # here we call optimizer.step()
    evaluate(...)

My model and train routine(quite similar to this notebook)

class BERTGRUSentiment(nn.Module):
    def __init__(self,
                 bert,
                 hidden_dim,
                 output_dim,
                 n_layers=1, 
                 bidirectional=False,
                 finetune_bert=False,
                 dropout=0.2):

        super().__init__()

        self.bert = bert

        embedding_dim = bert.config.to_dict()['hidden_size']

        self.finetune_bert = finetune_bert

        self.rnn = nn.GRU(embedding_dim,
                          hidden_dim,
                          num_layers = n_layers,
                          bidirectional = bidirectional,
                          batch_first = True,
                          dropout = 0 if n_layers < 2 else dropout)

        self.out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)        
        self.dropout = nn.Dropout(dropout)

    def forward(self, text):    
        #text = [batch size, sent len]

        if not self.finetune_bert:
            with torch.no_grad():
                embedded = self.bert(text)[0]
        else:
            embedded = self.bert(text)[0]
        #embedded = [batch size, sent len, emb dim]
        _, hidden = self.rnn(embedded)

        #hidden = [n layers * n directions, batch size, emb dim]

        if self.rnn.bidirectional:
            hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))
        else:
            hidden = self.dropout(hidden[-1,:,:])

        #hidden = [batch size, hid dim]

        output = self.out(hidden)

        #output = [batch size, out dim]

        return output


import torch
from sklearn.metrics import accuracy_score, f1_score


def train(model, iterator, optimizer, criterion, max_grad_norm=None):
    """
    Trains the model for one full epoch
    """
    epoch_loss = 0
    epoch_acc = 0

    model.train()

    for i, batch in enumerate(iterator):
        optimizer.zero_grad()
        text, lens = batch.text

        predictions = model(text)

        target = batch.target

        loss = criterion(predictions.squeeze(1), target)

        prob_predictions = torch.sigmoid(predictions)

        preds = torch.round(prob_predictions).detach().cpu()
        acc = accuracy_score(preds, target.cpu())

        loss.backward()
        # Gradient clipping
        if max_grad_norm:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)


Answer

pashok3ddd picture pashok3ddd · May 2, 2020

Here you can see a visualization of learning rate changes using get_linear_scheduler_with_warmup.

Referring to this comment: Warm up steps is a parameter which is used to lower the learning rate in order to reduce the impact of deviating the model from learning on sudden new data set exposure.

By default, number of warm up steps is 0.

Then you make bigger steps, because you are probably not near the minima. But as you are approaching the minima, you make smaller steps to converge to it.

Also, note that number of training steps is number of batches * number of epochs, but not just number of epochs. So, basically num_training_steps = N_EPOCHS+1 is not correct, unless your batch_size is equal to the training set size.

You call scheduler.step() every batch, right after optimizer.step(), to update the learning rate.