Deploy an Image Classification Model Using Flask
- Get an overview of PyTorch and Flask
- Learn to build an image classification model in PyTorch
- Learn how to deploy the model using Flask.
Image Classification is a pivotal pillar when it comes to the healthy functioning of Social Media. Classifying content on the basis of certain tags are in lieu of various laws and regulations. It becomes important so as to hide content from a certain set of audiences.
I regularly encounter posts with a “Sensitive Content” on some of the images while scrolling on my Instagram feed. I am sure you must have too. Any image on a humanitarian crisis, terrorism, or violence is generally classified as ‘Sensitive Content’. It always intrigued me how Instagram categorizes an image. This unceasing curiosity pushed me to find answers to understand the procedure of Image Classification.
Most of the images are detected by the image classification models deployed by Instagram. And, there is also a community-based feedback loop. This is one of the most important use cases of the Image Classification. In this article, we will deploy an image classification model to detect the category of the images.
Table of Contents
- What is Model Deployment?
- Introduction to PyTorch
- What is Flask?
- Installing Flask and PyTorch on your Machine
- Understanding the Problem Statement
- Setup the Pre-Trained Image Classification Model
- Build an Image Scraper
- Create the Webpage
- Setup the Flask Project
- Working of the Deployed Model
What is Model Deployment?
In a typical machine learning and deep learning project, we usually start with defining the problem statement followed by data collection and preparation, and model building, right?
Once we have successfully built and trained the model, we want it to be available for the end-users. Thus we will have to “deploy” the model so that the end-users can make use of it. Model Deployment is one of the later stages of any machine learning or deep learning project.
In this article, we will build a classification model in PyTorch and then learn how to deploy the same using Flask. Before we get into the details, let us have a quick introduction to PyTorch.
Introduction to PyTorch
PyTorch is a python based library that provides flexibility as a deep learning development platform. The workflow of PyTorch is as close as you can get to python’s scientific computing library – NumPy.
PyTorch is being widely used for building deep learning models. Here are some important advantages of PyTorch –
- Easy to use API –The PyTorch API is as simple as python can be.
- Python support – PyTorch smoothly integrates with the python data science stack.
- Dynamic computation graphs – PyTorch provides a framework for us to build computational graphs as we go, and even change them during the runtime. This is valuable for situations where we don’t know how much memory is going to be required for creating a neural network.
Note: If you are new to PyTorch, I would recommend you to go through below resources:
- Enroll in the free course of PyTorch: Introduction to PyTorch for Deep Learning
- An amazing tutorial: An Introduction to PyTorch – A Simple yet Powerful Deep Learning Library
In further sections, we will use a pre-trained model to detect the category of the image using PyTorch. Next, we will be using Flask for model deployment. In the next section, we will briefly discuss Flask.
What is Flask?
Flask is a web application framework written in Python. It has multiple modules that make it easier for a web developer to write applications without having to worry about the details like protocol management, thread management, etc.
Flask gives a variety of choices for developing web applications and it gives us the necessary tools and libraries that allow us to build a web application.
Installing Flask and PyTorch on your Machine
Installing Flask is simple and straightforward. Here, I am assuming you already have Python 3 and pip installed. To install Flask, you need to run the following command:
sudo apt-get install python3-flask
Next, we need to install the PyTorch. You are not required to have the GPU to run the code provided in this article.
!pip install torch torchvision
That’s it! Now let us take up a problem statement and build a model.
Understanding the Problem Statement
Let us discuss the problem statement, we want to create a web page that will contain a text box like this (shown below). Here users will input URL. And, here the task is to scrape all images from the URL. For each of the images, we will predict the category or class of the image using an image classification model and render the images with categories on the webpage.
Here is the workflow for the end-to-end model-
Setting up the Project WorkFlow
- Model Building: We will use a pre-trained model Densenet 121 to predict the image class. It is available in the torchvision library in PyTorch. Here, our focus will not be on building a highly accurate classification model from scratch but to see how to deploy that model and make use of it with a web interface.
- Create an Image Scraper: We will create a web scraper using requests and the BeautifulSoup library. It will download all the images from a URL and store it so that we can make predictions on it.
- Design Webpage Template: Also we will design a user interface where the user can submit a URL and also get the results once calculated.
- Classify images and send results: Once we get the query from the user, we will use the model to predict classes of the images and send the results back to the user.
Here is a representation of the steps we just saw:
Let’s discuss all the required components of the projects:
Setup the Pre-Trained Image Classification Model
We will use a pre-trained model Densenet 121 to classify the images. If you want to build an Image Classification model I would highly recommend you to go through this article: Build your First Image Classification Model in just 10 Minutes!
You can download the complete code and dataset here.
Let’s start by importing some of the required libraries and get the densenet121 model from the torchvision library. Make sure to add the parameter ‘pretrained’ as True.
Now, we will define a function to transform the image. It will create a transform pipeline and transform the image as required. This method takes the image data in bytes and applies a series of ‘transform’ functions on it and returns a tensor. This piece of code was taken from the pytorch documentation.
Now, the pre-trained model returns the index of the predicted class id. PyTorch has provided mapping for the same so that we can see the name of the predicted class. You can download the mapping here. It has 1000 different categories.
Here is a sample of the mapping:
Next, we will define a function to get the category of the image. For this, we will pass the path of the image as the only parameter. At first, it will open and read the image in the binary form and then transform it. Then it will pass the transformed image to the model to get the predicted class. It will use the mapping and return the class name.
Let’s try this function on a few images:
get_category(image_path='static/sample_1.jpeg') ## ['n02089973', 'English_foxhound']
get_category(image_path='static/sample_2.jpeg') ## ['n11939491', 'daisy']
Now, our model is ready to predict the classes of the image. Let’s start with building the image scraper.
Build an Image Scraper
In this section, we will build a web scraper that will download the images from the URL provided. We will use the BeautifulSoup library to download the images. You are free to use any other library or an API that will give you the images. If you are not comfortable with the web scraping part, I will recommend you to enroll in this Free Course: Introduction to Web Scraping using Python.
We will start by importing some of the required libraries. For each of the URLs that we will scrape a new directory will be created to store the images. We will create a function get_path that will return you the path of the folder created for that URL.
Now, we will define a function get_images. It will first create the directory using the get_path function and then send a request for the source code. From the source code, we will extract sources by using “img” tag.
After this, We will select only the images with jpeg format. You can also add png format images. I have filtered them out as most of the png format pictures are logos. In the end, start the counter and save images with counter names in the specified directory.
Let’s try out the scraper that we have just created!
Now, a new directory is created and see how it looks like. We have all the images downloaded at a single place.
Note: It is advised to use this Image Scraper as per the learning purpose only. Always follow the robots.txt file of the target website which is also known as the robot exclusion protocol. This tells web robots which pages not to crawl.
Create the Webpage
We will create two webpages one is “home.html” and another one is “image_class.html”.
- “home.html” is the default one which will have a text box in which a user can type the URL.
- “image_class.html” will help us to render the images with their categories.
We need to add the form tag in the home.html file to collect the data in the search container. In the form tag, we will pass the method post and name as “search”.
By doing this, our backend code would be able to know that we have received some data with the name “search”. At the backend, we need to process that data and send it.
While calculating the results another page will get rendered with the results as shown below. This page “image_class.html” will be updated on every query. And you can see that we are showing below information on the web page:
- Image category
- Frequency count of all available image category
Here, is code to perform this:
The next step is to setup the Flask project to combine these individual pieces to solve the challenge.
Setup the Flask Project
We have done the following tasks involved in our project:
- Image Classification model that is working fine and able to classify the images.
- We have built the Image scraper that will download the images and store them.
- We have created the webpage to get and return the results.
And now we need to connect all these files together so that we can have a working project.
Let’s have a look at the directory structure.
Note: Make sure that you save the images in the folder name static and html files in templates. Flask will only look for these names. You will get an error if you change these.
Running a Flask Application
Flask application will first render the home.html file and whenever someone sends a request for the image classification, Flask will detect a post method and call the get_image_class function.
This function will work in the following steps:
- First, it will send a request to download the images and store them.
- Next, It will send the directory path to the get_prediction.py file which will calculate and return the results in the form of a dictionary.
- Finally, It will send this dictionary to the generate_html.py, file generating the output file which will be sent back to the user.
Once the above steps are done, we are ready to serve the user with the results. We will call the success function which will then render the image_class.html file.
Get Prediction for all images of Source URL
Till now, we have taken prediction for each image individually. Now, we will solve this by modifying get_category function with new parameters. We will pass the directory path which will contain multiple image files.
Now, we will define another function get_prediction which will use the get_category function and will return the dictionary where the keys will be the image path and the values will be the image class.
Later, we will send this dictionary to the generate_html.py file which will create the HTML file for us.
Now, all the code files are ready and we just need to connect these with the master file.
Firstly, create an object of the Flask class that will take the name of the current module __name__ as an argument. The route function will tell the Flask application which URL to render next on the webpage.
Working of the Deployed Model
You can download the complete code and dataset here.
Now, we will run the file get_class.py and the flask server will get started on localhost:5000.
Open the web browser and go to localhost:5000 and you will see that the default home page is rendered there. Now, type any URL in the text box and press the search button. It might take 20-30 seconds depending upon the number of images in that URL and the Internet speed.
Let’s check out the working of the deployed model.
In this article, I explained, in brief, the concepts of model deployment, Pytorch, and Flask. Then we dived into understanding various steps involved in the process of creating an image classification model using PyTorch and deploying it with Flask. I hope this helps you in building and deploying your image classification model.
Also, the model was deployed on the localhost. We can also deploy it on Cloud Services like Google Cloud, Amazon, github.io etc. We will cover this also in the upcoming article.
Reach out on the comment section in case of any doubts. I will be happy to help.