Post by Dr.
Barath Narayanan,
University of Dayton Research Institute (UDRI) with co-authors: Dr.
Russell C. Hardie, University of Dayton (UD),
Manawduge Supun De Silva, UD, and
Nathaniel K. Kueterman, UD.
Introduction
Diabetic Retinopathy (DR) is one of the leading cause for blindness, affecting over 93 million people across the world. DR is an eye disease associated with diabetes. Detection and grading DR at an early stage would help in preventing permanent vision loss. Automated detection and grading during the retinal screening process would help in providing a valuable second opinion. In this blog, we implement a simple transfer-learning based approach using a deep Convolutional Neural Network (CNN) to
detect DR.
Please cite the following articles if you're using any part of the code for your research:
- Narayanan, B. N., Hardie, R. C., De Silva, M. S., & Kueterman, N. K. (2020). Hybrid machine learning architecture for automated detection and grading of retinal images for diabetic retinopathy. Journal of Medical Imaging, 7(3), 034501.
- Narayanan, B. N., De Silva, M. S., Hardie, R. C., Kueterman, N. K., & Ali, R. (2019). Understanding Deep Neural Network Predictions for Medical Imaging Applications. arXiv preprint arXiv:1912.09621.
The Kaggle blindness detection challenge dataset (
APTOS 2019 Dataset) contains separate training and testing cases. In this blog, we solely utilize the training dataset to study and estimate the performance. These images were captured at the Aravind Eye Hospital, India. The training dataset contains 3662 images marked into different categories (Normal, Mild DR, Moderate DR, Severe DR, and Proliferative DR) by expert clinicians.
Note that, in this blog, we solely focus on detecting DR, you could find more details about our grading architecture in our paper.
Grouping Data by Category
We extract the labels from excel sheet and segregate the images into 2-folders as 'no' or 'yes' as we're solely focused on detecting DR in this blog.
The helper code for splitting the data into categories is at the end of this post.
Load the Database
Let's begin by loading the database using
imageDatastore. It's a computationally efficient function to load the images along with its labels for analysis.
two_class_datapath = 'Train Dataset Two Classes';
imds=imageDatastore(two_class_datapath, ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames');
total_split=countEachLabel(imds)
Visualize the Images
Let's visualize the images and see how images differ for each class. It would also help us determine the type of classification technique that could be applied for distinguishing the two classes. Based on the images, we could identify preprocessing techniques that would assist our classification process. We could also determine the type of CNN architecture that could be utilized for the study based on the similarities within the class and differences across classes. In this article, we implement transfer learning using inception-v3 architecture. You can
read our paper to see the performance of different preprocessing operations and other established architectures.
num_images=length(imds.Labels);
perm=randperm(num_images,20);
figure;
for idx=1:20
subplot(4,5,idx);
imshow(imread(imds.Files{perm(idx)}));
title(sprintf('%s',imds.Labels(perm(idx))))
end
Training, Testing and Validation
Let’s split the dataset into training, validation and testing. At first, we are splitting the dataset into groups of 80% (training & validation) and 20% (testing). Make sure to split equal quantity of each class.
train_percent=0.80;
[imdsTrain,imdsTest]=splitEachLabel(imds,train_percent,'randomize');
valid_percent=0.1;
[imdsValid,imdsTrain]=splitEachLabel(imdsTrain,valid_percent,'randomize');
This leaves us with the following counts:
|
Yes |
No |
Training Set: |
1337 |
1300 |
Validation Set: |
144 |
149 |
Test Set: |
361 |
371 |
Deep Learning Approach
Let’s adopt a transfer learning approach to classify retinal images. In this article, I’m utilizing Inception-v3 for classification, you could utilize other transfer learning approaches as mentioned in the
paper or any other architecture that you think might be suited for this application. My MathWorks blogs on transfer learning using other established networks can be found here:
AlexNet,
ResNet
Training
We will utilize validation patience of 3 as the stopping criteria. For starters, we use 'MaxEpochs' as 2 for our training, but we can tweak it further based on our training progress. Ideally, we want the validation performance to be high when training process is stopped. We choose a mini-batch size of 32 based on our hardware memory constraints, you could pick a bigger mini-batch size but make sure to change the other parameters accordingly.
augimdsTrain = augmentedImageDatastore([299 299],imdsTrain);
augimdsValid = augmentedImageDatastore([299 299],imdsValid);
options = trainingOptions('adam','MaxEpochs',2,'MiniBatchSize',32,...
'Plots','training-progress','Verbose',0,'ExecutionEnvironment','parallel',...
'ValidationData',augimdsValid,'ValidationFrequency',50,'ValidationPatience',3);
netTransfer = trainNetwork(augimdsTrain,incepnet,options);
Testing and Performance Evaluation
augimdsTest = augmentedImageDatastore([299 299],imdsTest);
[predicted_labels,posterior] = classify(netTransfer,augimdsTest);
actual_labels = imdsTest.Labels;
figure
plotconfusion(actual_labels,predicted_labels)
title('Confusion Matrix: Inception v3');
test_labels=double(nominal(imdsTest.Labels));
[fp_rate,tp_rate,T,AUC] = perfcurve(test_labels,posterior(:,2),2);
figure;
plot(fp_rate,tp_rate,'b-');hold on;
grid on;
xlabel('False Positive Rate');
ylabel('Detection Rate');
Class Activation Mapping Results
We visualize the Class Activation Mapping (CAM) results for these networks for different DR cases using the code:
https://www.mathworks.com/help/deeplearning/examples/investigate-network-predictions-using-class-activation-mapping.html. This would help in providing insights behind the algorithm's decision to the doctors.
Here are the results obtained for various cases:
Conclusions
In this blog, we have presented a simple deep learning-based classification approach for CAD of DR in retinal images. The classification algorithm using Inception-v3 without any preprocessing performed relatively well with an overall accuracy of 98.0% and an AUC of 0.9947 (results may vary because of the random split). In the
paper, we studied the performance of various established CNN architectures for the same set of training and testing cases under different preprocessing conditions. Combining the results of various architectures provides a boost in performance both in terms of AUC and overall accuracy. A comprehensive study of these algorithms, both in terms of computation (memory and time) and performance, allows the subject matter experts to make an informed choice. In addition, we have presented our novel architecture approaches in the
paper for detection and grading of DR.
About the Author
|
Dr. Barath Narayanan graduated with MS and Ph.D. degree in Electrical Engineering from the University of Dayton (UD) in 2013 and 2017 respectively. He currently holds a joint appointment as a Research Scientist at UDRI's Software Systems Group and as an Adjunct Faculty for the ECE department at UD. His research interests include deep learning, machine learning, computer vision, and pattern recognition. |
Helper Code
Code for grouping data by DR category (yes or no)
After downloading the ZIP files from the website and extracting them to a folder called "train_images". Make sure to download the excel sheet (train.csv - convert it to .xlsx for this code) containing the true labels by expert clinicians. We extract the labels from excel sheet and segregate the images into 2-folders as 'no' or 'yes' as we solely focus on detecting DR in this blog.
datapath='train_images\';
two_class_datapath='Train Dataset Two Classes\';
class_names={'No','Yes'};
mkdir(sprintf('%s%s',two_class_datapath,class_names{1}))
mkdir(sprintf('%s%s',two_class_datapath,class_names{2}))
[num_data,text_data]=xlsread('train.xlsx');
train_labels=num_data(:,1);
train_labels(train_labels~=0)=2;
train_labels(train_labels==0)=1;
filename=text_data(2:end,1);
for idx=1:length(filename)
% fprintf('Processing %d among %d files:%s \n',idx,length(filename),filename{idx})[/%]
current_filename=strrep(filename{idx}, char(39), '');
img=imread(sprintf('%s%s.png',datapath,current_filename));
imwrite(img,sprintf('%s%s%s%s.png',two_class_datapath,class_names{train_labels(idx)},'\',current_filename));
clear img;
end
Comments
To leave a comment, please click here to sign in to your MathWorks Account or create a new one.