Extending the ImageDataGenerator in Keras and TensorFlow
This article was published as a part of the Data Science Blogathon.
Understanding the Problem
Many times while working on computer vision problems, we encounter situations where we need to apply some form of transformation to our entire dataset. The ImageDataGenerator class in Keras provides a variety of transformations such as flipping, normalizing, etc.
However, it becomes difficult to apply custom transformations that are not available in Keras. In our particular example, we will apply a denoising algorithm as a pre-processing transformation to our dataset. One trivial way to do this is to apply the denoising function to all the images in the dataset and save the processed images in another directory. However, this costs us both time and space. Another method is to perform this transformation on the fly using the preprocessing_function attribute.
In order to load the images for training, I am using the .flow_from_directory() method implemented in Keras. Denoising is fairly straightforward using OpenCV which provides several in-built algorithms to do so. In this article, I would show how to define our own preprocessing function, pass it to the training generator, and feed the images directly into the model thus eliminating the need to save them.
This tutorial would be broadly divided into 2 parts-
Implementing the denoising algorithm
Extending the preprocessing function
Let’s get started right away!
Part 1 – Implementing the denoising algorithm
Let us prepare a function that takes an image as an input, applies the inbuilt denoising algorithm, and returns the processed image.
import cv2 import numpy as np def preprocessing_fun(filename): img = cv2.imread(filename) dst = cv2.fastN1MeansDenoisingColored(img, None, 10, 10, 7, 21) return dst
We are using the fastN1MeansDenoisingColored algorithm from OpenCV because this algorithm works on colored images. OpenCV also provides other algorithms that work on images with a single-channel.
Now that we have implemented our algorithm, let’s use it in the ImageDataGenerator class.
Part 2 – Extending the preprocessing function
Here, we use the function defined in the previous section in our training generator.
img_datagen = ImageDataGenerator(rescale=1./255, preprocessing_function = preprocessing_fun) training_gen = img_datagen.flow_from_directory(PATH, target_size=(224,224), color_mode='rgb',batch_size=32, shuffle=True)
In the first 2 lines where we define ImageDataGenerator’s object, you can notice that we have passed our denoising function to the preprocessing_function parameter. By doing this, we are instructing our data generator to apply this function to every image as a preprocessing step before feeding it to the model. This way, we eliminate the need to process all the images and write them to a separate directory.
Pro Tip: If you need to perform a series of transformations that are defined in a different function, you can use it in your training generator in the following way.
def transform1(img): #Applies a transformation such as horizontal flip and returns the image return cv2.flip(img, 1) def transform2(img): #Applies a transformation such as vertical flip and returns the image return cv2.flip(img, 0) def transform3(img): #Applies 180-degree rotation and returns the image return cv2.rotate(img, cv2.ROTATE_180)
def our_preprocessing_function(filename): #Combines all the transformations img = cv2.imread(filename) img1 = transform1(img) img2 = transform2(img1) final_img = transform3(img2) return final_img img_datagen = ImageDataGenerator(rescale=1./255, preprocessing_function = our_preprocessing_function) training_generator = img_datagen.flow_from_directory(PATH, target_size=(224,224), color_mode='rgb', batch_size=32, class_mode='categorical', shuffle=True)
In this way, we can provide a series of custom transformations, wrap them in a function, and apply them to our dataset. This method is simple yet powerful and comes in handy when working in a resource-constraint environment.
Leave a Reply Your email address will not be published. Required fields are marked *