Use Pytorch SSIM loss function in my model

sealpuppy picture sealpuppy · Dec 28, 2018 · Viewed 12.1k times · Source

I am trying out this SSIM loss implement by this repo for image restoration.

For the reference of original sample code on author's GitHub, I tried:

for epo in range(epoch):
    for i, data in enumerate(trainloader, 0):
        inputs = data
        inputs = Variable(inputs)
        inputs = inputs.view(bs, 1, 128, 128)
        top = model.upward(inputs)
        outputs = model.downward(top, shortcut = True)
        outputs = outputs.view(bs, 1, 128, 128)

        if i % 20 == 0:
            out = outputs[0].view(128, 128).detach().numpy() * 255
            cv2.imwrite("/home/tk/Documents/recover/SSIM/" + str(epo) + "_" + str(i) + "_re.png", out)

        loss = - criterion(inputs, outputs)
        ssim_value = -
        print (ssim_value)

However, the results didn't come out as I expected. After first 10 epochs, the printed outcome image were all black.

loss = - criterion(inputs, outputs) is proposed by the author, however, for classical Pytorch training code this will be loss = criterion(y_pred, target), therefore should be loss = criterion(inputs, outputs) here.

However, I tried loss = criterion(inputs, outputs) but the results are still the same.

Can anyone share some thoughts about how to properly utilize SSIM loss? Thanks.


Kinal picture Kinal · Feb 7, 2019

The author is trying to maximize the SSIM value. The natural understanding of the pytorch loss function and optimizer working is to reduce the loss. But the SSIM value is quality measure and hence higher the better. Hence the author uses
loss = - criterion(inputs, outputs)

You can instead try using
loss = 1 - criterion(inputs, outputs)
as described in this paper.

Modified code ( for testing the above thing using this repo

import pytorch_ssim
import torch
from torch.autograd import Variable
from torch import optim
import cv2
import numpy as np

npImg1 = cv2.imread("einstein.png")

img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0
img2 = torch.rand(img1.size())

if torch.cuda.is_available():
    img1 = img1.cuda()
    img2 = img2.cuda()

img1 = Variable( img1,  requires_grad=False)
img2 = Variable( img2, requires_grad = True)

# Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True)
ssim_value = 1-pytorch_ssim.ssim(img1, img2).item()
print("Initial ssim:", ssim_value)

# Module: pytorch_ssim.SSIM(window_size = 11, size_average = True)
ssim_loss = pytorch_ssim.SSIM()

optimizer = optim.Adam([img2], lr=0.01)

while ssim_value > 0.05:
    ssim_out = 1-ssim_loss(img1, img2)
    ssim_value = ssim_out.item()