pytorch skip connection in a sequential model

powder picture powder · Aug 9, 2018 · Viewed 8.6k times · Source

I am trying to wrap my head around skip connections in a sequential model. With the functional API I would be doing something as easy as (quick example, maybe not be 100% syntactically correct but should get the idea):

x1 = self.conv1(inp)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)

x = self.deconv4(x)
x = self.deconv3(x)
x = self.deconv2(x)
x = torch.cat((x, x1), 1))
x = self.deconv1(x)

I am now using a sequential model and trying to do something similar, create a skip connection that brings the activations of the first conv layer all the way to the last convTranspose. I have taken a look at the U-net architecture implemented here and it's a bit confusing, it does something like this:

upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                    kernel_size=4, stride=2,
                                    padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]

if use_dropout:
    model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
    model = down + [submodule] + up

Isn't this just adding layers to the sequential model well, sequentially? There is the down conv which is followed by submodule (which recursively adds inner layers) and then concatenated to up which is the upconv layer. I am probably missing something important on how the Sequential API works, but how does the code snipped from U-NET actually implements the skip?

Answer

benjaminplanche picture benjaminplanche · Aug 10, 2018

Your observations are correct, but you may have missed the definition of UnetSkipConnectionBlock.forward() (UnetSkipConnectionBlock being the Module defining the U-Net block you shared), which may clarify this implementation:

(from pytorch-CycleGAN-and-pix2pix/models/networks.py#L259)

# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
#   |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):

    # ...

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)

The last line is the key (applied for all inner blocks). The skip layer is simply done by concatenating the input x and the (recursive) block output self.model(x), with self.model the list of operations you mentioned -- so not so differently from the Functional code you wrote.