Image

Image Segmentation using Mask R CNN with PyTorch

Mask-R-CNN is being employed to create a deep-learning model for detecting brain Tumors. The project's main focus is to automatically detect and segment tumors in medical images so that diagnostics and treatment planning could benefit significantly. The use of computer vision in this study would enhance the accuracy and efficiency of identifying brain tumors.

Project Overview

This project aims to build a sophisticated deep-learning model using Mask R-CNN for brain tumor detection and segmentation. The model is provided with fine-tuning on a dedicated dataset with brain scans and tumor annotations within it, which allows it to properly detect and segment tumor-associated regions. The application of state-of-the-art computer vision techniques in the model results in fine segmentation masks and bounding boxes of the tumor regions in medical images. All these serve to automate the tumor detection process, create less manual effort, and improve early-stage diagnosis by diagnostic capabilities. The project addresses the urgent needs of the healthcare professionals for an efficient tool in reliable and analyzing medical images for assistance in clinical decisions.

Prerequisites

  • Knowledge of how deep learning and neural networks would work.
  • Knowledge of Python programming and tools like PyTorch and torchvision.
  • Previous work in image processing and the application of computer vision methods.
  • Understanding how Mask R-CNN was developed to work in object detection as well as in segmentations.
  • Knowledge of training with datasets and performing data preprocessing, and image augmentation.
  • Familiarity with basic modeling, training, optimizing, and evaluating models.
  • Awareness of how a GPU is used in training a model as well as in making predictions (if any).
  • User experience with Jupyter Notebooks or Google Colab to run deep learning models.
  • Knowledge about matplotlibs or other tools for visualizing, the results of the model.

Approach

The project methodology uses a pre-trained mask R-CNN network that has been fine-tuned on a brain tumor dataset to detect and segment tumor regions in medical images. The first step is to preprocess the dataset which includes normalization and tensor conversion. This model is based on the ResNet-50 backbone with a feature pyramid network. It has been specifically adapted for use in this study through reconfiguration of its classification and mask prediction layers to work with the single class: tumor. Model training is feeding the images through the model, loss computation, and then optimizing with gradient descent. Regularization techniques, for example, gradient clipping, were used to avoid problems such as exploding gradients. The performance of the model is evaluated on validation images; predictions are then visualized with segmentation masks and bounding boxes. The model learns how well to detect and segment tumors thereby providing a strong methodology for medical image analysis.

Workflow and Methodology

Workflow

  • Load and preprocess brain tumor data sourced from existing training and valid folders.
  • Use a pre-trained Mask R-CNN model and adapt its layers to suit the brain tumor detection task.
  • Define transformations that convert images to tensors and normalize them.
  • Train the model on the training dataset, applying optimizations like gradient clipping.
  • Validate performance on the validation dataset.
  • Apply the learned model to predict tumor masks and bounding boxes for test images.
  • Visualize predictions by masking and bounding boxes on the original images.
  • Further, refine the model to enhance its performance and accuracy depending on evaluation results.

Methodology

  • Use a pre-trained Mask R-CNN with custom layers fine-tuned for brain tumor detection.
  • Convert images into tensors and normalize them to promote better training and generalization of the model.
  • Train using all the standard optimization techniques like SGD with learning rate scheduling.
  • Use augmentation techniques and transformations to avoid overfitting and increase robustness.
  • Validated over the validation set to measure accuracy, loss, and quality of detection.
  • Visualization of segmentation masks and bounding boxes on images to examine output and errors with detection.

Data Collection and Preparation

Data collection

Brain Tumor dataset is available in Kaggle. It is possible to conveniently and securely access a Kaggle dataset from within Google Colab after configuring your Kaggle credentials to prevent compromising sensitive information. It brings in the user’s data to collect securely the Kaggle API key and username and assigns them as environment variables. This enables the use of Kaggle’s CLI command (!kaggle datasets download -d ammarnassanalhajali/brain-tumor) which authenticates the user and downloads the dataset straight into Colab.

Data Preparation

Data preparation workflow

  • Load the images from the respective folders using appropriate indexing.
  • Apply image transformations, such as resizing, normalization, and conversion to tensor format.
  • Extract and process the annotations, including masks and bounding boxes, for each image.
  • Ensure all masks are properly aligned with the images and in the correct format (binary masks).
  • Split the dataset into batches using a DataLoader, preparing it for efficient training.
  • Verify data integrity by checking for any missing or corrupted images and annotations.

Code Explanation

STEP 1:

Mounting Google Drive

This code mounts Google Drive to your Colab environment, allowing access to files stored in it. It makes files from your Drive available for use in your notebook.

from google.colab import drive
drive.mount('/content/drive')

Library Installation

This installs essential Python libraries; torch, torchvision, and torchaudio for deep learning. Matplotlib, opencv-python, and pycocotools are intended for image processing and visualization. This mainly prepares the environment for model training and data manipulation.

!pip install torch torchvision torchaudio
!pip install matplotlib opencv-python pycocotools

Import Libraries

The code below imports several libraries. The libraries used are cv2, torch, matplotlib, and PIL for image processing and other data manipulation functions. It also imports libraries for model building and advanced image functions like torchvision and skimage.

import os
import sys
import cv2
import json
import torch
import shutil
import random
import matplotlib
import torchvision
import numpy as np
import skimage.draw
from PIL import Image
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
from IPython.display import clear_output
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader

Cloning and Cleaning Up Mask R-CNN Repository

These codes clone the Mask R-CNN repository from GitHub and remove all .git folders so committing kernel will not throw an error. It also removes images and assets folders, preventing any unwanted images from being displayed at the bottom of the notebook.

!git clone https://www.github.com/matterport/Mask_RCNN.git
!rm -rf .git # to prevent an error when the kernel is committed
!rm -rf images assets # to prevent displaying images at the bottom of a kernel

Random Image Display in Grid

This code randomly chooses 6 images from the training folder and assembles them in a grid using matplot to make a grid layout of the images without the axis. The pictures are visible in rows of 3, thus ensuring clear visibility.

# Define the training folder
train_image_folder = '/content/brain-tumor/Training'  # Change this to your path
# Get image filenames from the training folder
image_files = [f for f in os.listdir(train_image_folder) if f.endswith(('.jpg', '.jpeg', '.png'))]
# Select 6 random images
num_images = 6  # Set the number of images you want to display
selected_images = random.sample(image_files, min(num_images, len(image_files)))
def display_images_in_grid(image_folder, selected_images, images_per_row=3):
# Calculate the number of rows needed
rows = (len(selected_images) // images_per_row) + (1 if len(selected_images) % images_per_row else 0)
# Create a figure with subplots
fig, axes = plt.subplots(rows, images_per_row, figsize=(15, 5 * rows))
# Flatten axes for easy indexing
axes = axes.flatten()
# Display the images
for i, image_file in enumerate(selected_images):
image_path = os.path.join(image_folder, image_file)
img = Image.open(image_path)
axes[i].imshow(img)
axes[i].axis('off')  # Hide the axes
axes[i].set_title(image_file)
for i in range(len(selected_images), len(axes)):
axes[i].axis('off')
plt.tight_layout()
plt.show()
display_images_in_grid(train_image_folder, selected_images, images_per_row=3)

Defining BrainTumorDataset Class

This class loads brain tumor images and preprocesses them to train a model. It reads images along with their annotations, makes masks out of polygonal annotations and returns bounding boxes for each tumor region. The data is returned as a dictionary containing images and targets (masks, labels, boxes). Optional transformations are applied.

class BrainTumorDataset(Dataset):
def __init__(self, dataset_dir, subset, transforms=None):
"""
Args:
dataset_dir (string): Directory with all the images and annotations.
subset (string): 'Training', 'Validation', or 'Test'
transforms (callable, optional): Optional transform to be applied on a sample.
"""
self.dataset_dir = dataset_dir
self.subset = subset
self.transforms = transforms
# Add the class 'tumor' with class_id 1
self.class_names = ["tumor"]
self.class_id = 1
# Load annotations
self.annotations = json.load(open(os.path.join(dataset_dir, f'annotations_{subset}.json')))
self.annotations = list(self.annotations.values())
self.annotations = [a for a in self.annotations if a['regions']]  # Skip unannotated images
# List of images
self.images = [a for a in self.annotations if a['regions']]
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
a = self.images[idx]
# Get the image and its annotations
image_path = os.path.join(self.dataset_dir, self.subset, a['filename'])
image = Image.open(image_path).convert("RGB")
width, height = image.size
# Load the polygons for the mask
polygons = [r['shape_attributes'] for r in a['regions']]
masks = np.zeros((height, width, len(polygons)), dtype=np.uint8)
for i, p in enumerate(polygons):
# Clip polygon points to be within image bounds
clipped_y = np.clip(p['all_points_y'], 0, height - 1)
clipped_x = np.clip(p['all_points_x'], 0, width - 1)
rr, cc = skimage.draw.polygon(clipped_y, clipped_x)
masks[rr, cc, i] = 1
# Convert masks to a tensor
masks = torch.as_tensor(masks, dtype=torch.uint8)
# The class ids are all 1 (tumor)
labels = torch.ones((masks.shape[-1],), dtype=torch.int64)
# Bounding boxes for the masks (not used here but required by PyTorch)
boxes = []
for i in range(masks.shape[-1]):
pos = np.where(masks[:, :, i] \> 0)
ymin, xmin = np.min(pos[0]), np.min(pos[1])
ymax, xmax = np.max(pos[0]), np.max(pos[1])
boxes.append([xmin, ymin, xmax, ymax])
boxes = torch.as_tensor(boxes, dtype=torch.float32)
# Convert image to numpy array
image = np.array(image)  # Convert to numpy array instead of tensor
# If transforms are specified, apply them to the image
if self.transforms:
image = self.transforms(image)
# Create the sample dictionary with the image and annotations
sample = {'image': image, 'target': {'masks': masks, 'labels': labels, 'boxes': boxes}}
return sample

Visualize Images along with Masks

The function overlays a created mask upon the original image at a particular axis. It shows the image and applies the mask with transparency, implementing the jet color map to improve the visualization aspect. The axis titles are hidden for a clearer view.

def visualize_mask_and_image(image, mask, ax):
"""
Visualizes the original image and the generated mask on a given axis.
Args:
- image: The original image.
- mask: The generated mask for the tumor region.
- ax: The axis on which to plot the image and mask.
"""
ax.imshow(image)
ax.imshow(mask, alpha=0.5, cmap='jet')  # Overlay the mask with transparency
ax.axis('off')

Visualizing random samples with masks.

This function randomly samples from the training data set and visualizes the results with its generated mask. It shows the original image and its corresponding mask side by side for better comparison. The mask layers are transparent images over an original image.

def visualize_random_samples(dataset, num_samples=3):
"""
Visualizes random samples from the dataset along with their generated masks.
Args:
- dataset: The dataset object that contains images and annotations.
- num_samples: The number of random samples to visualize.
"""
# Select random samples from the dataset
selected_indices = random.sample(range(len(dataset)), num_samples)
# Set up the figure for displaying images and masks side by side
fig, axes = plt.subplots(num_samples, 2, figsize=(12, 4 * num_samples))
for idx, axis in zip(selected_indices, axes):
# Get the sample
sample = dataset[idx]
image = sample['image']
mask = sample['target']['masks'][:, :, 0]  # Get the first mask (assuming binary mask)
# Visualize the original image in the first column
axis[0].imshow(image)
axis[0].set_title(f"Original Image {idx+1}")
axis[0].axis('off')
# Visualize the mask in the second column
visualize_mask_and_image(image, mask, axis[1])
axis[1].set_title(f"Generated Mask {idx+1}")
# Display the grid
plt.tight_layout()
plt.show()
# Example usage to visualize 3 random samples
dataset = BrainTumorDataset(dataset_dir='/content/brain-tumor', subset='Training')
visualize_random_samples(dataset, num_samples=3)

Load Pre-trained Mask R-CNN Model

This loads a pre-trained Mask R-CNN model with ResNet-50 backbone and Feature Pyramid Network (FPN) from torchvision, ready to be fine-tuned on the dataset, e.g., for tumor detection in medical images.

# Load pre-trained Mask R-CNN model from torchvision
model = models.detection.maskrcnn_resnet50_fpn(pretrained=True)

Updating the Model for the Brain Tumor Dataset

This code alters the classifier of Mask-RCNN, pre-trained models to take as input the brain tumor dataset with a single class (tumor). The number of output classes will be 2, including the background. The box predictor is modified for this new classification setting.

# Modify the classifier to suit the brain tumor dataset (only 1 class: tumor)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes=2)

Adaptations in the Mask Predictor for Brain Tumor Dataset

This code modifies the mask predictor for the Mask R-CNN model to fit the brain tumor dataset. It changes the number of input channels for the mask predictor and sets the number of output classes to 2, including background and tumor. It also defines the output mask size for the tumor region.

in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(in_features_mask, 256, num_classes=2)

The Image Transformations defining

This code defines a series of transformations of the dataset that, first, convert images into tensors and then normalize them according to predefined mean and standard deviation values. This kind of transformation is typically applied to preprocess images before training them in deep learning networks.

transform = transforms.Compose([
transforms.ToTensor(),  # Convert to tensor after applying other transformations
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Example normalization
])

Load Dataset of Brain Tumors

This code imports the datasets for training and validation from the BrainTumorDataset class and applies the transformations defined previously (i.e. conversion to tensor and normalization) to both the training and validation sets.

DATASET_DIR = '/content/brain-tumor'
# Load dataset
dataset_train = BrainTumorDataset(DATASET_DIR, 'Training', transforms=transform)
dataset_val = BrainTumorDataset(DATASET_DIR, 'Validation', transforms=transform)

Set up the DataLoader for Training.

The code defines a DataLoader for the training dataset with a batch size of one while shuffling the dataset and utilizing a customized collate_fn in charge of batching without altering the basic structure of each sample.

# Create DataLoader for training
train_loader = DataLoader(dataset_train, batch_size=1, shuffle=True, collate_fn=lambda x: x)

Configuring the device for the model

This code configures the device for executing the model. It checks when there is GPU (CUDA) available and if so, uses it; otherwise, it gets back to the CPU. Finally, the model is moved onto the chosen device for efficient computation.

# Set up device for model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Assign Optimizer and Learning Rate Scheduler

This code sets the optimizer up and specifies SGD for the model. It specifies 0.0005 as the learning rate, uses a momentum of 0.9, and weight decay for regularization. It then specifies a learning rate scheduler to reduce the learning rate by a factor of 0.1 every 3 epochs.

# Set up optimizer and learning rate scheduler
optimizer = torch.optim.SGD(model.parameters(), lr=0.0005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

Test for Batch NaN or Infinity Values

This function allows you to check whether there are NaN or Infinity values within the image, mask, labels, or bounding box tensors in a batch, iterates through the batch, and if any of these tensors contain NaN or Infinity values, it will print a message on the console which will further help in recognizing the data issues before model training.

# Function to check for NaN or infinity values in the batch
def check_for_nans(batch):
for sample in batch:  # Loop over the batch
if torch.isnan(sample['image']).any():  # Check if any image tensor contains NaN values
print("NaN found in image")
if torch.isinf(sample['image']).any():  # Check if any image tensor contains Inf values
print("Infinity found in image")
if torch.isnan(sample['target']['masks']).any():  # Check if any mask tensor contains NaN values
print("NaN found in masks")
if torch.isinf(sample['target']['masks']).any():  # Check if any mask tensor contains Inf values
print("Infinity found in masks")
if torch.isnan(sample['target']['labels']).any():  # Check if any labels tensor contains NaN values
print("NaN found in labels")
if torch.isinf(sample['target']['labels']).any():  # Check if any labels tensor contains Inf values
print("Infinity found in labels")
if torch.isnan(sample['target']['boxes']).any():  # Check if any boxes tensor contains NaN values
print("NaN found in boxes")
if torch.isinf(sample['target']['boxes']).any():  # Check if any boxes tensor contains Inf values
print("Infinity found in boxes")
#

Gradient-Clipping Training of the Model

This code trains the model for 7 epochs. It checks for NaN or Infinity in the data, sends images and target data to the device (GPU or CPU), and performs the forward pass. It calculates the loss; then, it performs the backward and gradient clipping to avoid explosive gradients, updating the model parameters. The learning rate scheduler is also updated after each epoch.

# Training loop
num_epochs = 7
for epoch in range(num_epochs):
model.train()
for i, batch in enumerate(train_loader):
check_for_nans(batch)  # Check for NaN or Inf in data
# Move the image tensor to the device
images = [x['image'].to(device) for x in batch]
# Move each component of the target dictionary to the device
targets = []
for x in batch:
target = x['target']
target = {
'masks': target['masks'].to(device),
'labels': target['labels'].to(device),
'boxes': target['boxes'].to(device)
}
targets.append(target)
# Forward pass
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
# Backward pass
optimizer.zero_grad()
losses.backward()
# Clip gradients to avoid exploding gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {losses.item()}")
lr_scheduler.step()

Model Inference & Visualization

This code sets the model to evaluation mode (eval()) and defines a function to predict and display tumor masks for test cases from the validation dataset. It does a no-gradient run of the model, predicting the tumor masks and visualizing them with the original image. The first image of the validation dataset is then fed to the model for prediction.

# Switch to inference mode
model.eval()
def predict_and_plot(dataset, idx):
sample = dataset[idx]
image = sample['image'].unsqueeze(0).to(device)
with torch.no_grad():
prediction = model(image)
masks = prediction[0]['masks'] \> 0.5
boxes = prediction[0]['boxes']
plt.figure(figsize=(10,10))
plt.imshow(image[0].cpu().numpy().transpose(1, 2, 0))
plt.title("Predicted Tumor Masks")
plt.show()
# Test on some images
predict_and_plot(dataset_val, 0)

Show Original Image and Segmentation Mask

This code loads the 6th sample from the training dataset and extracts from it the image, mask, class labels, and bounding boxes, displaying them next to each other. The original image appears on the left, while the segmentation mask with a green overlay appears on the right against a black background. The green region indicates tumor areas.

# Load the image and mask using dataset indexing
image_data = dataset_train[6]  # Get the data for the 6th sample
image = image_data['image'].numpy().transpose(1, 2, 0)  # Convert image tensor to numpy (H, W, C)
mask = image_data['target']['masks'].numpy()  # Get the masks tensor
class_ids = image_data['target']['labels']  # Get the labels (class_ids)
boxes = image_data['target']['boxes'].numpy()  # Get the bounding boxes
# Create a figure with 1 row and 2 columns (side by side)
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
# --- Display the Original Image ---
ax[0].imshow(image)
ax[0].set_title("Original Image")
ax[0].axis('off')  # Hide axes
# --- Display the Segmentation Mask with Black Background and Green Segmentation ---
# Create an all-black image
black_image = np.zeros_like(image)
# Combine the masks for multiple objects into one
combined_mask = mask.max(axis=-1)  # Combine all object masks into one mask
# Show the black background
ax[1].imshow(black_image)
# Overlay the segmentation mask in green (with transparency)
ax[1].imshow(combined_mask, cmap='Greens', alpha=1.0)  # Green overlay with alpha=1.0 (fully opaque)
ax[1].set_title("Segmentation Mask in Green")
ax[1].axis('off')  # Hide axes
# Show the plot
plt.tight_layout()
plt.show()

Inference and Visualization for Predicted Masks with Bounding Boxes

This code carries out inference on a sample image from the validation dataset, extracts the predicted masks and bounding boxes, and visualizes all of the above. It shows the original image along with the overlay of the predicted tumor mask and a bounding box around it indicating the tumor region. This allows one to evaluate the model's detection ability.

import matplotlib.pyplot as plt
import torch
import matplotlib.patches as patches
# Switch to inference mode
model.eval()
def predict_and_plot(dataset, idx, device):
sample = dataset[idx]
image = sample['image'].unsqueeze(0).to(device)  # Add batch dimension and move to device
with torch.no_grad():
prediction = model(image)
# Get the predicted mask(s) and convert to numpy
predicted_masks = prediction[0]['masks'] \> 0.5  # Apply threshold for binary mask
# We can extract the mask for the first object, if there are multiple
mask = predicted_masks[0, 0].cpu().numpy()  # Get first mask, remove batch dimension
boxes = prediction[0]['boxes'].cpu().numpy()  # Get the bounding boxes
# Plot the image and the predicted mask
plt.figure(figsize=(10, 10))
# Display the original image
plt.subplot(1, 2, 1)
plt.imshow(image[0].cpu().numpy().transpose(1, 2, 0))  # Convert from CHW to HWC format for display
plt.title("Original Image")
plt.axis('off')
# Display the predicted mask overlay
plt.subplot(1, 2, 2)
plt.imshow(image[0].cpu().numpy().transpose(1, 2, 0))  # Show the original image first
plt.imshow(mask, cmap='Reds', alpha=0.5)  # Overlay the predicted mask in red
# Add the bounding box on top of the image
for box in boxes:
rect = patches.Rectangle(
(box[0], box[1]), box[2] - box[0], box[3] - box[1],
linewidth=2, edgecolor='yellow', facecolor='none'
)
plt.gca().add_patch(rect)  # Add the rectangle on top of the image
plt.title("Predicted Tumor Mask with Bounding Box")
plt.axis('off')
plt.show()
# Test on some images
predict_and_plot(dataset_val, 0, device)

Conclusion

This project proved the feasibility of using Mask R-CNN for brain tumor detection and segmentation. We fine-tuned a pre-trained model on a tailored dataset and successfully detected the tumor areas in medical images. The trained model was validated with the validation set and the results were visualized using segmentation masks and bounding boxes. This approach opened up the use of deep learning in the analysis of medical images via an excellent tool, which can help doctors early in their patient diagnosis. With further modifications and adaptations, the model will be ready for real-world clinical application to automate and simplify the diagnosis process.

Challenges New Coders Might Face

  • Challenge: Data Quality and Annotation Errors
    Solution: Carefully review and clean the dataset to ensure accurate annotations, and consider manual validation of a subset of the images.

  • Challenge: Insufficient Data
    Solution: Use data augmentation techniques, such as rotation, flipping, and scaling, to artificially increase the dataset size and diversity.

  • Challenge: Overfitting Model
    Solution: Implement regularization methods, such as dropout and weight decay, and ensure proper validation using a separate dataset.

  • Challenge: Long Training Time
    Solution: Use a pre-trained model, fine-tune it on your dataset, and apply techniques like early stopping to reduce unnecessary training time.

FAQ

Question 1: What is a Mask R-CNN, and how does it apply to image segmentation?
Answer: Mask R-CNN is the advanced state-of-the-art deep learning model that is used for object detection and segmentation. In this study, it is used to detect and segment brain tumors in medical images. The model produces bounding boxes along with segmentation masks for tumor regions.

Question 2: Which dataset is used for the model training?
Answer: This dataset contains brain tumor images with tumor locations marked in regions. The dataset is divided for training and validation purposes, helping evaluate models trained.

Question 3: What is the advantage of using a pre-trained model like Mask R-CNN?
Answer: Using a pre-trained Mask R-CNN has helped in speeding convergence and thus better performance. Learning general features from new, larger datasets has been undertaken, which can be fine-tuned toward our brain tumor dataset for specific tumor identification.

Question 4: What types of preprocessing is done on the images for the model?
Answer: Images are resized, normalized, and converted into tensors. This ensures that the model now receives data in a fixed and aligned format in terms of requirements.

Question 5: What is the role of data augmentation in image segmentation?
Answer: Data augmentation artificially increases the size of the dataset and thus helps to improve the generalization of the model. Techniques like the rotation of images, flipping, and scaling are implemented to prevent overfitting and increase general performance.

Code Editor