Banana Leaf Disease Detection using Vision Transformer model
Banana cultivation is a significant agricultural activity in many tropical and subtropical regions, providing a vital source of income and nutrition. However, the health and productivity of banana plants are frequently threatened by various leaf diseases. These diseases can drastically reduce yield and quality, leading to substantial economic losses for farmers and impacting food security.
Traditional methods of disease detection in banana plants rely heavily on manual inspection by experts, which can be time-consuming, labor-intensive, and prone to human error. Additionally, by the time symptoms are visually noticeable, the disease may have already spread extensively, making effective management and control more challenging.
To address these challenges, the Banana Leaf Disease Detection project leverages advanced technologies such as machine learning, computer vision, and remote sensing. By utilizing these innovative approaches, the project aims to develop an automated and efficient system for early detection and diagnosis of banana leaf diseases.
Project Overview
Banana Leaf Disease Detection using Vision Transformer and CNN models aims to address a critical agricultural challenge. Banana plants are an essential source of food and income in tropical regions, but diseases can seriously harm their health and productivity. This project uses advanced technology to develop an efficient and accurate system for detecting diseases in banana leaves early.
The combination of Vision Transformer and CNN models offers a powerful approach to solving the problem. Traditional methods of disease detection are time-consuming and labor-intensive. By applying machine learning and computer vision, this project automates disease detection, reducing the reliance on manual inspection. The system helps farmers detect diseases at an early stage, leading to better disease management and increased crop yields.
Prerequisites
Before starting this project, it is important to have a basic understanding of machine learning and computer vision. Familiarity with deep learning frameworks such as TensorFlow and Keras is essential, as they will be used to build the models. A working knowledge of Python programming is required, as it will be used to implement the entire workflow. Experience with data preprocessing, model evaluation, and optimization is also beneficial.
You should have access to a dataset of banana leaf images, labeled with different types of diseases. The dataset will be used to train and validate the models. You will also need access to Google Colab or a similar platform for running the models, as well as the necessary packages such as TensorFlow, Keras, and Sklearn for building and training the models.
Approach
The approach to this project combines advanced machine learning techniques with practical agricultural applications. The project uses a hybrid model, combining the strengths of Vision Transformers and Convolutional Neural Networks (CNNs). Vision Transformers excel at capturing long-range dependencies in image data, while CNNs are powerful for extracting local features. By integrating these two models, the project creates a more robust system for banana leaf disease detection.
The project uses image data of banana leaves, which are preprocessed and fed into both models. Each model predicts the presence of a particular disease, and the final prediction is made using a voting system. This ensures that the predictions are more accurate, as it takes into account the strengths of both models.
Workflow and Methodology
The Overall Workflow of this Project Includes:
- Data Collection and Preparation: The first step involves collecting a dataset of banana leaf images, categorized by the type of disease present. After that, the dataset is prepared for model training.
- Model Building: Two models are built in parallel: a Vision Transformer model and a CNN model. Each model is trained separately to predict banana leaf diseases.
- Model Training: Both models are trained using the prepared dataset. Class weights are adjusted to account for any imbalances in the dataset.
- Model Evaluation: After training, both models are evaluated on a separate validation dataset to check their accuracy.
- Voting Mechanism: The final step involves using an ensemble voting mechanism to combine the predictions of both models and make a more accurate decision.
The Methodology Involves:
- Data Augmentation: Data augmentation techniques are applied to increase the diversity of the training data.
- Model Optimization: Learning rate schedules and early stopping mechanisms are implemented to ensure that the models are optimized during training.
- Performance Evaluation: The models' performance is evaluated using metrics such as precision, recall, accuracy, and AUC (Area Under the Curve).
Data Collection
The success of any machine learning project largely depends on the quality of the data used for training. For this project, a high-quality dataset of banana leaf images is essential. The dataset must include images of healthy banana leaves as well as leaves affected by various diseases like cordana, pestalotiopsis, and sigatoka.
Data Preparation
Once the dataset is collected, it must be carefully prepared before training the models. This involves splitting the dataset into training and validation sets, with 80% of the data used for training and 20% reserved for validation. It is important to ensure that the dataset is balanced so that all disease categories are represented equally.
Data augmentation techniques are used to increase the size of the dataset. These techniques include rotating, flipping, and zooming the images to create new variations. This helps to improve the robustness of the model by exposing it to a wider variety of inputs during training.
Data Preparation Workflow
-
Rescaling: All images are rescaled to a common size of 224x224 pixels to ensure consistency.
-
Normalization: The pixel values of the images are normalized to lie between 0 and 1. This helps the model converge faster during training.
-
Data Augmentation: Techniques such as rotation, zoom, and flipping are applied to increase the size of the dataset.
-
Splitting the Data: The dataset is split into training and validation sets in an 80:20 ratio.
-
Handling Class Imbalance: Class weights are calculated and applied to handle any imbalances in the dataset.
Explanation All Code
Step 1:
Import and install the necessary packages.
!pip install tensorflow
!pip install keras
Importing Libraries
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, MaxPooling2D, Dropout, GlobalAveragePooling2D, Dense, Input, Activation
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization, GlobalAveragePooling2D
from tensorflow.keras.applications import ResNet50, InceptionV3
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import Dense, LayerNormalization, Dropout, Rescaling
from tensorflow.keras.metrics import Precision, Recall, AUC
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.utils import class_weight
from keras.optimizers import Adam
from sklearn.metrics import classification_report, confusion_matrix, precision_score, recall_score, f1_score
import numpy as np
from PIL import Image
from keras.preprocessing.image import img_to_array, load_img
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.ndimage import sobel
from skimage.filters import scharr
import warnings
import plotly.express as px
import plotly.subplots as sp
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.figure_factory as ff
warnings.filterwarnings('ignore')
You can mount your Google Drive in a Google Colab notebook with this piece of code. This makes it easy to view files stored in Google Drive for data manipulation, analysis, and training models in the Colab environment.
#Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
Step 2:
Data collection and preparation
We split the dataset for training and testing. We used 80% data for traning and remaining 20% as test data.This block of code loads and preprocesses an image dataset from Google Drive for Keras model training. It sets up an ImageDataGenerator for data augmentation and rescaling pixel values to [0, 1]. The dataset is split into 80% for training and 20% for validation. Using flow_from_directory, it generates batches of 32 images resized to 224x224 pixels with categorical labels for multi-class classification.
# Read dataset from google drive
data_dir = '/content/drive/MyDrive/banana leaf/AugmentedSet'
datagen = ImageDataGenerator(validation_split=0.2, rescale=1./255)
# Create training and validation generators
train_generator = datagen.flow_from_directory(
data_dir,
target_size=(224, 224),
batch_size=32,
class_mode='categorical',
subset='training'
)
val_generator = datagen.flow_from_directory(
data_dir,
target_size=(224, 224),
batch_size=32,
class_mode='categorical',
subset='validation'
)
This block of code calculates and displays class weights to address class imbalance in a dataset and prints class labels for a sample batch. It uses class_weight.compute_class_weight from sklearn to calculate balanced class weights based on the training data, storing them in a dictionary with class indices as keys. The mapping of class names to numerical labels is printed, and a new dictionary, class_weights_with_labels, maps class names to their corresponding weights. Finally, it prints class names for a sample batch from the training generator by decoding the class indices, providing a quick check of class labels and class distribution.
# Compute class weights
class_weights = class_weight.compute_class_weight(
'balanced',
classes=np.unique(train_generator.classes),
y=train_generator.classes
)
class_weights = {i: class_weights[i] for i in range(len(class_weights))}
# Print the class indices to see the mapping
print("Class indices:", train_generator.class_indices)
# Map class indices to class names and display weights
class_labels = {v: k for k, v in train_generator.class_indices.items()}
class_weights_with_labels = {class_labels[i]: class_weights[i] for i in class_weights}
print("Class weights with labels:")
for class_name, weight in class_weights_with_labels.items():
print(f"Class: {class_name}, Weight: {weight:.4f}")
# you can also print the class labels for a sample batch
x, y = next(train_generator)
class_labels = {v: k for k, v in train_generator.class_indices.items()}
print("Sample batch class names:", [class_labels[np.argmax(label)] for label in y])
Step 3:
Display the images of every class
This block of code visualizes a batch of images from the training dataset with their class labels using Matplotlib and Seaborn. It sets the plot style to "dark" with sns.set(style="dark") and fetches a batch of images and labels from train_generator. A function plot_images is defined to display these images in a 4x8 grid, removing axes and setting each subplot's title to the corresponding class name (decoded from the one-hot encoded label). Finally, it calls plot_images to show the images and their labels, providing a visual check of the batch.
sns.set(style="dark")
# Fetch a batch of images and labels
images, labels = next(train_generator)
# Function to plot images with class names
def plot_images(images_arr, labels_arr, class_labels):
fig, axes = plt.subplots(4, 8, figsize=(20, 10))
axes = axes.flatten()
for img, lbl, ax in zip(images_arr, labels_arr, axes):
ax.imshow(img)
ax.axis('off')
class_name = class_labels[np.argmax(lbl)]
ax.set_title(class_name)
plt.tight_layout()
plt.show()
# Plot images with class names
plot_images(images, labels, class_labels)
Plot the grayscale images
This block of code defines and uses a function to visualize a batch of Sobel-filtered images with their class labels. The plot_sobel_images function takes original images, Sobel-filtered images, labels, and a class index-to-name mapping. It creates a 4x8 grid of subplots using Matplotlib, displaying each Sobel-filtered image in grayscale with its corresponding class name as the title, and turns off the axes for a cleaner look. The function is then called to show the Sobel-filtered images along with their class labels, offering a visual representation of the filtered images and their classes.
def apply_sobel(images_arr):
sobel_images = []
for img in images_arr:
gray_img = np.dot(img[...,:3], [0.2989, 0.5870, 0.1140])
sobel_x = sobel(gray_img, axis=0, mode='constant')
sobel_y = sobel(gray_img, axis=1, mode='constant')
sobel_img = np.hypot(sobel_x, sobel_y)
sobel_images.append(sobel_img)
return np.array(sobel_images)
sobel_images = apply_sobel(images)
def plot_sobel_images(original_images, sobel_images, labels_arr, class_labels):
fig, axes = plt.subplots(4, 8, figsize=(20, 10))
axes = axes.flatten()
for orig_img, sob_img, lbl, ax in zip(original_images, sobel_images, labels_arr, axes):
ax.imshow(sob_img, cmap='gray')
ax.axis('off')
class_name = class_labels[np.argmax(lbl)]
ax.set_title(class_name, color='black')
plt.tight_layout()
plt.show()
plot_sobel_images(images, sobel_images, labels, class_labels)
type(train_generator)
Choose AI Mode
The combination of Vision Transformer and CNN models offers significant advantages for banana leaf disease detection. One of the key benefits is higher accuracy, as the hybrid approach leverages the strengths of both models. Vision Transformers are excellent at capturing long-range patterns, allowing the model to generalize well across varying leaf conditions. On the other hand, CNNs excel at processing local features, making the detection process faster and more efficient. This combination provides a robust solution for early disease detection, ensuring precise and timely identification of diseases.
Other models such as ResNet, Inception, and EfficientNet could also be considered for this task. ResNet is effective at capturing detailed patterns but focuses mainly on local features, which may miss the broader context of leaf conditions. Inception networks excel at identifying small details but may struggle with understanding the larger image context. EfficientNet is computationally efficient and can provide good accuracy but lacks the capability to handle large-scale patterns as effectively as the Vision Transformer. Overall, the chosen combination of Vision Transformer and CNN provides a more comprehensive and accurate approach compared to these alternatives.
Step 4:
Build Custom CNN Model
Convolution Neural Network
A Convolutional Neural Network (CNN) processes an input image by first applying a series of convolutional layers, each consisting of a small filter (kernel) that slides over the image to extract features, followed by a ReLU activation function to introduce non-linearity. These layers are interspersed with pooling layers that reduce the dimensionality of the feature maps, making the computation more efficient while retaining essential information.
After several convolution and pooling operations, the resulting feature maps are flattened into a single vector and passed through fully connected layers, which act like a traditional neural network. The final output layer uses a SoftMax activation function to produce a probabilistic distribution over the possible classes, such as identifying whether the image is of a horse, zebra, or dog. This combination of feature extraction and classification allows CNNs to effectively recognize and categorize images.
This block of code defines a convolutional neural network (CNN) model using the Sequential API from Keras, designed for image classification tasks. The model starts with an input layer for 128x128 RGB images, followed by four blocks of convolutional layers with ReLU activation and batch normalization to stabilize learning. Each block has two convolutional layers, followed by max-pooling layers to downsample the feature maps and dropout layers to prevent overfitting. The number of filters increases progressively (64, 128, 256, 512) to capture more complex features. After the convolutional blocks, a global average pooling layer reduces the spatial dimensions, followed by fully connected dense layers with dropout and batch normalization to further reduce over-fitting and enhance learning stability. The final dense layer with a softmax activation outputs probabilities for four classes, making the model suitable for multi-class classification tasks.
model_CNN = Sequential([
Input(shape=(128, 128, 3)),
Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same'),
BatchNormalization(),
Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same'),
BatchNormalization(),
MaxPooling2D(pool_size=(2, 2)),
Dropout(0.3),
Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same'),
BatchNormalization(),
Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same'),
BatchNormalization(),
MaxPooling2D(pool_size=(2, 2)),
Dropout(0.4),
Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same'),
BatchNormalization(),
Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same'),
BatchNormalization(),
MaxPooling2D(pool_size=(2, 2)),
Dropout(0.4),
Conv2D(512, kernel_size=(3, 3), activation='relu', padding='same'),
BatchNormalization(),
Conv2D(512, kernel_size=(3, 3), activation='relu', padding='same'),
BatchNormalization(),
MaxPooling2D(pool_size=(2, 2)),
Dropout(0.5),
GlobalAveragePooling2D(),
Dropout(0.5),
Dense(512, activation='relu'),
BatchNormalization(),
Dropout(0.5),
Dense(256, activation='relu'),
BatchNormalization(),
Dropout(0.5),
Dense(4, activation='softmax')
])
This block of code compiles a Convolutional Neural Network (CNN) model, model_CNN, and prints its summary. The compile method configures the model for training by setting the optimizer (Adam), loss function (categorical_crossentropy), and evaluation metric (accuracy). These choices are suitable for multi-class classification tasks with one-hot encoded labels. The summary method is then called to print details of the model's architecture, including layers, output shapes, and parameter counts, giving an overview of the model's structure and complexity.
model_CNN.compile(
optimizer=Adam(learning_rate=1e-4),
loss='categorical_crossentropy',
metrics=[
'accuracy',
Precision(name='precision'),
Recall(name='recall'),
AUC(name='auc')
]
)
# Print the model summary
model_CNN.summary()
Train the CNN model
This block of code trains the CNN model using Keras's fit method. The train_generator provides training data in batches, and the model is trained for 40 epochs. The validation_data parameter, set to val_generator, evaluates the model's performance on a validation dataset at the end of each epoch. The class_weight parameter addresses class imbalance by assigning different weights to each class, preventing the model from favoring more frequent classes. The training process returns a history object, which contains details about training and validation loss and accuracy for each epoch, useful for further analysis and visualization.
history_CNN = model_CNN.fit(train_generator, epochs=40, validation_data=val_generator, class_weight=class_weights)
Step 5:
Build a Vision Transformer Model
Understanding the Vision Transformer
ViT has 3 important aspects which involve splitting the image pixels into regular-sized patches, applying a linear transformation on them, and then adding positional embeddings to the patches to retain spatial embeddings as a trainable input to the neural network. The neural network consists of a standard Transformer which uses a multihead splitted self attention mechanism for preserving better image characteristics.
Patch Embedding
The standard Transformer in case of NLP receives input as a 1D sequence of token embeddings. To handle 2D images, we reshape the image x∈R^{H×W×C} into a sequence of flattened 2D patches. Where, (H, W) is the resolution of the original image and (P, P) is the resolution of each image patch. N = HW/P² is then the effective sequence length for the Transformer. The image is split into fixed-size patches, in the image below, patch size is taken as 16×16. So the dimensions of the image will be 48×48.
Linear Transformation/Projection
The patches are then rolled out in a linear manner and passed into an Embedding layer to create Patched Embeddings. The Patched embedding matrix is created by multiplying the trainable embedding weight with the patches.
Positional Embedding
Position embeddings are added to the patched embeddings to retain positional information. We explore different 2D-aware variants of position embeddings without any significant gains over standard 1D position embeddings. The joint embedding serves as input to the Transformer encoder. Each unrolled patch (before Linear Projection) has a sequence of numbers associated with it, in this paper the authors chose to 1,2,3,4…. no patches. These numbers are nothing but learnable vectors. Each vector is parameterized and stacked row-wise to form a learnable positional embedding table. Similar to BERT which has [cls] tokens, the ViT is prepended with a learnable embedding to the sequence of embedded patches, whose state at the output of the Transformer encoder (zₗ⁰) serves as the image representation y. Both during pre-training and fine-tuning, the classification head is attached to zₗ⁰.
This block of code defines a custom Vision Transformer (ViT) model using TensorFlow and Keras. It includes three key components: TransformerBlock, MultiHeadAttention, and VisionTransformer. The TransformerBlock processes input features with self-attention and feedforward layers. MultiHeadAttention computes attention scores across multiple heads. The VisionTransformer divides images into patches, projects them into higher dimensions with positional and class embeddings, and processes them through TransformerBlocks. Finally, a multi-layer perceptron (MLP) head outputs predictions for image classification tasks. This model efficiently handles image classification using advanced deep learning techniques.
This block of code defines a custom Vision Transformer (ViT) model using TensorFlow and Keras. It includes three key components: TransformerBlock, MultiHeadAttention, and VisionTransformer. The TransformerBlock processes input features with self-attention and feedforward layers. MultiHeadAttention computes attention scores across multiple heads. The VisionTransformer divides images into patches, projects them into higher dimensions with positional and class embeddings, and processes them through TransformerBlocks. Finally, a multi-layer perceptron (MLP) head outputs predictions for image classification tasks. This model efficiently handles image classification using advanced deep learning techniques.
class TransformerBlock(tf.keras.layers.Layer):
def __init__(self, embed_dim, num_heads, feedforward_dim, dropout=0.1, regularizer=None):
super(TransformerBlock, self).__init__()
self.multiheadselfattention = MultiHeadAttention(embed_dim, num_heads, regularizer)
self.ffn = tf.keras.Sequential(
[
Dense(feedforward_dim, activation="relu", kernel_regularizer=regularizer),
Dense(embed_dim, kernel_regularizer=regularizer),
]
)
self.layernorm1 = LayerNormalization(epsilon=1e-6)
self.layernorm2 = LayerNormalization(epsilon=1e-6)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
def call(self, inputs, training=None):
out1 = self.layernorm1(inputs)
attention_output = self.multiheadselfattention(out1, training=training)
attention_output = self.dropout1(attention_output, training=training)
out2 = self.layernorm1(inputs + attention_output)
ffn_output = self.ffn(out2, training=training)
ffn_output = self.dropout2(ffn_output, training=training)
return self.layernorm2(out2 + ffn_output)
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, embed_dim, num_heads, regularizer=None):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.embed_dim = embed_dim
assert self.embed_dim % self.num_heads == 0
self.projection_dim = self.embed_dim // self.num_heads
self.query_dense = Dense(self.embed_dim, kernel_regularizer=regularizer)
self.key_dense = Dense(self.embed_dim, kernel_regularizer=regularizer)
self.value_dense = Dense(self.embed_dim, kernel_regularizer=regularizer)
self.combine_heads = Dense(self.embed_dim, kernel_regularizer=regularizer)
def self_attention(self, query, key, value):
q_kt = tf.matmul(query, key, transpose_b=True)
key_dims = tf.cast(self.embed_dim**(-0.5), tf.float32)
normalized_score = q_kt / key_dims
softmax_wts = tf.nn.softmax(normalized_score, axis=-1)
output = tf.matmul(softmax_wts, value)
return output, softmax_wts
def separate_heads(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs, training=None):
batch_size = tf.shape(inputs)[0]
query = self.query_dense(inputs)
key = self.key_dense(inputs)
value = self.value_dense(inputs)
query = self.separate_heads(query, batch_size)
key = self.separate_heads(key, batch_size)
value = self.separate_heads(value, batch_size)
attention, weights = self.self_attention(query, key, value)
attention = tf.transpose(attention, perm=[0, 2, 1, 3])
concat_attention = tf.reshape(attention, (batch_size, -1, self.embed_dim))
output = self.combine_heads(concat_attention)
return output
class VisionTransformer(tf.keras.Model):
def __init__(self, image_size, patch_size, num_layers, num_classes, d_model, num_heads, mlp_dim, channels=3, dropout=0.1, **kwargs):
super(VisionTransformer, self).__init__()
self.patch_size = patch_size
num_patches = (image_size // patch_size) ** 2
self.patch_dim = channels * patch_size ** 2
self.num_layers = num_layers
self.d_model = d_model
self.rescale = Rescaling(1./255)
self.pos_emb = self.add_weight(name="positional_emb", shape=(1, num_patches + 1, d_model))
self.cls_emb = self.add_weight(name="cls_embedding", shape=(1, 1, d_model))
self.patch_proj = Dense(d_model)
self.enc_layers = [
TransformerBlock(d_model, num_heads, mlp_dim, dropout)
for _ in range(num_layers)
]
self.mlp_head = tf.keras.Sequential(
[
Dense(mlp_dim, activation="gelu"),
Dropout(dropout),
Dense(num_classes),
]
)
def extract_patches(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patches = tf.reshape(patches, [batch_size, -1, self.patch_dim])
return patches
def call(self, x, training=None):
batch_size = tf.shape(x)[0]
x = self.rescale(x)
patches = self.extract_patches(x)
x = self.patch_proj(patches)
cls_emb = tf.broadcast_to(self.cls_emb, [batch_size, 1, self.d_model])
x = tf.concat([cls_emb, x], axis=1)
x = x + self.pos_emb
for layer in self.enc_layers:
x = layer(x, training=training)
# First (cls token) is used for classification
x = self.mlp_head(x[:, 0], training=training)
return x
This block of code defines and initializes a Vision Transformer model for image classification. It processes input images resized to 224x224 pixels, dividing them into 16x16 patches. The model includes 12 Transformer layers with an embedding size of 128, 8 attention heads, and a feedforward network dimension of 256. It outputs probabilities for 4 classes using an MLP head with a dropout rate of 0.1 for regularization. The VisionTransformer class instance, model_VIT, is created with these parameters, ready for training on a classification task.
# Vision Transformer model
image_size = 224
patch_size = 16
num_layers = 12
num_classes = 4
d_model = 128
num_heads = 8
mlp_dim = 256
dropout = 0.1
model_VIT = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
num_layers=num_layers,
num_classes=num_classes,
d_model=d_model,
num_heads=num_heads,
mlp_dim=mlp_dim,
dropout=dropout
)
This block of code compiles the Vision Transformer (ViT) model, model_VIT, specifying the training configurations. It uses CategoricalCrossentropy as the loss function with from_logits=True, suitable for multi-class classification. The optimizer is Adam with a learning rate of 1e-4, aiding efficient gradient descent. The model's performance will be evaluated using four metrics: accuracy, precision, recall, and AUC, providing a comprehensive assessment of its predictive capabilities.
# Compile the model
model_VIT.compile(
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
metrics=[
"accuracy",
Precision(name='precision'),
Recall(name='recall'),
AUC(name='auc')
]
)
Train the model
This block of code trains the Vision Transformer model using the train_generator and val_generator datasets for 40 epochs. It includes early stopping and learning rate scheduling to optimize the training process. The EarlyStopping callback monitors validation loss, halting training if there's no improvement for 10 consecutive epochs, and restores the best weights. The LearningRateScheduler callback adjusts the learning rate dynamically, reducing it exponentially after the 10th epoch for finer tuning. The class_weight parameter handles class imbalances, and the training history is stored in history_model_VIT.
# Train the model
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
def lr_schedule(epoch, lr):
if epoch > 10:
lr = lr * float(tf.math.exp(-0.1))
return lr
lr_scheduler = LearningRateScheduler(lr_schedule)
history_model_VIT = model_VIT.fit(
train_generator,
epochs=40,
validation_data=val_generator,
class_weight=class_weights,
callbacks=[lr_scheduler]
)
This block of code defines a function plot_performance that visualizes a machine learning model's training history using Plotly. It plots accuracy, precision, recall, AUC, and loss metrics for both training and validation datasets over epochs. The function creates a 3x2 grid of subplots with appropriate titles using make_subplots. For each metric, it adds two traces: one for training data and one for validation data, displaying lines and markers. The layout is adjusted for figure dimensions and title, and the legend is shown. The function uses fig.show() to render the interactive Plotly figure, offering a detailed visual assessment of the model's performance during training.
def plot_performance(history, title):
metrics = ['accuracy', 'precision', 'recall', 'auc']
fig = make_subplots(rows=3, cols=2, subplot_titles=[f'{title} {metric.capitalize()}' for metric in metrics] + [f'{title} Loss'])
for i, metric in enumerate(metrics):
row = (i // 2) + 1
col = (i % 2) + 1
fig.add_trace(
go.Scatter(
x=list(range(len(history.history[metric]))),
y=history.history[metric],
mode='lines+markers',
name=f'Train {metric}',
marker=dict(color='blue')
),
row=row, col=col
)
fig.add_trace(
go.Scatter(
x=list(range(len(history.history[f'val_{metric}']))),
y=history.history[f'val_{metric}'],
mode='lines+markers',
name=f'Validation {metric}',
marker=dict(color='orange')
),
row=row, col=col
)
fig.add_trace(
go.Scatter(
x=list(range(len(history.history['loss']))),
y=history.history['loss'],
mode='lines+markers',
name='Train Loss',
marker=dict(color='blue')
),
row=3, col=1
)
fig.add_trace(
go.Scatter(
x=list(range(len(history.history['val_loss']))),
y=history.history['val_loss'],
mode='lines+markers',
name='Validation Loss',
marker=dict(color='orange')
),
row=3, col=1
)
fig.update_layout(height=800, width=1200, title_text=title, showlegend=True)
fig.show()
Evaluating the model with test data.
This block of code defines the evaluate_model function to assess a trained machine learning model using a data generator. It calculates true labels (y_true) and predicted labels (y_pred_classes). A confusion matrix and classification report are generated for performance evaluation. The confusion matrix is visualized as an annotated heatmap in Plotly, with true labels on the y-axis and predicted labels on the x-axis. The layout is optimized for readability with titles and margin adjustments. The classification report, detailing precision, recall, and F1-score for each class, is printed. The interactive Plotly heatmap is displayed using fig.show().
def evaluate_model(model, generator):
y_true = generator.classes
y_pred = model.predict(generator)
y_pred_classes = np.argmax(y_pred, axis=1)
cm = confusion_matrix(y_true, y_pred_classes)
cr = classification_report(y_true, y_pred_classes, target_names=generator.class_indices.keys())
# Create a confusion matrix with Plotly
z = cm
x = list(generator.class_indices.keys())
y = list(generator.class_indices.keys())
z_text = [[str(y) for y in x] for x in z]
fig = ff.create_annotated_heatmap(z, x=x, y=y, annotation_text=z_text, colorscale='Blues')
# Add title
fig.update_layout(title_text='Confusion Matrix',
xaxis_title='Predicted',
yaxis_title='True',
yaxis=dict(tickmode='array', tickvals=list(range(len(y))), ticktext=y),
xaxis=dict(tickmode='array', tickvals=list(range(len(x))), ticktext=x))
# Adjust margins to make room for y-axis title
fig.update_layout(margin=dict(t=100, l=200))
fig.show()
print('Classification Report:')
print(cr)
This block of code calls the plot_performance function to visualize the training and validation metrics from the history_CNN. It generates plots for accuracy, precision, recall, AUC, and loss over the training epochs, showing both training and validation values. The plots are labeled with the title 'Custom CNN Model', helping to assess the model's performance and track its learning progress during training.
plot_performance(history_CNN, 'Custom CNN Model')
This block of code calls the plot_performance function to visualize the training and validation metrics for the Vision Transformer model using the history_model_VIT. It generates line plots for accuracy, precision, recall, AUC, and loss over the training epochs, showing both training and validation metrics. The plots are labeled with the title 'Vision Transformer', offering a clear visual representation of the model's performance and learning progress during training.
plot_performance(history_model_VIT, 'Vision Transformer')
This block of code calls the evaluate_model function to assess the performance of model_CNN using the validation dataset from val_generator. It computes the true labels (y_true) and predicted labels (y_pred_classes). A confusion matrix and a classification report, including precision, recall, and F1-score for each class, are generated. The confusion matrix is visualized as a heatmap using Seaborn, with true labels on the y-axis and predicted labels on the x-axis. The classification report is printed, providing detailed performance metrics on the validation dataset.
evaluate_model(model_CNN, val_generator)
Saving the model checkpoint file.
model_CNN.save('model_CNN.h5')
This block of code calls the evaluate_model function to assess the performance of the Vision Transformer model (model_VIT) using the validation dataset from val_generator. It calculates the true labels (y_true) and predicted labels (y_pred_classes). A confusion matrix is generated to visualize the relationship between true and predicted labels and a classification report is produced, including precision, recall, and F1-score for each class. The confusion matrix is displayed as a heatmap using Seaborn, with true labels on the y-axis and predicted labels on the x-axis. The classification report is printed, providing a detailed evaluation of the model's performance on the validation dataset.
evaluate_model(model_VIT, val_generator)
Save transformer model
model_VIT.save('vit_model.keras')
Step 6:
This block of code defines a process for predicting an image's class using an ensemble method combining a CNN and a Vision Transformer. The preprocess_image function loads and preprocesses the image to the target size, normalizing pixel values. The predict_with_voting function processes the image for both models, averages their predicted probabilities, and determines the final class. The display_image_with_prediction function shows the input image with the predicted class label. The class_mapping dictionary translates class indices to disease names.
def preprocess_image(image_path, target_size):
img = load_img(image_path, target_size=target_size)
img_array = img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = img_array / 255.0
return img_array
class_mapping = {
0: "cordana",
1: "healthy",
2: "pestalotiopsis",
3: "sigatoka"
}
def predict_with_voting(image_path):
input_size_CNN = (128, 128) # Target size for CNN model
input_size_VIT = (224, 224) # Target size for ViT model
img_CNN = preprocess_image(image_path, input_size_CNN)
img_VIT = preprocess_image(image_path, input_size_VIT)
pred_CNN = model_CNN.predict(img_CNN)
pred_VIT = model_VIT.predict(img_VIT)
final_pred_prob = (pred_CNN + pred_VIT) / 2
final_pred_class = np.argmax(final_pred_prob, axis=1)
final_pred_disease = [class_mapping[class_index] for class_index in final_pred_class]
return final_pred_disease
def display_image_with_prediction(image_path, predicted_class):
img = load_img(image_path)
plt.imshow(img)
plt.title(f"Predicted Class: {predicted_class[0]}")
plt.axis('off')
plt.show()
Predict disease from images
This block of code snippet performs image classification using an ensemble of a CNN and a Vision Transformer and displays the result. The predict_with_voting function is called with the image path, preprocessing the image to the required sizes (128x128 for CNN and 224x224 for ViT). It predicts class probabilities with both models, averages them, and determines the final class by taking the argmax of the averaged probabilities. The predicted class is mapped to its corresponding disease name using the class_mapping dictionary. Finally, the display_image_with_prediction function shows the input image with the predicted disease name as the title.
image_path = '/content/drive/MyDrive/banana leaf/Datasets/validation/cordana/167.jpeg'
predicted_class = predict_with_voting(image_path)
display_image_with_prediction(image_path, predicted_class)
Project Conclusion
Banana Leaf Disease Detection using Vision Transformer and CNN models demonstrates the potential of machine learning in solving real-world agricultural problems. The project's innovative approach, which combines two powerful models, creates an efficient system for detecting banana leaf diseases. By automating the process of disease detection, the system saves time and labor, enabling farmers to take timely action and protect their crops.
Challenges and Troubleshooting
-
Implementing this project comes with its own set of challenges, especially related to data preprocessing and model optimization. One of the most common challenges is dealing with imbalanced datasets. If the dataset contains significantly more images of healthy leaves than diseased leaves, the model may become biased toward predicting healthy leaves. To overcome this, class weights are used to ensure that the model treats each class equally during training.
-
Another challenge is the complexity of training deep learning models such as Vision Transformers. These models require significant computational resources, which can slow down the training process. To address this, early stopping and learning rate scheduling techniques are employed to optimize training. Monitoring the model's performance during training and adjusting the hyperparameters can also help improve the model's accuracy.
-
Additionally, fine-tuning both the CNN and Vision Transformer models can be tricky, especially when integrating them into a voting ensemble. Ensuring that both models contribute meaningfully to the final prediction requires careful tuning of model weights and hyperparameters.
FAQ
-
What is the main goal of this project?
-
Answer: The goal is to develop a system that can accurately detect diseases in banana leaves using machine learning models. The project leverages Vision Transformer and CNN models to improve accuracy.
-
-
Why use both CNN and Vision Transformer models?
-
Answer: CNNs are excellent at detecting local patterns in images, while Vision Transformers excel at capturing global features. Combining both models provides a more comprehensive analysis of the image data.
-
-
What types of banana leaf diseases are detected?
-
Answer: The models are trained to detect four types of banana leaf diseases: cordana, pestalotiopsis, sigatoka, and healthy leaves.
-
-
How is data collected for this project?
-
Answer: A dataset of banana leaf images is collected, with images labeled according to the type of disease present.
-
-
What techniques are used to handle class imbalance in the dataset?
-
Answer:Class weights are calculated and applied during training to ensure that each class is represented equally.
-