Image classification using TensorFlow 2
Build a deep learning model to detect Malaria infection in cell images.
Image classification is described as the process of feeding an image into a model created with a specific algorithm that returns the class or probability of the class to which the image belongs. This type of problems are very common in computer vision field.
Table of contents
· Prerequisites
· Project Description
· Set-up
· Modeling
· Model Evaluation
· Prediction on unseen images
· Conclusion
Prerequisites
- Python Programming
- Basic understanding of Neural Networks and CNNs
Project Description
Malaria is a life-threatening disease. It’s typically transmitted through the bite of an infected Anopheles mosquito. Infected mosquitoes carry the Plasmodium parasite. When this mosquito bites you, the parasite is released into your bloodstream. Once the parasites are inside your body, they travel to the liver, where they mature. After several days, the mature parasites enter the bloodstream and begin to infect red blood cells.
In the United States, the Centers for Disease Control and Prevention (CDC) report 1,700 cases Trusted Source of malaria annually.
In this post, We will build a Image classification model using deep learning and tensorflow that can detect whether a cell is parasitized (infected by malaria) or not.
Set-up
The data set is about 334Mb in size and it took me about 18–20 minutes to train the model. But wait! You don’t have to train the model again! You can download my model which gave me about 95% validation accuracy.
Download the data set from Kaggle :- Malaria-cell-dataset
Download the pre-trained model :- pretrained model (vgg.h5)
Virtual environment (Recommended!)
Python virtual environments are designed to provide an isolated environment for Python projects. This means that each project can have its own set of dependencies, independent of other projects.
💡 See this video to know about virtual environments. If you have anaconda distribution, then create a new conda environment.
- Open your Command Prompt (windows 10)
- create a virtual env :
python -m venv malaria
- activate the virtual env :
malaria\Scripts\activate.bat
- Update pip version :
pip install --update pip
- Install libraries
pip install tensorflow==2.4.1
pip install flask
pip install jupyter notebook
pip install numpy matplotlib
- Change directory :
cd Desktop/Malaria-app
- open jupyter notebook :
jupyter notebook
Now, Create a new ipython notebook and we’re all set! Let’s get our hands dirty…!
Modeling
Import tensorflow and other libraries
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.callbacks import EarlyStopping
Build an input pipeline
Train set : 80% of data
Validation set : 20% of data
# Data generator parametersWIDTH = 150
HEIGHT = 150
BATCH_SIZE = 128
VAL_SET_SIZE = 0.2
Image data generator :Generate batches of tensor image data with real-time data augmentation. Since we have enough data , we don’t augment our data.
generator = ImageDataGenerator(rescale=1/255.0,
validation_split=VAL_SET_SIZE)
Training pipe
train_gen = generator.flow_from_directory(
directory='cell_images/',
target_size=(WIDTH, HEIGHT),
class_mode = 'binary',
batch_size = BATCH_SIZE,
subset='training')
Validation pipe
val_gen = generator.flow_from_directory(
directory='cell_images/',
target_size=(WIDTH,HEIGHT),
class_mode = 'binary',
batch_size = BATCH_SIZE,
subset='validation')
# class labels
print(train_gen.class_indices)
print(val_gen.class_indices)
Examine and understand data
Parasitized cell images
path = r'cell_images/Parasitized/'
n_images = 9
images = os.listdir(path)[:n_images]
f,axes = plt.subplots(3,3, figsize=(9,9))
axes = np.ravel(axes)
for i, img in enumerate(images):
img = plt.imread(os.path.join(path,img))
axes[i].imshow(img)
Uninfected cell images
path = r'cell_images/Uninfected/'
n_images = 9
images = os.listdir(path)[:n_images]
f,axes = plt.subplots(3,3, figsize=(9,9))
axes = np.ravel(axes)
for i, img in enumerate(images):
img = plt.imread(os.path.join(path,img))
axes[i].imshow(img)
Build the model
Since it is a binary classification task, We will build a customized version of VGG16 model architecture to solve our problem.
# function for creating a vgg block
def vgg_block(layer_in, n_filters, n_conv):
for _ in range(n_conv):
layer_in = Conv2D(n_filters, (3,3),
padding='same',
activation='relu')(layer_in)
layer_in = MaxPooling2D((2,2), strides=(2,2))(layer_in)
return layer_in
# define model input
visible = Input(shape=(150, 150, 3))
# add vgg module
layer = vgg_block(visible, 64, 2)
# add vgg module
layer = vgg_block(layer, 128, 2)
# add vgg module
layer = vgg_block(layer, 256, 2)
x = Flatten()(layer)
x =(Dense(128, activation='sigmoid'))(x)
y = (Dense(1, activation='sigmoid'))(x)
# create model
model = Model(inputs=visible, outputs=y, name='VGG')
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
Model Training
Early stopping
Early stopping is a method that allows you to specify an arbitrary large number of training epochs and stop training once the model performance stops improving on a hold out validation dataset.
patience
: Number of epochs with no improvement after which training will be stopped.
Note : We already have the pre-trained model with us. If you want, you can train it again to see the results.
early_stopping = EarlyStopping(monitor='val_loss', patience=3)EPOCHS = 20
history = model.fit(train_gen,
epochs = EPOCHS,
validation_data = val_gen,
callbacks=[early_stopping])
The training stopped after 11 epochs with 94.7% of validation accuracy
Model Evaluation
Loss
# plot the loss
plt.plot(history.history['loss'], label='train loss')
plt.plot(history.history['val_loss'], label='val loss')
plt.legend()
plt.show()
Accuracy
# plot the accuracy
plt.plot(history.history['accuracy'], label='train acc')
plt.plot(history.history['val_accuracy'], label='val acc')
plt.legend()
plt.show()
Prediction on unseen images
Let’s create a function that will take an image as argument and returns the prediction. Load our pre-trained model or use your model.
# Load the pre-trained model
from tensorflow.keras.models import load_model
model = load_model(“vgg.h5”)# function to predict on unseen image
def predict(img):
x = image.img_to_array(img)
x = x/255.0
x = np.expand_dims(x,axis=0)
proba = model.predict(x)[0][0]
y = "Uninfected" if proba > 0.5 else "Parasitized"
return y, proba
Let’s check our function
# prediction on parasitized imagepath = "cell_images/Parasitized/"
im = "C33P1thinF_IMG_20150619_114756a_cell_179.png"img = image.load_img(os.path.join(path, im),target_size=(150, 150))
plt.figure(figsize=(4,4))
plt.title("True label : Parasitized")y, proba = predict(img) # Predictionplt.xlabel(f"Predicted label : {y}", fontsize=14)
plt.imshow(img)
Yayy!! The model is doing well and ready for serving!!
Now, save the model using save
method if you trained it again!
model.save(“vgg.h5”)
Conclusion
We have successfully built our CNN model with pretty good performance.
To make this an end to end ML project, create a simple UI / Api using flask and try to deploy it to cloud (AWS or GCP).