Mask R-CNN for Instance Segmentation Using Pytorch

DERBEL MohamedAziz 22 Feb, 2023 • 10 min read

Introduction

From the 2000s onward, Many convolutional neural networks have been emerging, trying to push the limits of their antecedents by applying state-of-the-art techniques. The ultimate goal of these deep learning algorithms is to mimic the human eye’s capacity to perceive the surrounding environment. Image classification, object detection, optical character recognition, and image segmentation tasks remain the main focus of today’s research. This guide focuses on the image segmentation task. First, we will see together a brief definition of image segmentation in general. Then, I will cover the difference between semantic segmentation and instance segmentation. The third part of this guide explains the widely used model architectures in both subcategories. Next, as the final preparation step, I will explain the theory behind the models that will be used in the following demo. And finally, I shall implement a demo using the deep learning framework PyTorch and its open-source library Torchvision to demonstrate how these models work in practice for an instance segmentation task.

By the end of this guide, you will be able to properly:

  • Define image segmentation and what it is used for.
  • Differentiate between semantic segmentation instance segmentation and some pre-trained models for these tasks.
  • Implement an end-to-end image segmentation program using Pytorch.

If you are here for the demo, feel free to skip the 3 first parts, as they serve as context preparation for the demo.

This article was published as a part of the Data Science Blogathon.

Table of Contents

Image Segmentation and Its Use Cases

Image segmentation is a key building block of computer vision technologies and algorithms. It is used for many practical applications, including medical image analysis, computer vision for autonomous vehicles, and face recognition. It’s a more advanced technique than image classification and object detection because it can perform more detailed image analyses and content extraction. Still, image classification and object detection models have their own use cases, and research is still going on to improve their performances.

The theoretical idea behind image segmentation is to distinguish different components from a given image. This separation is a pixel-based distinction, meaning each present pixel will be given a corresponding label or class. As a result, the image components will be defined as a group of pixels having the same label. The following example shows how image segmentation separated the cat segment from the image background. In other words, the model could differentiate between the cat and its surroundings.

Image Segmentation model

In practice, image segmentation is used in several industries. In medical diagnostics, especially in evaluating X-ray analysis and MRI scans, systems using semantic segmentation can help classify relevant regions of an image, making diagnostic tests easier and simpler. In satellite imagery, semantic segmentation has been used to analyze land usage, areas suffering from deforestation, and the analysis of agricultural lands.

Segmentation

In the automotive industry, image segmentation is used to identify objects like other cars and traffic signs and regions like road lanes and sidewalks. It’s a building block of the autonomous driving feature as it serves as the system’s eyes and forwards its generated predictions to a central unit where all necessary actions will be taken accordingly.

Segmentation

Semantic Segmentation and Instance Segmentation

Computer vision tasks range from simple image classification to real-time motion detection. Each has its use case and benefits. Knowing when and where to apply those techniques is of vital importance. Image segmentation, as an example, will create pixel-wise masks for each object; hence it will be useful to understand granular details about the analyzed images in contrast with image classification, which focuses on the image pattern to classify given images. Two approaches are available today to perform this image analysis: Semantic and instance segmentation.

Semantic segmentation is the process of assigning a class label for each pixel in the image. As a result, the generated image segments are class-based, and the model overlooks the number of occurrences of each instance of that class. For example, 2 cats in a single image are masked and grouped together as one segment.

Meanwhile, instance segmentation associates every pixel with one instance of a class. In other words, two groups of pixels sharing the same class label are logically distinguished and detected separately. Now, the model knows how many instances are present from each class label. You can think of this as applying an object detector on top of a semantic segmentation layer.

Image Segmentation Models

Certainly, many highly efficient models shared between the community members for image segmentation tasks have recorded state-of-art results outperforming their precedents. The FCN_RESNET50, for example, is a fully convolutional network model with a ResNet-50 backbone for semantic segmentation tasks. It was pre-trained on a subset of the coco train2017 dataset. The model was published in 2016, recording state-of-art results with 60.5 as the mean IOU and 91.4% as global pixel-wise accuracy.

For instance, the Mask R-CNN architecture has been widely adopted in segmentation tasks to detect instances of digital images accurately. A Mask R-CNN model is a region-based convolutional Neural Network and extends the faster R-CNN architecture by adding a third branch that outputs the object masks in parallel with the existing branch for bounding box recognition. The above image shows us a global overview of its architecture.

Mask R-CNN for Segmentation

I chose the Mask R-CNN architecture to conduct the instance segmentation demo using the deep learning framework PyTorch. The same pre-trained architecture exists under the name ‘MASKRCNN_RESNET50_FPN’ in the PyTorch hub. This version is powered by the ResNet50 backbone and trained on a subset of the COCO2017 dataset.

Instance Segmentation Demo

Now that we have seen some of the most important notions together let’s practice our knowledge. This demo consists of a practical guide on using a pre-trained model for an instance segmentation task using Pytorch. As I mentioned, I will use the Mask R-CNN architecture to segment arbitrarily chosen images from the internet.

Personally, I used google colab notebooks for today’s demo since all necessary dependencies are installed by default, and there is no need to reinstall them. Plus, colab notebooks provide GPU support that accelerates the model inference time, which might not be available on every workstation. Feel free to use any development environment you want if you have the following python packages installed in your system (There is no version restriction).

  • Pytorch
  • Torchvision
  • Numpy
  • Matplotlib

Now, I think we are ready to go.

Import the necessary package and define a utility function

import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
device = torch.device('cuda')

From line 1 to line 4, I imported torch, numpy, matplotlib, and the functional class.

Line 5: I selected the hardware on which all tensor operations will be performed. I chose GPU support since it’s available in colab notebooks and will save me a lot of time during the inference. You can switch to a normal CPU if you don’t have compatible hardware.

Import and visualize data

from torchvision.utils import make_grid
from torchvision.io import read_image
import torchvision.transforms as T
from pathlib import Path

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False,figsize=(12, 12))
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

From line 1 to line 2:  Again, some import statements related to data manipulation this time. I imported to handy function from the torchvision package. ‘make_grid’ function create a grid of image from a given list. The ‘read_image’ function loads our input image from the storage. The transform class provides many utility functions for input pre-processing, such as resizing and scaling.

Line 6: I defined a function named ‘show.’ It’s a function that takes the grid of images as input to plot it using the matplotlib subplots function. I will use this function to display my input images and the model output later.

image_1 = read_image(str(Path('drive/MyDrive/Computer_vision_data') / 'dog3.png'))
image_2 = read_image(str(Path('drive/MyDrive/Computer_vision_data') / 'dog2.jpg'))
image_list = [T.Resize(size=(500,500))(image_1),T.Resize(size=(500,500))(image_2)]
grid = make_grid(image_list) # return a tensor containing a grid of images
show(grid)

From lines 1 to 3, I loaded my 2 input images from my personal drive. It’s necessary to check out your path and make the changes. Then, I stored them in a python list after resizing them to 500*500. The resizing step was mandatory so that the images could fit into one grid (Not necessary from the model perspective as it can handle images with different shapes). I chose a new size of 500*500 to preserve my images from dilatation. I chose to work with only 2 images for this demo. Feel free to add as many as you want; the code will not break unless you add them to the list of images.

In Lines 4 and 5, I created the grid of images and passed it to the show function as an argument to display its content. Here is an example of what I got :

"

Model loading

from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights

weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
model = maskrcnn_resnet50_fpn(weights=weights, progress=False).to(device)
model = model.eval()

Line 1: I imported the mask r-cnn architecture and the associated pre-trained weights classes.

Lines 3  and 4: I selected the default version of weights and then instantiated the model architecture using the weights. Note that I added the to function the ‘.to()’ function at the end to tell Pytorch where to place the model weights for future computation. The device variable stores the hardware type we selected from the first code block.

Line 5: I set the model to the evaluation model to block the autograd feature of PyTorch in preparation for the inference step.

Input pre-processing

transforms = weights.transforms()
images = [transforms(d).to(device) for d in image_list]
for image in images:
  print(image.shape)

I initialized a transform variable with a callable object in the first line. This object is responsible for pre-processing our input images. That’s why just in the second line, I used the list comprehension syntax to apply the pre-processing amendments on each image.

Line 3: This is an optional loop I used to check out the output shape to keep track of what’s happening.

Model inference and output inspection

output = model(images) # list of dict
print(len(output)) # equals to how many images that were fed into the model
print(output[0].keys()) # dict_keys(['boxes', 'labels', 'scores', 'masks'])
def inspect_model_output(output):
  for index,prediction in enumerate(output):
    print(f'Input {index + 1} has { len(prediction.get("scores")) } detected instances')

The inference step is as simple as running a single line of code where the model object is called with its input as an argument. I stored the model response using the ‘output’ variable. Next, in the following lines, I wanted to inspect the model’s output shape.

Line 2: Display the size of our output; in my case, it was 2 since I used only 2 images as input.

Line 3: Inspect the model output: Each input image is associated with a python dictionary containing the masks, boxes, labels, and scores for the predicted instances inside that image.

Line 4: I defined a function to display the number of predicted instances in each input image. It’s a basic print statement where I inject information I want to display. Feel free to add your own.

Output processing

from torchvision.utils import draw_bounding_boxes
from torchvision.utils import draw_segmentation_masks

def filter_model_output(output,score_threshold):
  filtred_output = list()
  for image in output:
    filtred_image = dict()
    for key in image.keys():
      filtred_image[key] = image[key][image['scores'] >= score_threshold]
    filtred_output.append(filtred_image)
  return filtred_output

def get_boolean_mask(output):
  for index,pred in enumerate(output):
    output[index]['masks'] = pred['masks'] > 0.5
    output[index]['masks'] = output[index]['masks'].squeeze(1)
  return output

In the first 2 lines, I imported 2 utility functions from the torchvision package. As their names suggest, draw the bounding boxes and masks on top of the images.

Line 4: I defined a function to filter the model’s output according to a certain score threshold. In fact, this function will loop through the model output and filter out the instances with a prediction score below the threshold passed as an argument to the function.

Line 5: Here, I created another function called get_boolean_mask, where I convert the predicted masks for each input from the probability space to a boolean value. I hard-coded my baseline score equal to 0.5, so all probability below this score will be converted to false. In the end, every pixel in each mask will get a boolean value, either true or false. You might wonder why I created this function and what it is used for. Well, this function is handy when it comes to drawing the instances masks on top of the original images since the built-in function ‘ draw_segmentation_masks ‘ that I have imported in the second line expects the boolean masks of the instances masks to plot them.

Output Visualization

score_threshold = .8
output = filter_model_output(output=output,score_threshold=score_threshold)
output = get_boolean_mask(output)
show([
    draw_segmentation_masks(image, prediction.get('masks'), alpha=0.9)
    for index, (image, prediction) in enumerate(zip(image_list, output))
])

Line 1: Set the score threshold to 0.8. It means we will only keep instances where the model is at least 80% confident they exist.

Lines 2 & 3: Execute the pre-defined output processing functions as discussed previously.

Line 4: I used the list comprehension syntax to store the output of the ‘ draw_segmentation_masks ‘, which returns the original images with the detected masks on top of it. Then, the show function plots those newly generated images from the list.

Here is an example of the final result I got:

"

The model detected 3 instances in the 2 images: Person, dog, and phone. Meanwhile, it failed to mask the dog head in the first image due to the filter we applied. You can alter the prediction score threshold to gain more insights into the model behavior.

Conclusion

Throughout this guide, I tried to focus on one of the most advanced techniques regarding computer vision: image segmentation. We have seen what is meant by image segmentation in general. Then we discussed the differences between its 2 subcategories: semantic segmentation and instance segmentation. I covered from a high level 2 of the most adopted model architectures about those tasks before choosing one to conduct a practical demo about image segmentation. This was my final step in implementing all the previously discussed notions.

Still, applying transfer learning and using pre-trained models for computer vision tasks is utterly beneficial since training deep learning models for computer vision tasks from scratch is considered time-consuming and requires compute-intensive support. This implies additional costs for entrepreneurs who want to take advantage of these technologies as soon as possible, given their potential to generate new financial benefits. From another perspective, training these models from scratch has a non-negligible carbon footprint that must be considered. To conclude, we should always prioritize using existing models for similar tasks. In worst cases, where the existing models do not fit your current need at all, consider fine-tuning them. This way, time and energy are saved, costs are minimized, and most importantly, our environment is protected.

The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion. 

Frequently Asked Questions

Lorem ipsum dolor sit amet, consectetur adipiscing elit,

Responses From Readers

Related Courses