I'm trying to fine-tune a model with BERT (using transformers
library), and I'm a bit unsure about the optimizer and scheduler.
First, I understand that I should use transformers.AdamW
instead of Pytorch's version of it. Also, we should use a warmup scheduler as suggested in the paper, so the scheduler is created using get_linear_scheduler_with_warmup
function from transformers
package.
The main questions I have are:
get_linear_scheduler_with_warmup
should be called with the warm up. Is it ok to use 2 for warmup out of 10 epochs? scheduler.step()
? If I do after train
, the learning rate is zero for the first epoch. Should I call it for each batch?Am I doing something wrong with this?
from transformers import AdamW
from transformers.optimization import get_linear_scheduler_with_warmup
N_EPOCHS = 10
model = BertGRUModel(finetune_bert=True,...)
num_training_steps = N_EPOCHS+1
num_warmup_steps = 2
warmup_proportion = float(num_warmup_steps) / float(num_training_steps) # 0.1
optimizer = AdamW(model.parameters())
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([class_weights[1]]))
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
for epoch in range(N_EPOCHS):
scheduler.step() #If I do after train, LR = 0 for the first epoch
print(optimizer.param_groups[0]["lr"])
train(...) # here we call optimizer.step()
evaluate(...)
My model and train routine(quite similar to this notebook)
class BERTGRUSentiment(nn.Module):
def __init__(self,
bert,
hidden_dim,
output_dim,
n_layers=1,
bidirectional=False,
finetune_bert=False,
dropout=0.2):
super().__init__()
self.bert = bert
embedding_dim = bert.config.to_dict()['hidden_size']
self.finetune_bert = finetune_bert
self.rnn = nn.GRU(embedding_dim,
hidden_dim,
num_layers = n_layers,
bidirectional = bidirectional,
batch_first = True,
dropout = 0 if n_layers < 2 else dropout)
self.out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, text):
#text = [batch size, sent len]
if not self.finetune_bert:
with torch.no_grad():
embedded = self.bert(text)[0]
else:
embedded = self.bert(text)[0]
#embedded = [batch size, sent len, emb dim]
_, hidden = self.rnn(embedded)
#hidden = [n layers * n directions, batch size, emb dim]
if self.rnn.bidirectional:
hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))
else:
hidden = self.dropout(hidden[-1,:,:])
#hidden = [batch size, hid dim]
output = self.out(hidden)
#output = [batch size, out dim]
return output
import torch
from sklearn.metrics import accuracy_score, f1_score
def train(model, iterator, optimizer, criterion, max_grad_norm=None):
"""
Trains the model for one full epoch
"""
epoch_loss = 0
epoch_acc = 0
model.train()
for i, batch in enumerate(iterator):
optimizer.zero_grad()
text, lens = batch.text
predictions = model(text)
target = batch.target
loss = criterion(predictions.squeeze(1), target)
prob_predictions = torch.sigmoid(predictions)
preds = torch.round(prob_predictions).detach().cpu()
acc = accuracy_score(preds, target.cpu())
loss.backward()
# Gradient clipping
if max_grad_norm:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
Here you can see a visualization of learning rate changes using get_linear_scheduler_with_warmup
.
Referring to this comment: Warm up steps is a parameter which is used to lower the learning rate in order to reduce the impact of deviating the model from learning on sudden new data set exposure.
By default, number of warm up steps is 0.
Then you make bigger steps, because you are probably not near the minima. But as you are approaching the minima, you make smaller steps to converge to it.
Also, note that number of training steps is number of batches
* number of epochs
, but not just number of epochs
. So, basically num_training_steps = N_EPOCHS+1
is not correct, unless your batch_size
is equal to the training set size.
You call scheduler.step()
every batch, right after optimizer.step()
, to update the learning rate.