Content aware image resizing using Deep Learning
Problem statement —
Ever tried uploading your vacation pic on social media and found it crop your image thereby cutting out some important parts (you!!!) of the pic? This blog explores the use of deep learning to reduce the size of the image, while retaining important parts of the image. We will also look at the steps to build a flask application, containerizing it and deploying it on AWS EC2.
Approach —
An image can be considered as a matrix of pixels. To resize the image without losing important objects in the image, we have to identify the pixels that are important. There are many ways to achieve this —
- Object detection + Seam carving
- Saliency Map — click here to learn more
Let us discuss object detection + seam carving.
In simple terms, Seam Carving is a technique where a path of low energy pixels from top to bottom or left to right (called seams) are identified to remove, without affecting the important parts of the image.
On the left is the original image that has to be resized. Notice the gap between the castle and the human. Using seam carving technique seams are identified in the center image. The identified seams are removed in the image on the right. After generating seam carved images for a small set of images, I found that seams can sometimes pass through through the objects, thereby losing important parts of the object/image. To overcome this issue we could run the image through Amazon Rekognition and find all the objects. Then generate seams using seam carving.
Now the question arises — How to use deep learning to generate seam carved image? I framed this seam identification process as a segmentation problem. We can build a model that can identify the pixels with low energy (seams) which can then they can be removed in post processing steps. U-net is a well-known architecture used in segmentation problems. There are many awesome blogs that explain this architecture. Click here to read a blog that explains semantic segmentation and U-net architecture.
Data Preparation —
I started off trying to download lots of image from google but I could not get it working on my old laptop. I ended up using PASCAL Visual Object Classes dataset available here. In retrospect, I should have spent more time to get the script to download images from google working, because the images in voc dataset is not of great quality.
Data → Amazon Rekognition → Seam carving to generate masks with seams that do not pass through the objects
Each image is sent through Amazon’s Rekognition service to identify the objects. It is then passed through seam carving function to generate masks with seams that do not pass through the objects. So, the input to train the model would be image and target would be mask.
Click here for the github link for this project.
After generating masks for all the images, building and training the model when I got the resized image I found that all the images were cropped towards the right ie. 25% of image on right was missing. I realized that something is wrong with the data and decided to analyze the input data. I built a heatmap of all the seams that were generated and found most of the seams were in the right half of the image, ie. If an image is divided into 4 parts vertically, most of the seams were in the 4th part (rightmost part) of the image.
So, I decided to build a dataset by selecting images with masks that had seams in all parts of the image. This reduces the size of the dataset drastically, which is fine because quality is important than quantity.
Observations —
After training the model with the balanced dataset, the resized images were better and there was scope for improvement. The first loss function I tried was MSE. The output images were good but some images had more seams passing only through the left and right parts of the image. Some other loss functions I tried include Focal Loss, Binary Cross Entropy but the resized image were very bad. I then tried Weighted Mean Squared Error. Here, if the seams are predicted in two middle sections of the image, they are given more weight than seams predicted in left and right sections of the image. This is to encourage the model to predict seams in the center of the image as well. The output looks like this.
This model can be improved by using better dataset, better techniques (other than seam carving) to generate better original masks for the model to learn from.
Deployment —
Now that the model is ready, let us build a flask web application so that anyone can get their image resized. Click here to learn how to build a basic Flask application.
When the user pastes a link to the image on the webpage(index.html) and hits resize button, the url is sent to a function in main.py where the image is converted to numpy, the model is loaded and run_inference function in ml_model.py is invoked. The image is then sent through the model and predicted mask is received as output. Resized image is then retrieved from the output mask in the post processing step. This image is then sent back to the front-end of the application. I got this running on my laptop before building a docker container.
Building a docker container and deploying on AWS took some time as I encountered many errors while building the docker container. Click here to look at the final working docker file as of December 2020. Deploying to EC2 instance was comparatively easy. I followed this blog for initial setup of EC2 instance, this blog to install docker on EC2, and this blog to deploy the flask application on EC2.
Demo -
Conclusion-
Current model does a decent job at resizing the image. It can be improved by (a) using better dataset, (b) using better techniques to resize the image while preparing training dataset, (c) including user feedback and use the poorly resized image in test set while training, (d) using better loss functions. More features like (a) increasing the size of the image, (b) getting user to choose if the image should be increased or decreased in size etc can be added to this project
Reference
[Segmentation and Unet] — https://towardsdatascience.com/understanding-semantic-segmentation-with-unet-6be4f42d4b47
[Initial setup of EC2] — https://docs.aws.amazon.com/AmazonECS/latest/developerguide/docker-basics.html
[Install docker on EC2] — https://docs.docker.com/engine/install/ubuntu/
[Deploy flask application on EC2] — https://medium.datadriveninvestor.com/dockerizing-and-hosting-your-flask-web-app-rest-api-on-aws-ec2-9f9c189bf563
[Github link for the code ] — https://github.com/ThePrecious/content_aware_image_resize
[Saliency map] — https://rajat-tripathi-08.medium.com/content-aware-image-resizing-with-deep-learning-e4c8d45efc64