I am trying to wrap my head around skip connections in a sequential model. With the functional API I would be doing something as easy as (quick example, maybe not be 100% syntactically correct but should get the idea):
x1 = self.conv1(inp)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.deconv4(x)
x = self.deconv3(x)
x = self.deconv2(x)
x = torch.cat((x, x1), 1))
x = self.deconv1(x)
I am now using a sequential model and trying to do something similar, create a skip connection that brings the activations of the first conv layer all the way to the last convTranspose. I have taken a look at the U-net architecture implemented here and it's a bit confusing, it does something like this:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
Isn't this just adding layers to the sequential model well, sequentially? There is the down
conv which is followed by submodule
(which recursively adds inner layers) and then concatenated to up
which is the upconv layer. I am probably missing something important on how the Sequential
API works, but how does the code snipped from U-NET actually implements the skip?
Your observations are correct, but you may have missed the definition of UnetSkipConnectionBlock.forward()
(UnetSkipConnectionBlock
being the Module
defining the U-Net block you shared), which may clarify this implementation:
(from pytorch-CycleGAN-and-pix2pix/models/networks.py#L259
)
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
# ...
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([x, self.model(x)], 1)
The last line is the key (applied for all inner blocks). The skip layer is simply done by concatenating the input x
and the (recursive) block output self.model(x)
, with self.model
the list of operations you mentioned -- so not so differently from the Functional
code you wrote.