Potato Disease Classifier — End to End Deep Learning Project

Ashok kumar Palivela
4 min readOct 29, 2021

Build a web app to predict the diseases of Potato plants using TensorFlow2 and Flask

High-Level Overview of the Project (image by author)

1. Introduction

In this article, We will solve a machine learning problem from agriculture domain using Convolutional Neural Networks and Tensorflow2.

We will build a web application to predict the diseases of Potato plants.

This application will help farmers to identify the diseases in potato plants so that they can use appropriate fertilizers to get more yield.

1.1 Image Classification

A classical computer vision problem, where the task is to predict the class of an image within a known set of possible classes.

1.2 Problem statement

  • To classify the given potato leaf image as healthy, late blight or early blight.
  • It is a multi class classification problem.

2. Data

We will use a kaggle dataset for this project.
I created a subset of the original data, which includes only the diseases of potato plants. You can find the dataset used in this project here!

Late Blight: Late blight of potato is a disease caused by fungus Phytophthora infestans.

Early Blight: Early blight of potato is a disease caused by the fungus Alternaria solani

Healthy: Uninfected or healthy plant

Sample images from the dataset

3. Data Preparation!

Let’s begin our coding, Open your jupyter notebook and import necessary python modules.

# Importing python modules
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
import tensorflow as tf

3.1 Train test split

Let’s split our dataset into two sets, 80% for training the model and 20% for validation.

Found 2152 images belonging to 3 classes.
Using 1722 images for training.
Using 430 images for validation.

Let’s check the distribution of our images in train and validation datasets. It’s mandatory to split the data with same distribution. It helps to understand the model performance correctly.

Train and validation sets have the same distribution

3.2 data optimization

The tf.data API helps to build flexible and efficient input pipelines that delivers data for the next step before the current step has finished. Prefetching reduces the training time. Read TensorFlow blog for deeper understanding.

# Configure the dataset for performanceAUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

4. Modeling

4.1 Model Architecture

We’ll build our model from scratch. After experimenting on different models, I found the below architecture performing better.

Note that we didn’t scaled our images when loading, because we added a Rescaling layer in our model to do it. This layer automatically scales unseen images, which is helpful for deployment.

We have a Softmax layer at the end to output the probabilities of three classes.

4.2 Model training

We have everything to train our model, We use Adam algorithm for optimization and EarlyStopping callback to avoid over fitting.

We’ll train our model for 20 epochs and store the training history in the history variable.

Model training

After 20 epochs, we got 0.99 training accuracy and 0.97 validation accuracy. Good performance and no overfitting. Great!

5. Model Evaluation

5.1 Training history

Let’s look at Loss and Accuracy of the model at each epoch of training.

Journey of the model

5.2 Confusion matrix

By visualizing the confusion matrix, an individual could determine the accuracy of the model by observing the diagonal values for measuring the number of accurate classification.

0=Early blight, 1=healthy, 2=Late blight

Our model is performing better on all classes. 13% of healthy leaves are classified as Late blight. If we add more images for healthy class, we can overcome this. But 87% accuracy also not bad for now!

5.3 classification report

Let’s look at some other metrics like precision, recall and f1-score.

precision, recall, f1-score

Everything looks fine till now. We’ve successfully trained our model and saved it for future use. Let’s build our web application.

6. Deployment!

I developed the front-end part using HTML, CSS, & bootstrap and used Flask framework for back-end.

It doesn’t look good to show all the development code here. I’ve uploaded all the project code to my github.

Download the full source code here!

6.1 Web application

7. Outro!

Making mini Projects like this one will increase our interest in data science to learn more and build applications to solve complex problems using machine learning.

PS: I am going to post more end to end projects and other machine learning stuff. Do follow for updates.🤗 Have a nice day!!

--

--