Medical Image Segmentation With UNET
Have you ever thought about how doctors are so precise in diagnosing any conditions based on medical images? Quite simply, it's not alchemy. They rely on sophisticated devices such as U-Net. Which is a deep learning architecture designed for medical image segmentation. It's as if shoving powers in doctors' hands to make them speedy and accurate treatment. And it's simply awesome!
Here in this project, we explore the workings of U-Net and employ it in MRI, CT, and X-ray images. Enjoy the trip through data, coding, and highly advanced medical technology that is greatly helping people.
Project Overview
This is an interesting project that we have taken on as a challenge within the medical field. The task that we seek to address is Medical image segmentation. The task includes accurately marking objects like tumors and organs in the images obtained with MRI, CT, and X−ray using the U-Net model.
U-Net architecture is well-suited for the specific task at hand due to the two-part architecture, It allows images to be segmented at pixel level while maintaining the resolution of the images by capturing all the details. This project shows how to work with medical images, train the U-Net model, and run on the datasets.
Here’s what we'll cover:
- Different image preprocessing techniques
- U-Net model structure and function
- Model training and testing
- The challenges we faced and how to solve them.
Prerequisites
Before embarking on this project, ensure that you possess the following foundational components:
- An understanding of Python programming and usage of Google Colab
- Basic knowledge about deep learning and medical images.
- Comfortable using frameworks like Tensorflow, Keras, Numpy, OpenCV, and Matplotlib to handle data and build models and visualize data and performance of models
- Familiarity with Semantic Segmentation and its role in areas like medical imaging and diagnosis.
- Comfortable with evaluation metrics specifically Mean Intersection over Union (IoU) metrics.
- Availability of jupyter notebook/google colab for the task at hand.
Approach
In this project, we take a detailed step-by-step approach to medical image segmentation using the U-Net model. First of all, the images are loaded and preprocessed for them to be fit for model training.
Then, we design a custom data generator. After that, we can use large datasets without challenges. Then we use flipping and rotation augmentations for further enhancement of the training effort. Next, we build the U-Net architecture. It functions with encodes that downscale the image content and decodes that restore every pixel of the content.
For training the model we use keras. Then we save only the best model callbacks and modification of the learning rate. As training occurs, metrics such as accuracy, mean and standard deviation of IoU are observed to evaluate the model. After training, the U-Net is used to predict segmentation masks. The images then are put into the original images to see how well the model localizes certain areas of interest in the medical scans. At last, the Mean Intersection over Union (IoU) is computed to assess the performance of the predictions for the various classes.
Workflow and Methodology
The overall workflow of this project includes
- Data Collection: In this project, we collect publicly available data containing images and masks.
- Data Preprocessing: Next we process data. Resize, and convert the images to the appropriate color space (HSV, RGB, or grayscale). Then normalize the image to improve model performance
- Model Design: U-Net architecture is designed to perform image segmentation. The encoder is responsible for capturing features, while the decoder works to reconstruct the image at a pixel level.
- Training: Training the U-Net model using the prepared training dataset. The model is evaluated with a validation set to fine-tune values and prevent overfitting.
- Evaluation: We test with the unseen dataset to assess its ability to accurately detect diseases. IoU is used for performance evaluation.
- Visualization: Overlay the predicted segmentation masks onto the original medical images to facilitate easier interpretation of the results.
The methodology involves
- Data Preprocessing: First, images and their corresponding masks are resized to the appropriate input sizes to U-Net architecture. Then pixel values are scaled to the standardized range of 0-1 for the purpose of uniformity.
- Model Architecture: Implemented the U-Net architecture that is most appropriate for this task. Because it preserves the spatial resolution of the input which is good in detail segmentation.
- Metrics: Applied the Mean IoU metric to evaluate the model to make sure that each of the regions in the medical images was correctly segmented.
- Visualization: Showed the results of segmentation by placing the predicted mask on top of the origin image.
Data Collection
First of all, it is necessary to gather a set of RGB images. More so, some preprocessing stages like image resizing can also improve the performance of a model.
Data Preparation
- Resizing: Every image and mask is resized into 128x128 dimensions.
- Normalization: Images are normalized by dividing pixel values by 255 so that they are scaled to a range between 0 and 1.
- Color Conversion: Depending on the dataset, images are converted to different color spaces like HSV, RGB, or grayscale for optimal performance.
- Mask Encoding: In order to assign classes to the encoded masks performed on the RGB image, mapping of pixel values to respective encoded classes is devised.
Data Preparation Workflow
- The images and masks are imported from the dataset.
- Images and masks are rescaled to a suitable size.
- The pixel values are adjusted to a target range.
- The segmentation mask labels are transformed into integers.
- The pre-processed images and masks are then passed to a custom data generator to facilitate training efficiently.
Code Explanation
STEP 1:
Connecting Google Drive
You can mount your Google Drive in a Google Colab notebook. This makes it easy to view files saved in Google Drive. In Colab, you can change and analyze data. You can also train models.
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
Install Necessary libraries
Install libraries like TensorFlow, Keras, and utils. For numerical operations, image processing, machine learning, and visualization.
!pip install keras
!pip install utils
!pip install tensorflow
Import Necessary libraries
Import necessary libraries like numpy, tensorflow, matplotlib etc. These libraries will help with computational processes. Also, it will help to build and train models. After that, we can visualize results through these libraries.
import numpy as np
from tensorflow.keras.utils import Sequence
import cv2
import tensorflow as tf
import pickle
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import sklearn
from sklearn.cluster import KMeans
from tensorflow.keras.layers import *
from tensorflow.keras import models
from tensorflow.keras.callbacks import *
import glob2
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
from tensorflow.keras.metrics import MeanIoU
STEP 4:
Defines utility functions cvtColor and func.
Here is the code of two utility functions. If the values of the pixels are less than 255, the first function ('cvtColor') sets the values to 0. Applying the 'cvtColor' function to every pixel in an image is what the second function ('func') does.
# Define a function to convert color space to RGB
def cvtColor(x):
x[x < 255] = 0
return x
# Define a function to apply cvtColor to each pixel in the image
def func(img):
d = list(map(lambda x: cvtColor(x), img.reshape(-1,3)))
return np.array(d).reshape(*img.shape[:-1], 3)
DataGenerator Class for Batch Processing and Preprocessing
The DataGenerator class generates data batches for model training, with its constructor initializing arguments like data filenames, input and batch sizes, shuffle options, color mode, encoding dictionary, and optional processing functions. The processing method encodes masks using the provided dictionary. The __len__ method calculates the number of batches per epoch based on dataset and batch size, while the __getitem__ method retrieves a subset of filenames by batch index and uses data_generation to load and preprocess images and masks. The on_epoch_end method updates indices and shuffles them after each epoch. Finally, the data_generation method handles loading and preprocessing images and masks, adjusts sizes, manages color modes (HSV, RGB, grayscale), applies optional processing, and normalizes pixel values for the current batch.
class DataGenerator(Sequence):
def __init__(self, all_filenames, input_size=(128, 128), batch_size=8, shuffle=True, seed=123, encode: dict = None, color_mode='hsv', function=None) -> None:
super(DataGenerator, self).__init__()
# Check if the encoding dictionary is provided
assert encode != None, 'Not empty !'
# Check if the color mode is valid
assert color_mode == 'hsv' or color_mode == 'rgb' or color_mode == 'gray'
# Initialize instance variables
self.all_filenames = all_filenames
self.input_size = input_size
self.batch_size = batch_size
self.shuffle = shuffle
self.color_mode = color_mode
self.encode = encode
self.function = function
# Set random seed for shuffling
np.random.seed(seed)
# Shuffle the data at the start
self.on_epoch_end()
def processing(self, mask):
# Encode mask based on the provided dictionary
d = list(map(lambda x: self.encode[tuple(x)], mask.reshape(-1, 3)))
return np.array(d).reshape(*self.input_size, 1)
def __len__(self):
# Calculate the number of batches per epoch
return int(np.floor(len(self.all_filenames) / self.batch_size))
def __getitem__(self, index):
# Generate one batch of data
indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
all_filenames_temp = [self.all_filenames[k] for k in indexes]
X, Y = self.__data_generation(all_filenames_temp)
return X, Y
def on_epoch_end(self):
# Update indexes after each epoch
self.indexes = np.arange(len(self.all_filenames))
if self.shuffle == True:
np.random.shuffle(self.indexes)
def __data_generation(self, all_filenames_temp):
# Generates data containing batch_size samples
# Initialize arrays for images and masks
batch = len(all_filenames_temp)
if self.color_mode == 'gray':
X = np.empty(shape=(batch, *self.input_size, 1))
else:
X = np.empty(shape=(batch, *self.input_size, 3))
Y = np.empty(shape=(batch, *self.input_size, 1))
# Iterate over the filenames in the current batch
for i, (fn, label_fn) in enumerate(all_filenames_temp):
# Load and preprocess image
img = cv2.imread(fn)
if self.color_mode == 'hsv':
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
elif self.color_mode == 'rgb':
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
elif self.color_mode == 'gray':
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = tf.expand_dims(img, axis=2)
img = tf.image.resize(img, self.input_size, method='nearest')
img = tf.cast(img, tf.float32)
img /= 255.
# Load and preprocess mask
mask = cv2.imread(label_fn, 0)
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
mask = tf.image.resize(mask, self.input_size, method='nearest')
mask = np.array(mask)
if self.function:
mask = self.function(mask)
mask = self.processing(mask)
mask = tf.cast(mask, tf.float32)
# Assign images and masks to the arrays
X[i,] = img
Y[i,] = mask
return X, Y
Converts pixel-wise labels and saves the mapping in a pickle file
The 'encode_label' function transforms a mask into a 2D array of pixel values, creating unique labels. Constructs an encoder dictionary, and saves the dictionary to a pickle file, preserving the mapping for later use.
def encode_label(mask):
# input (batch, rows, cols, channels)
# Initialize an empty list to store unique labels
label = []
# Iterate over each pixel in the mask
for i in mask.reshape(-1, 3):
# Convert each pixel to a tuple and append it to the label list
label.append(tuple(i))
# Convert the list of tuples to a set to get unique labels
label = set(label)
# Create an encoder dictionary where keys are unique labels and values are their indices
encoder = dict((j, i) for i, j in enumerate(label)) # key is tuple
# Save the encoder dictionary to a pickle file
with open('label.pickle', 'wb') as handle:
pickle.dump(encoder, handle, protocol=pickle.HIGHEST_PROTOCOL)
# Return the encoder dictionary
return encoder
# Print the function reference (not calling the function)
print(encode_label)
Decodes model predictions back to pixel-wise labels using the saved mapping.
The function converts predicted values into labels. Reshapes them into an image with 3 channels, and returns the resulting image.
def decode_label(predict, label):
# Convert predicted values to labels using argmax along the channel axis
predict = np.argmax(predict, axis=3)
# Map label indices to label values using the provided label dictionary
d = list(map(lambda x: label[int(x)], predict.reshape(-1, 1)))
# Reshape the decoded labels into an image shape with 3 channels
img = np.array(d).reshape(*predict.shape, 3)
# Return the decoded image
return img
# Print the function reference (not calling the function)
print(decode_label)
STEP 5:
Data Loading
This function loads and preprocesses a selection of masks for label encoding. Then uses the 'encode_label' function to build label dictionaries. The function prepares instances of the DataGenerator class for both training and validation data. This ensures the data is shuffled and correctly preprocessed.
def DataLoader(all_train_filename, all_mask, all_valid_filename=None, input_size=(128, 128), batch_size=4, shuffle=True, seed=123, color_mode='hsv', function=None) -> None:
# Randomly select a subset of masks for encoding labels
mask_folder = sklearn.utils.shuffle(all_mask, random_state=47)[:16]
# Load and resize the masks
mask = [tf.image.resize(cv2.cvtColor(cv2.imread(img), cv2.COLOR_BGR2RGB), input_size, method='nearest') for img in mask_folder]
mask = np.array(mask)
# Apply preprocessing function to masks if provided
if function:
mask = function(mask)
# Encode the masks to create label dictionaries
encode = encode_label(mask)
# Create DataGenerator for training data
train = DataGenerator(all_train_filename, input_size, batch_size, shuffle, seed, encode, color_mode, function)
# If validation filenames are provided, create DataGenerator for validation data
if all_valid_filename is None:
return train, None
else:
valid = DataGenerator(all_valid_filename, input_size, batch_size, shuffle, seed, encode, color_mode, function)
return train, valid
# Print the function reference (not calling the function)
print(DataLoader)
Downsampling U-Net model block
The down_block function is a downsampling operation intended for the U-Net model. It takes in a tensor and applies two convolutional layers with Batch Normalization and Leaky ReLU activations before performing an optional max pooling to reduce the spatial dimensions by half. It produces the downsampled output tensor along with the input tensor for skip connections. These are useful since they keep important information for the later parts of the network.
def down_block(x, filters, use_maxpool=True):
# Apply two convolutional layers with specified filters and LeakyReLU activation
x = Conv2D(filters, 3, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)
x = Conv2D(filters, 3, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)
# Optionally apply MaxPooling2D with a stride of (2, 2)
if use_maxpool:
return MaxPooling2D(strides=(2, 2))(x), x
else:
return x
# Print the function reference (not calling the function)
print(down_block)
Defines an upsampling block in the U-Net model.
The function defines a convolutional neural network block for upsampling, which involves concatenating input feature maps, applying convolutional layers, batch normalization, and LeakyReLU activation, and returning the output feature map, without calling itself.
def up_block(x, y, filters):
# Upsample the input feature map
x = UpSampling2D()(x)
# Concatenate the upsampled feature map with the corresponding feature map from the contracting path
x = Concatenate(axis=3)([x, y])
# Apply two convolutional layers with specified filters and LeakyReLU activation
x = Conv2D(filters, 3, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)
x = Conv2D(filters, 3, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)
# Return the output feature map
return x
# Print the function reference (not calling the function)
print(up_block)
Defines the U-Net model architecture using TensorFlow/Keras.
The Unet function builds the U-Net model for image segmentation. It first defines the input layer and then builds the encoding path with downsampling blocks capturing the features and minimizing the loss of spatial information. Then, a decoding path is built where upsampling blocks are used to bring in the information from the encoding path using skip connections.
The last fully connected layer is regularized using dropout, before using the softmax activation function to make class predictions. After explaining all the steps, the function concludes with a summary of the model along with details of training and inference.
def Unet(input_size=(128, 128, 3), *, classes, dropout):
# Define the number of filters for each downsampling level
filters = [64, 128, 256, 512, 1024]
# Input layer
input = Input(shape=input_size)
# Encoding path
# Apply down_block for each downsampling level and store the skip connections
x, temp1 = down_block(input, filters[0])
x, temp2 = down_block(x, filters[1])
x, temp3 = down_block(x, filters[2])
x, temp4 = down_block(x, filters[3])
x = down_block(x, filters[4], use_maxpool=False) # Last down_block without max pooling
# Decoding path
# Apply up_block for each upsampling level using the stored skip connections
x = up_block(x, temp4, filters[3])
x = up_block(x, temp3, filters[2])
x = up_block(x, temp2, filters[1])
x = up_block(x, temp1, filters[0])
# Apply dropout
x = Dropout(dropout)(x)
# Output layer
output = Conv2D(classes, 1, activation='softmax')(x)
# Define and summarize the model
model = models.Model(input, output, name='unet')
model.summary()
# Return the model
return model
# Print the function reference (not calling the function)
print(Unet)
Implement Mean Intersection over Union (IoU) metrics.
The class m_iou can be used as a convenient code structure for the evaluation of segmentation models in terms of mean IoU and class IoU. It helps to assess the accuracy of any pixel-wise classification task.
from tensorflow.keras.metrics import MeanIoU
import numpy as np
class m_iou():
def __init__(self, classes: int) -> None:
# Initialize the number of classes
self.classes = classes
def mean_iou(self, y_true, y_pred):
# Compute mean IoU metric using Keras's MeanIoU
y_pred = np.argmax(y_pred, axis=3)
miou_keras = MeanIoU(num_classes=self.classes)
miou_keras.update_state(y_true, y_pred)
return miou_keras.result().numpy()
def miou_class(self, y_true, y_pred):
# Compute IoU for each class
y_pred = np.argmax(y_pred, axis=3)
miou_keras = MeanIoU(num_classes=self.classes)
miou_keras.update_state(y_true, y_pred)
values = np.array(miou_keras.get_weights()).reshape(self.classes, self.classes)
for i in range(self.classes):
class_iou = values[i, i] / (sum(values[i, :]) + sum(values[:, i]) - values[i, i])
print(f'IoU for class {str(i + 1)} is: {class_iou}')
Plots training and validation loss, accuracy, and mean IoU over epochs.
This function shows a model's training history. This code plots training and validation loss, accuracy, and mean IoU over epochs. It provides insight into the model's training progress.
import matplotlib.pyplot as plt
def show_history(history, validation: bool = False):
if validation:
# If validation data is available, plot training and validation metrics
# Loss plot
fig, axes = plt.subplots(figsize=(20, 5))
axes.plot(history.epoch, history.history['loss'], color='r', label='Train')
axes.plot(history.epoch, history.history['val_loss'], color='b', label='Val')
axes.set_xlabel('Epoch')
axes.set_ylabel('Loss')
axes.legend()
# Accuracy plot
fig, axes = plt.subplots(figsize=(20, 5))
axes.plot(history.epoch, history.history['acc'], color='r', label='Train')
axes.plot(history.epoch, history.history['val_acc'], color='b', label='Val')
axes.set_xlabel('Epoch')
axes.set_ylabel('Acc')
axes.legend()
# Mean IoU plot
fig, axes = plt.subplots(figsize=(20, 5))
axes.plot(history.epoch, history.history['mean_iou'], color='r', label='Train')
axes.plot(history.epoch, history.history['val_mean_iou'], color='b', label='Val')
axes.set_xlabel('Epoch')
axes.set_ylabel('MeanIoU')
axes.legend()
else:
# If no validation data is available, plot only training metrics
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
# Loss plot
axes[0].plot(history.epoch, history.history['loss'])
axes[0].set_title('Train')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
# Accuracy plot
axes[1].plot(history.epoch, history.history['acc'])
axes[1].set_title('Train')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Acc')
# Mean IoU plot
axes[2].plot(history.epoch, history.history['mean_iou'])
axes[2].set_title('Train')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('MeanIoU')
Model performance inference function creation.
The predict function processes an input image, normalizes it, and uses a trained U-Net model to generate a segmentation mask. It then prepares an image for visualization. Then it combines the original and predicted outputs and returns both for further analysis.
def predict(model, image_test, label, color_mode, size):
# Read and preprocess the test image based on the specified color mode and size
image = cv2.imread(image_test)
if color_mode == 'hsv':
image_cvt = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
elif color_mode == 'rgb':
image_cvt = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif color_mode == 'gray':
image_cvt = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image_cvt = tf.expand_dims(image_cvt, axis=2)
image_cvt = tf.image.resize(image_cvt, size, method='nearest')
image_cvt = tf.cast(image_cvt, tf.float32)
image_norm = image_cvt / 255.
image_norm = tf.expand_dims(image_norm, axis=0)
# Make prediction using the model
new_image = model(image_norm)
# Decode the predicted labels and combine with the original image for visualization
image_decode = decode_label(new_image, label)
predict_img = tf.cast(tf.image.resize(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), size, method='nearest'), tf.float32) * 0.7 + image_decode * 0.3
# Return the predicted image and the raw prediction
return np.floor(predict_img)[0].astype('int'), new_image
Displays the original image, ground truth mask, and predicted mask.
The function takes in input size, color mode, trained model, label encoding information, optional ground truth mask, and preprocessing function. It preprocesses the image, creates a prediction, and displays the original image, model prediction, and ground truth mask if available. It also calculates and outputs the class-wise Intersection over Union.
def show_example(image, mask, model, label, inp_size, color_mode, function):
# Read and preprocess the input image
img = cv2.imread(image)
img = tf.image.resize(img, inp_size, method='nearest')
# Generate prediction using the provided model
pred, _pred = predict(model, image, label, color_mode, inp_size)
if mask is not None:
# If ground truth mask is provided, visualize ground truth and prediction
# Read and preprocess the ground truth mask
msk = cv2.imread(mask)
msk = tf.image.resize(msk, inp_size, method='nearest')
# Apply preprocessing function to the mask if provided
if function:
msk = tf.convert_to_tensor(function(msk.numpy()))
# Compute and print class-wise IoU
m.miou_class(train_data.processing(msk.numpy()), _pred)
# Create a visualization of original image, ground truth, and prediction
ground_truth = np.floor(img.numpy() * 0.7 + msk.numpy() * 0.3).astype('int')
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
axes[0].imshow(img)
axes[0].set_title('Original Image')
axes[1].set_title('Ground truth')
axes[1].imshow(ground_truth)
axes[2].set_title('Prediction')
axes[2].imshow(pred)
else:
# If no ground truth mask is provided, visualize only the prediction
fig, axes = plt.subplots(1, 2, figsize=(12, 3))
axes[0].imshow(img)
axes[0].set_title('Original Image')
axes[1].set_title('Prediction')
axes[1].imshow(pred)
STEP 6:
Load data and ready for training
The code loads and shuffles file paths for images and masks. Then splits the data into training and validation sets. It then creates instances of the data generator for training and validation, preparing the data for processing.
# Get a list of all image file paths
images = sorted(glob2.glob('/content/drive/seg/segmentation_datasets/image/*.png'))
# Get a list of all mask file paths
masks = sorted(glob2.glob('/content/drive/seg/segmentation_datasets/mask/*.png'))
# Combine image and mask file paths into tuples
data = list(zip(images, masks))
# Shuffle the list of tuples using a random seed for reproducibility
data = shuffle(data, random_state=42)
# Determine the index to split the data into training and validation sets
split = int(0.8 * len(data))
# Divide the data into training and validation sets
all_train_filenames = data[:split]
all_valid_filenames = data[split:]
STEP 7:
U-Net model building and saving
In this code, the 'DataLoader' function generates training and validation data generators. The UNet model is instantiated using the UNet function with specified input size, classes, and dropout rate.
# Create training and validation data generators using DataLoader function
train_data, valid_data = DataLoader(all_train_filenames, masks, all_valid_filenames, (128, 128), 8, True, 47, 'gray', function=func)
# Define input size for the UNet model
inp_size = (128, 128, 1)
# Instantiate the UNet model
unet = Unet(inp_size, classes=2, dropout=0.3)
m = m_iou(2)
Save the model
In this code, the ModelCheckpoint callback optimizes the best model during training. Storing it in a specified file path. The ReduceLROnPlateau callback adjusts the learning rate based on monitored loss.
# Define a ModelCheckpoint callback to save the best model during training
checkpoint = ModelCheckpoint('/content/drive/aionlinecourse/unet.h5',
monitor='val_mean_iou', # Monitor validation mean IoU
save_best_only=True, # Save only the best model
verbose=1, # Verbosity mode
mode='max') # Maximizing mode for validation mean IoU
# Define a ReduceLROnPlateau callback to adjust learning rate during training
lr_R = ReduceLROnPlateau(monitor='loss', # Monitor loss
patience=4, # Number of epochs with no improvement before adjusting learning rate
verbose=1, # Verbosity mode
factor=0.3, # Factor by which the learning rate will be reduced
min_lr=0.00001) # Minimum learning rate
Train the model
UNet compiles using sparseCategoricalCrossentropy() and Adam optimizer. Train over 100 epochs with verbosity mode set to 1 for detailed output.
# Compile the UNet model
unet.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), # Sparse categorical crossentropy loss
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), # Adam optimizer with specified learning rate
metrics=[m.mean_iou, 'acc'], # Metrics: custom mean IoU and accuracy
run_eagerly=True) # Run eagerly for better debugging
# Train the UNet model
history = unet.fit(train_data, # Training data generator
validation_data=valid_data, # Validation data generator
epochs=100, # Number of epochs
verbose=1, # Verbosity mode
callbacks=[checkpoint, lr_R]) # Callbacks: ModelCheckpoint and ReduceLROnPlateau
STEP 8:
Evaluation of model
In this code, we see the U-Net model result. We plot graphs of accuracy and loss curves to check model performance.
show_history(history, True)
Loading Label Encoding Information
The 'with' statement opens the file 'label.pickle' in binary read mode, ensuring proper closorer even with errors. The 'pickle.load()' function deserializes data from the pickle file object, loading it into a dictionary containing 'encoded' labels. A new dictionary encode is created by iterating over items and swapping keys and values, converting each key-value pair into a value-key pair in 'encode'.
import pickle
# Open the pickle file for reading
with open('label.pickle', 'rb') as handle:
# Load the dictionary from the pickle file
k = pickle.load(handle)
# Create a new dictionary 'encode' by swapping the keys and values of the loaded dictionary
encode = dict((j, list(i)) for i, j in k.items())
Loading Trained Model for evaluation
We load our best-trained model to check the model performance. After loading the model, we analyze the image.
from tensorflow.keras.models import load_model
# Load the pre-trained model from the HDF5 file
model = load_model('/content/drive/MyDrive/U_net/unet.h5', custom_objects={'mean_iou': m.mean_iou})
Output visualization
This will pass the elements of all_train_filenames as individuals to the show_example function.
show_example(*all_train_filenames[6], model, encode, (128,128), 'gray', func)
The show_example function visualizes original images with ground truth masks and predicted masks. It computes and displays class-wise IoU metrics for better evaluation of the segmentation quality.
def show_example(image, mask, model, label, inp_size, color_mode, function):
# Read and resize the original image
img = cv2.imread(image)
img = tf.image.resize(img, inp_size, method='nearest')
# Generate predictions using the model
pred, _pred = predict(model, image, label, color_mode, inp_size)
# If a ground truth mask is provided, process and visualize it
if mask is not None:
msk = cv2.imread(mask)
msk = tf.image.resize(msk, inp_size, method='nearest')
# Apply preprocessing function to the mask if provided
if function:
msk = tf.convert_to_tensor(function(msk.numpy()))
# Compute and display class-wise IoU
m.miou_class(train_data.processing(msk.numpy()), _pred)
# Create a visualization with original image, ground truth, prediction, and prediction for class 1
ground_truth = np.floor(img.numpy() * 0.7 + msk.numpy() * 0.3).astype('int')
fig, axes = plt.subplots(1, 4, figsize=(15, 5))
axes[0].imshow(img)
axes[0].set_title('Original Image')
axes[1].set_title('Ground truth')
axes[1].imshow(ground_truth)
axes[2].set_title('Prediction')
axes[2].imshow(pred)
axes[3].set_title('Prediction (Class 1)')
axes[3].imshow(_pred[0, :, :, 1], cmap='viridis', alpha=0.7)
else:
# If no ground truth mask is provided, display only the original image and prediction
fig, axes = plt.subplots(1, 3, figsize=(12, 3))
axes[0].imshow(img)
axes[0].set_title('Original Image')
axes[1].set_title('Prediction')
axes[1].imshow(pred)
plt.show()
# Now, use a loop to display multiple examples
for example_filename in all_train_filenames[:5]:
show_example(*example_filename, model, encode, (128,128), 'gray', func)
Project Conclusion
In this project, we embarked on an exciting journey into the world of medical image segmentation using the powerful U-Net architecture! We transformed complex medical images like MRIs, CT scans, and X-rays into clearer insights. This can help doctors diagnose and treat patients more effectively.
With our custom data generator, we ensured our model was well-fed with data. While the U-Net architecture worked its magic to identify key anatomical structures and abnormalities. We kept track of our model's performance using metrics like Mean IoU and accuracy. Finally, the training visuals made it easy to cheer on our model as it improved over time.
The result is we created a fantastic tool that paves the way for early disease detection and precise treatment planning, ultimately enhancing patient care! This project not only showcases the power of deep learning in healthcare but also sets the stage for future innovations that could revolutionize how we approach medical diagnostics. Let's continue pushing the boundaries of what's possible with technology and healthcare!
Challenges and Troubleshooting
As you deep dive into this innovative project, you might face some waves of challenges. Don’t worry! Here are some common challenges along with the solutions to help you.
Challenge: Data Quality and Diversity
Solution: Use data augmentation techniques to generate a more artificially expanded dataset.Challenge: Overfitting
Solution: Implement regularization techniques such as dropout layers, and early stopping during training. Then use a validation set to monitor performance.Challenge: Computational Resources
Solution: Use cloud-based services like Google Colab or AWS that offer powerful GPU options. Consider using model checkpoints to save progress and resume training as needed.Challenge: Model Evaluation
Solution: Use multiple evaluation metrics such as Mean IoU, and Dice coefficient to assess model performance and ensure that the model generalizes well to unseen data.Challenge: Interpretability of Results
Solution: Implement visualization techniques to overlay predicted masks on original images, helping stakeholders better understand model outputs and their clinical implications.
FAQs
Question 1: What are the goals of the Medical Image Segmentation project?
Answer: In this project, it is planned to segment the medical images (MRI, CT, X-rays) with the U-Net architecture. This project enhances the diagnostic accuracy of the diseases and provides support for treatment planning.
Question 2: Why U-Net applied in medical image segmentation?
Answer: U-Net is particularly developed for biomedical image segmentation problems. Its encoder-decoder structure enables the model to enfold loco-global features of images. It is especially important for detecting anatomical structures and pathologies.
Question 3: Why would healthcare professionals be interested in this project?
Answer: This project helps to complete the segmentation work with less human input and also the time taken to segment medical images. Automated methods allow healthcare practitioners to diagnose patients quickly and accurately.
Question 4: Which kind of images are used in the U-Net model?
Answer: This model can categorize many medical images that can be implemented for MRI, CT, and X-ray images and therefore, is flexible for different kinds of medical imaging.
Question 5: Can this segmentation model be applied to other medical specializations?
Answer: Yes, the current article shows that the proposed U-Net model can be applied to other medical imaging applications such as tumor detection, organ segmentation, or analysis of other diseases.