Color Restoration with Generative Adversarial Network

11 minute read

Published:

Fast.ai has a two-part Deep Learning Course, the first being Practical Deep Learning for Coders, and the second being Deep Learning from the Foundations, both having different approaches and intended for different audiences. In the 7th lecture of Part 1, Jeremy Howard taught a lot about modern architectures such as Residual Network (ResNet) , U-Net, and Generative Adversarial Network (GAN).

Generative Adversarial Networks

GANs were first invented by Ian Goodfellow, one of the modern figures in the Deep Learning world. GANs could be used for various tasks such as Style Transfer, Pix2Pix, create CycleGAN, etc. Today what I’ll be experimenting with is Image Restoration.


Style Transfer Result | Tensorflow Tutorials

Image Restoration

There are different elements of an image which one can attempt to restore, and the example shown by Jeremy was restoring low resolution images into higher resolution images, which produces something like the following


Image Restoration Result | fast.ai

Jeremy also mentioned that GANs would also be capable of not only restoring an image’s resolution, but other elements such as clearing JPEG-like artifacts, different kinds of noise, or even restoring colors. And with that, I immediately hooked to finish the lecture and try out what I’ve learned, and thus came this project.

Color Restoration

Instead of turning low resolution images to high resolution images, I instead wanted to build a network which will be able to recolor black and white images. The approach is to do so is still similar in terms of how a GAN works, except with a few tweaks which we’ll discuss further down.

Code Source

Since it is the first time I’ve worked with generative networks like GANs, I decided to base my code heavily on a fast.ai notebook, lesson7-superres-gan.ipynb.

The code provided below isn’t complete and only the important blocks of code were taken.

The GAN Approach

A GAN is sort of like a game between two entities, one being the artist (formally generator) and the other being the critic (formally discriminator). Both of them have their own respective roles: the artist has to produce an image, while the critic has to decide whether the image produced by the artist is a real image or a fake/generated image.

The two of them have to get better at what they do, the critic has to get better at differentiating real from fake images, while the artist has to improve the image produced to fool the critic. The implementation of this concept to a task like image restoration is pretty much like the aforementioned. That is, the artist has to produce a higher resolution image from the low resolution image, while the critic also learns to distinguish between the two possibilities.

Now, to apply that to color restoration, instead of differentiating low resolution from high resolution images, the critic has to classify artist-generated images from colored images, and while doing so the artist has to learn how to better recolor the images it produces to outsmart the critic.

Data Modification

In order to build a network that is able to both learn to recolor images and to classify real from fake images, we need to provide it two sets of data, namely a colored image and its corresponding black-and-white image. To do so, we used the Pets dataset from Oxford IIT which are colored, and created a function to grayscale the images. Jeremy called the function to do such task as a crappifier, which in our case only grayscales the images. Once we have our colored and grayscaled images, we can use it later to train the network.

from PIL import Image, ImageDraw, ImageFont

class crappifier(object):
    def __init__(self, path_lr, path_hr):
        self.path_lr = path_lr
        self.path_hr = path_hr

    def __call__(self, fn, i):
        dest = self.path_lr/fn.relative_to(self.path_hr)
        dest.parent.mkdir(parents=True, exist_ok=True)
        img = PIL.Image.open(fn)
        img = img.convert('L')
        img.save(dest, quality=100)

Grayscaled Images

Pre-train Generator/Artist

Now, we will begin to train our generator first before using it in a GAN. The architecture we’ll use is a U-Net, with ResNet34 as its base model and all it’s trained to do is to recolor the images so it looks more like its colored-counterpart. Notice also that we’re using Mean Squared Error or MSELossFlat as our loss function.

arch = models.resnet34
loss_gen = MSELossFlat()

learn_gen = unet_learner(data_gen, arch, wd=wd, blur=True, norm_type=NormType.Weight,
                         self_attention=True, y_range=y_range, loss_func=loss_gen)

Once we have the generative model, we can train the model head for a few epochs, unfreeze, and train for several more epochs.

learn_gen.fit_one_cycle(2, pct_start=0.8)
epochtrain_lossvalid_losstime
00.1093060.11103802:37
10.0963120.10247902:40
learn_gen.unfreeze()
learn_gen.fit_one_cycle(3, slice(1e-6,1e-3))
epochtrain_lossvalid_losstime
00.0892060.10058302:41
10.0875620.09471602:44
20.0868390.09410602:45

The resulting generated images after a total of 5 epochs looks like the following


Generated Images

As you can see, the generator did poorly on some areas of the image, while it did great in others. Regardless, we’ll save those generated images to be used as the fake images dataset for the critic to learn from.

Train Discriminator/Critic

After generating two sets of images, we’ll feed the data to a critic and let it learn to distinguish between real images from the artist-generated images. Below is a sample batch of data, where the real images are labelled simply as images and the generated ones as image_gen


Real and Generated Images

To create the critic, we’ll be using fast.ai’s built-in gan_critic, which is just a simple Convolutional Neural Network with residual blocks. Unlike the generator, the loss function we’ll use is Binary Cross Entropy, since we only have two possible predictions, and also wrap it with AdaptiveLoss.

loss_critic = AdaptiveLoss(nn.BCEWithLogitsLoss())

learn_critic = Learner(data_crit, gan_critic(), metrics=accuracy_thresh_expand, loss_func=loss_critic, wd=wd)

Once the Learner has been created, we can proceed with training the critic for several epochs.

learn_critic.fit_one_cycle(6, 1e-3)
epochtrain_lossvalid_lossaccuracy_thresh_expandtime
00.1703560.1050950.95880403:34
10.0418090.0226460.99236503:27
20.0265200.0134800.99663803:26
30.0118590.0055850.99911703:25
40.0126740.0056550.99928803:25
50.0135180.0054130.99928803:24

GAN

With both of the generator and the critic pretrained, we can finally use both of them together and commence the game of outsmarting each other found in GANs. We will be utilizing AdaptiveGANSwitcher, which basically goes switches between generator to critic or vice versa when the loss goes below a certain threshold.

switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)

Wrapping both the generator and the critic inside a GAN learner:

learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.,50.), show_img=False, switcher=switcher,
                                 opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=wd)

A particular callback we’ll use is called GANDiscriminativeLR, which handles multiplying the learning rate for the critic.

learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))

Finally, we can train the GAN for 40 rounds before we use a larger image size to train for another 10 rounds.

lr = 1e-4
learn.fit(40, lr)
epochtrain_lossvalid_lossgen_lossdisc_losstime
03.7185573.85278303:27
13.2620253.45209603:29
23.2411053.49961003:29
33.0980723.51149203:31
43.1613093.21151103:30
53.1087232.59098703:29
63.0493293.21569503:29
73.1561223.25515803:29
83.0399213.25542303:30
93.1361423.10987303:30
102.9694353.09630903:30
112.9675173.53275303:30
123.0668353.30250403:28
132.9794723.14781403:29
142.8481813.22910103:29
152.9810363.37096103:30
162.8740223.64670103:32
172.8163353.51728403:33
182.8863163.33679303:33
192.8519273.59678303:33
202.8854493.56095603:33
213.0812553.35742603:31
222.8121353.34029003:33
232.9338713.47599303:32
243.0842403.03475803:31
252.9836083.11334903:33
262.7468272.86580603:32
272.7890293.17325903:33
282.9527773.22701203:32
292.8251853.05397903:34
302.7829073.44418203:34
312.8051903.34313203:33
322.9016203.29937503:33
332.7444633.27942103:32
342.8182383.04820603:32
352.7556712.97550403:32
362.7643823.07542503:32
372.7143433.07666203:32
382.8052593.29171903:32
392.7870183.17255103:32
learn.data = get_data(16, 192)
learn.fit(10, lr/2)
epochtrain_lossvalid_lossgen_lossdisc_losstime
02.7899683.12750008:28
12.8426873.22633408:22
22.7647773.12739308:24
32.7839103.18334508:23
42.7316493.27997608:21
52.6529343.14336308:23
62.6642482.99871808:22
72.7776353.18563208:27
82.7186683.35702508:26
92.6600092.88790808:23

The resulting training images looks like the following


GAN Produced Images

And as you can see, our model was able to recolor the images to a certain extent of accuracy. This is not bad, but GANs do have their weaknesses which we’ll discuss in the last section. Before we wrap up the GAN section, let’s try to feed the model external images, that is images that it hasn’t seen before.

Recoloring External Images

The following pet images were taken randomly from the internet. I’ve manually grayscaled the images and before letting the model predict its output.


GAN Produced Images

The colors produced, especially the animal’s fur is less saturated than it’s original image. However the natural background like grass and the sky is still acceptable, although different from the original.

Lastly, I tried to feed an image which is not a cat nor a dog. I tried to feed it images of actual people. The top row is a black-and-white picture which is already grayscaled when I received it. Whereas the bottom row’s image went through the same process as the images right above.


GAN Produced Images

Few things to notice here for the first prediction, the model is biased towards green and yellow colors, hence the floor color of the first output. Secondly, aside from coloring the person in front, the model also colored the person on the phone’s screen.

On the other hand, the second prediction was great at coloring the backdrop of mountains and the sky, but is bad at coloring the supposedly bright-red car as well as coloring the person as it remained mostly grey.

The most likely reason behind the poor recoloring of a person is because of the dataset being used to train the GAN on, which are Pets in this case.

Closing Remarks

Weaknesses of GANs

GANs are well known for being troublesome to be handled, especially during training, hence the fancy configuration and knobs which we have to have in order for it to behave well. Moreover, they take quite long hours to train in comparison to other architectures.

Possible Replacement of GANs

Just like shown in the remaining of Lecture 7, there are other architectures which are as good or even better than GANs, one of which is to use Feature Loss coupled with U-Nets, with shorter training hours and better results in several cases. I have tried doing that approach, but will not be discussing that here.

Conclusion

GANs are great, the tasks they can do vary from one architecture to another, and is one of the methods to let a model “dream” and have their own forms of creativity. However, they have certain weaknesses which includes long training time and careful tweaking requirements. They are definitely modern, and doing reasearch in the domain is still very much open and fun to do if you’re into this particular field.

That’s it! Thanks for your time and I hope you’ve learned something!