Banana Leaf Disease Detection using Vision Transformer model
Banana cultivation is a significant agricultural activity in many tropical and subtropical regions, providing a vital source of income and nutrition. However, the health and productivity of banana plants are frequently threatened by various leaf diseases. These diseases can drastically reduce yield and quality, leading to substantial economic losses for farmers and impacting food security.
Traditional methods of disease detection in banana plants rely heavily on manual inspection by experts, which can be time-consuming, labor-intensive, and prone to human error. Additionally, by the time symptoms are visually noticeable, the disease may have already spread extensively, making effective management and control more challenging.To address these challenges, the Banana Leaf Disease Detection project leverages advanced technologies such as machine learning, computer vision, and remote sensing. By utilizing these innovative approaches, the project aims to develop an automated and efficient system for early detection and diagnosis of banana leaf diseases.
Explanation All Code
Step 1:
Import and install the necessary packages.
Importing Libraries
You can mount your Google Drive in a Google Colab notebook with this piece of code. This makes it easy to view files stored in Google Drive for data manipulation, analysis, and training models in the Colab environment.
Step 2:
Data collection and preparation
We split the dataset for training and testing. We used 80% data for traning and remaining 20% as test data.
This block of code loads and preprocesses an image dataset from Google Drive for Keras model training. It sets up an ImageDataGenerator for data augmentation and rescaling pixel values to [0, 1]. The dataset is split into 80% for training and 20% for validation. Using flow_from_directory, it generates batches of 32 images resized to 224x224 pixels with categorical labels for multi-class classification.
This block of code calculates and displays class weights to address class imbalance in a dataset and prints class labels for a sample batch. It uses class_weight.compute_class_weight from sklearn to calculate balanced class weights based on the training data, storing them in a dictionary with class indices as keys. The mapping of class names to numerical labels is printed, and a new dictionary, class_weights_with_labels, maps class names to their corresponding
weights. Finally, it prints class names for a sample batch from the training generator by decoding the class indices, providing a quick check of class labels and class distribution.
Step 3:
Display the images of every class
This block of code visualizes a batch of images from the training dataset with their class labels using Matplotlib and Seaborn. It sets the plot style to "dark" with sns.set(style="dark") and fetches a batch of images and labels from train_generator. A function plot_images is defined to display these images in a 4x8 grid, removing axes and setting each subplot's title to the corresponding class name (decoded from the one-hot encoded label). Finally, it calls plot_images to show the images and their labels, providing a visual check of the batch.
Plot the grayscale images
This block of code defines and uses a function to visualize a batch of Sobel-filtered images with their class labels. The plot_sobel_images function takes original images, Sobel-filtered images, labels, and a class index-to-name mapping. It creates a 4x8 grid of subplots using Matplotlib, displaying each Sobel-filtered image in grayscale with its corresponding class name as the title, and turns off the axes for a cleaner look. The function is then called to show the
Sobel-filtered images along with their class labels, offering a visual representation of the filtered images and their classes.
Step 4:
Build Custom CNN Model
Convolution Neural Network
A Convolutional Neural Network (CNN) processes an input image by first applying a series of convolutional layers, each consisting of a small filter (kernel) that slides over the image to extract features, followed by a ReLU activation function to introduce non-linearity. These layers are interspersed with pooling layers that reduce the dimensionality of the feature maps, making the computation more efficient while retaining essential information. After several
convolution and pooling operations, the resulting feature maps are flattened into a single vector and passed through fully connected layers, which act like a traditional neural network. The final output layer uses a SoftMax activation function to produce a probabilistic distribution over the possible classes, such as identifying whether the image is of a horse, zebra, or dog. This combination of feature extraction and classification allows CNNs to effectively recognize and
categorize images.
This block of code defines a convolutional neural network (CNN) model using the Sequential API from Keras, designed for image classification tasks. The model starts with an input layer for 128x128 RGB images, followed by four blocks of convolutional layers with ReLU activation and batch normalization to stabilize learning. Each block has two convolutional layers, followed by max-pooling layers to downsample the feature maps and dropout layers to prevent overfitting. The number of filters increases progressively (64, 128, 256, 512) to capture more complex features. After the convolutional blocks, a global average pooling layer reduces the spatial dimensions, followed by fully connected dense layers with dropout and batch normalization to further reduce over-fitting and enhance learning stability. The final dense layer with a softmax activation outputs probabilities for four classes, making the model suitable for multi-class classification tasks.
This block of code compiles a Convolutional Neural Network (CNN) model, model_CNN, and prints its summary. The compile method configures the model for training by setting the optimizer (Adam), loss function (categorical_crossentropy), and evaluation metric (accuracy). These choices are suitable for multi-class classification tasks with one-hot encoded labels. The summary method is then called to print details of the model's architecture, including layers, output shapes, and parameter counts, giving an overview of the model's structure and complexity.
Train the CNN model
This block of code trains the CNN model using Keras's fit method. The train_generator provides training data in batches, and the model is trained for 40 epochs. The validation_data parameter, set to val_generator, evaluates the model's performance on a validation dataset at the end of each epoch. The class_weight parameter addresses class imbalance by assigning different weights to each class, preventing the model from favoring more frequent classes. The training process returns a history object, which contains details about training and validation loss and accuracy for each epoch, useful for further analysis and visualization.
Step 5:
Build a Vision Transformer Model
Understanding the Vision Transformer
ViT has 3 important aspects which involves splitting the image pixels into regular sized patches, applying a linear transformation on them and then adding positional embeddings to the patches to retain spatial embeddings as a trainable input to the neural network. The neural network consists of a standard Transformer which uses multihead splitted self attention mechanism for preserving better image characteristics.
Patch Embedding
The standard Transformer in case of NLP receives input as a 1D sequence of token embeddings. To handle 2D images, we reshape the image x∈R^{H×W×C} into a sequence of flattened 2D patches.Where, (H, W) is the resolution of the original image and (P, P) is the resolution of each image patch. N = HW/P² is then the effective sequence length for the Transformer. The image is split into fixed-size patches, in the image below, patch size is taken as 16×16. So the dimensions of the image will be 48×48.
Linear Transformation/Projection
The patches are then rolled out in a linear manner and passed into an Embedding layer to create Patched Embeddings. The Patched embedding matrix is created by multiplying the trainable embedding weight with the patches.
Positional Embedding
Position embeddings are added to the patched embeddings to retain positional information. We explore different 2D-aware variants of position embeddings without any significant gains over standard 1D position embeddings. The joint embedding serves as input to the Transformer encoder.Each unrolled patch (before Linear Projection) has a sequence of numbers associated with it, in this paper the authors chose it to 1,2,3,4…. no of patches. These numbers are nothing but learnable vectors. Each vector is parameterized and stacked row-wise to form a learnable positional embedding table.Similar to BERT which has [cls] tokens, the ViT is prepended with a learnable embedding to the sequence of embedded patches, whose state at the output of the Transformer encoder (zₗ⁰) serves as the image representation y. Both during pre-training and fine-tuning, the classification head is attached to zₗ⁰.
This block of code defines a custom Vision Transformer (ViT) model using TensorFlow and Keras. It includes three key components: TransformerBlock, MultiHeadAttention, and VisionTransformer. The TransformerBlock processes input features with self-attention and feedforward layers. MultiHeadAttention computes attention scores across multiple heads. The VisionTransformer divides images into patches, projects them into higher dimensions with positional and class embeddings, and processes them through TransformerBlocks. Finally, a multi-layer perceptron (MLP) head outputs predictions for image classification tasks. This model efficiently handles image classification using advanced deep learning techniques.
This block of code defines and initializes a Vision Transformer model for image classification. It processes input images resized to 224x224 pixels, dividing them into 16x16 patches. The model includes 12 Transformer layers with an embedding size of 128, 8 attention heads, and a feedforward network dimension of 256. It outputs probabilities for 4 classes using an MLP head with a dropout rate of 0.1 for regularization. The VisionTransformer class instance, model_VIT, is created with these parameters, ready for training on a classification task.
This block of code compiles the Vision Transformer (ViT) model, model_VIT, specifying the training configurations. It uses CategoricalCrossentropy as the loss function with from_logits=True, suitable for multi-class classification. The optimizer is Adam with a learning rate of 1e-4, aiding efficient gradient descent. The model's performance will be evaluated using four metrics: accuracy, precision, recall, and AUC, providing a comprehensive assessment of its predictive capabilities.
Train the model
This block of code trains the Vision Transformer model using the
train_generator and val_generator datasets for 40 epochs. It includes
early stopping and learning rate scheduling to optimize the training
process. The EarlyStopping callback monitors validation loss, halting
training if there's no improvement for 10 consecutive epochs, and
restores the best weights. The LearningRateScheduler callback adjusts
the learning rate dynamically, reducing it exponentially after the 10th
epoch for finer tuning. The class_weight parameter handles class
imbalances, and the training history is stored in history_model_VIT.
This block of code defines a function plot_performance that visualizes a machine learning model's training history using Plotly. It plots accuracy, precision, recall, AUC, and loss metrics for both training and validation datasets over epochs. The function creates a 3x2 grid of subplots with appropriate titles using make_subplots. For each metric, it adds two traces: one for training data and one for validation data, displaying lines and markers. The layout is adjusted for figure dimensions and title, and the legend is shown. The function uses fig.show() to render the interactive Plotly figure, offering a detailed visual assessment of the model's performance during training.
Evaluating the model with test data.
This block of code defines the evaluate_model function to assess a
trained machine learning model using a data generator. It calculates
true labels (y_true) and predicted labels (y_pred_classes). A confusion
matrix and classification report are generated for performance
evaluation. The confusion matrix is visualized as an annotated heatmap
in Plotly, with true labels on the y-axis and predicted labels on the
x-axis. The layout is optimized for readability with titles and margin
adjustments. The classification report, detailing precision, recall, and
F1-score for each class, is printed. The interactive Plotly heatmap is
displayed using fig.show().
This block of code calls the plot_performance function to
visualize the training and validation metrics from the history_CNN. It
generates plots for accuracy, precision, recall, AUC, and loss over the
training epochs, showing both training and validation values. The plots
are labeled with the title 'Custom CNN Model', helping to assess the
model's performance and track its learning progress during training.
This block of code calls the plot_performance function to
visualize the training and validation metrics for the Vision Transformer
model using the history_model_VIT. It generates line plots for
accuracy, precision, recall, AUC, and loss over the training epochs,
showing both training and validation metrics. The plots are labeled with
the title 'Vision Transformer', offering a clear visual representation
of the model's performance and learning progress during training.
This block of code calls the evaluate_model function to assess the
performance of model_CNN using the validation dataset from
val_generator. It computes the true labels (y_true) and predicted labels
(y_pred_classes). A confusion matrix and a classification report,
including precision, recall, and F1-score for each class, are generated.
The confusion matrix is visualized as a heatmap using Seaborn, with
true labels on the y-axis and predicted labels on the x-axis. The
classification report is printed, providing detailed performance metrics
on the validation dataset.
Saving the model checkpoint file.
This block of code calls the evaluate_model function to assess
the performance of the Vision Transformer model (model_VIT) using the
validation dataset from val_generator. It calculates the true labels
(y_true) and predicted labels (y_pred_classes). A confusion matrix is
generated to visualize the relationship between true and predicted
labels, and a classification report is produced, including precision,
recall, and F1-score for each class. The confusion matrix is displayed
as a heatmap using Seaborn, with true labels on the y-axis and predicted
labels on the x-axis. The classification report is printed, providing a
detailed evaluation of the model's performance on the validation
dataset.
Save transformer model
Step 6:
This block of code defines a process for predicting an image's
class using an ensemble method combining a CNN and a Vision Transformer.
The preprocess_image function loads and preprocesses the image to the
target size, normalizing pixel values. The predict_with_voting function
processes the image for both models, averages their predicted
probabilities, and determines the final class. The
display_image_with_prediction function shows the input image with the
predicted class label. The class_mapping dictionary translates class
indices to disease names.
Predict disease from images
This block of code snippet performs image classification using an
ensemble of a CNN and a Vision Transformer and displays the result. The
predict_with_voting function is called with the image path,
preprocessing the image to the required sizes (128x128 for CNN and
224x224 for ViT). It predicts class probabilities with both models,
averages them, and determines the final class by taking the argmax of
the averaged probabilities. The predicted class is mapped to its
corresponding disease name using the class_mapping dictionary. Finally,
the display_image_with_prediction function shows the input image with
the predicted disease name as the title.
Conclusion
The Banana Leaf Disease Detection project represents a significant leap forward in the field of agricultural technology. By harnessing the power of machine learning, computer vision, and remote sensing, this project addresses the critical need for early and accurate detection of banana leaf diseases. The implementation of automated detection systems not only enhances the efficiency of monitoring and managing crop health but also empowers farmers with timely and precise information.
Through the integration of advanced technologies, the project aims to reduce the economic losses caused by banana leaf diseases, ensuring higher yields and better quality produce. This technological advancement contributes to the sustainability of banana farming, promoting environmental stewardship by enabling targeted and minimal use of pesticides.
Moreover, the scalability of the solution ensures that farmers in diverse regions can benefit from the innovation, fostering a more resilient and robust agricultural sector. By equipping farmers with cutting-edge tools and knowledge, the project helps build a future where technology and traditional farming practices work hand in hand to ensure food security and economic prosperity.