How to initialize the weights and biases (for example, with He or Xavier initialization) in a network in PyTorch?
To initialize the weights of a single layer, use a function from torch.nn.init
. For instance:
conv1 = torch.nn.Conv2d(...)
torch.nn.init.xavier_uniform(conv1.weight)
Alternatively, you can modify the parameters by writing to conv1.weight.data
(which is a torch.Tensor
). Example:
conv1.weight.data.fill_(0.01)
The same applies for biases:
conv1.bias.data.fill_(0.01)
nn.Sequential
or custom nn.Module
Pass an initialization function to torch.nn.Module.apply
. It will initialize the weights in the entire nn.Module
recursively.
apply(fn): Applies
fn
recursively to every submodule (as returned by.children()
) as well as self. Typical use includes initializing the parameters of a model (see also torch-nn-init).
Example:
def init_weights(m):
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform(m.weight)
m.bias.data.fill_(0.01)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)