Artificial Intelligence

Apply machine learning and deep learning

Deep Learning in Action – part 1

Hello Everyone! Allow me to quickly introduce myself. My name is Johanna, and Steve has allowed me to take over the blog from time to time to talk about deep learning.
Today I’d like to kick off a series called:

“Deep Learning in Action:

Cool projects created at MathWorks  
This aims to give you insight into what we’re working on at MathWorks: I’ll show some demos, and give you access to the code and maybe even post a video or two. Today’s demo is called "Pictionary" and it’s the first article in a series of posts, including:
 
  • 3D Point Cloud Segmentation using CNNs
  • GPU Coder
  • Age Detection
  • And maybe a few more!
 
Demo: Pictionary
Pictionary refers to a game in which one person/team draws an object and the other person/team tries to guess what the object is.
The developer of the Pictionary demo is actually … me! This demo came about when a MathWorks developer posted on an internal message board: We already had an example of doing handwriting detection with the MNIST dataset, but this was a unique spin on that concept. Thus, the idea of creating a Pictionary example was born.
Read the images in the dataset
The first challenge [and honestly, the hardest part of the example] was reading in the images. Each image contains many drawings of an object category, for example there’s an “ant” category which has thousands of hand-drawn ants stored in a JSON file.  Each line of the file looks something like this:
{"word":"ant","countrycode":"US","timestamp":"2017-03-27 00:14:57.31033 UTC","recognized":true,"key_id":"5421013154136064","drawing":[[[27,17,16,21,34,50,49,34,23,17],[47,58,73,81,84,67,54,46,47,51]],[[22,0],[51,18]],[[41,46,43],[45,11,0]],[[53,65,64,69,91,119,135,148,159,158,149,126,87,68,62],[68,68,58,51,36,34,38,48,64,78,85,90,90,83,73]],[[161,175],[70,69]],[[180,177,176,187,206,226,244,250,250,245,233,207,188,180,180],[68,67,61,50,42,40,48,58,72,80,87,89,83,76,71]],[[73,61],[85,113]],[[95,94],[88,126]],[[140,157],[90,118]],[[199,201,208],[90,116,122]],[[234,242,255],[89,105,112]]]}
 
Can you see the image? Me neither. The image is contained as x,y connector points. If we pull out the x,y points from the file, we can see the drawing start to take shape.
Stroke X Values Y Values
1 27,17,16,21,34,50,49,34,23,17 47,58,73,81,84,67,54,46,47,51
2 22,0 51,18
3 41,46,43 45,11,0
4 53,65,64,69,91,119,135,148,159,158,149,126,87,68,62 68,68,58,51,36,34,38,48,64,78,85,90,90,83,73
 
The idea of the file is to capture individual “strokes,” i.e. what was drawn without lifting the pen. Let’s take Stroke #1:
The X and Y values plotted on the image look like this:
Full Image: Zoomed In:

X,Y values from input file plotted in pink

Same image, just zoomed in

             
And then we play a quick game of “connect the dots” and we get our first stroke resembling a drawing. Connect the dots is fairly easy in MATLAB with a function called iptui.intline
>> help iptui.intline
[X, Y] = intline(X1, X2, Y1, Y2) computes an approximation to the line segment joining 
(X1, Y1) and (X2, Y2) with integer coordinates.

X,Y values plotted (pink) and "strokes" connecting them (yellow)

We do that for the remaining strokes, and we get:

The yellow coloring is just for visual emphasis. The actual images will have a black background and white drawing.

Finally! A drawing slightly resembling an ant. Now that we can create images from these x,y points, we can create a function and quickly repeat this for all the ants in the file, and multiple categories too.    
Now, this dataset assumes that people drew with a pencil, or something thin, since the thickness of the line is only 1 pixel.  We can quickly change the thickness of the drawing with image processing tools, like image dilation. I imagined that people would be playing this on a whiteboard with markers, so training on thicker lines will help with this assumption.  
larger_im = imdilate(im2,strel('Disk',3));
And while we’re cleaning things up, lets center the image too:
For this example, I pulled 5000 training images, and 500 test images. There are many (many many!) more example images available in the files, so feel free to increase these numbers if you’re so inclined.
Create and train the network
Now that our dataset is ready to go, let’s start training. Here’s the structure of the network:
layers = [
    imageInputLayer([256 256 1])

    convolution2dLayer(3,16,'Padding',1)
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(2,'Stride',2)

    convolution2dLayer(3,32,'Padding',1)
    batchNormalizationLayer
    reluLayer

    maxPooling2dLayer(2,'Stride')

    convolution2dLayer(3,64,'Padding',1)
    batchNormalizationLayer
    reluLayer

    fullyConnectedLayer(5)
    softmaxLayer
    classificationLayer];
How did I pick this specific structure? Glad you asked. I stole from other people that had  already created a network. This specific network structure is 99% accurate on the MNIST dataset, so I figured it was a good starting point for these handwritten drawings.
Here’s a handy plot created with this code:
lgraph = layerGraph(layers);
plot(layers)

 

I’ll admit, this is a fairly boring layer graph since it’s all a straight line, but if you were working with DAG networks, you could easily see the connections of a complicated network.

I trained this with a zippy NVIDIA P100 GPU in roughly 20 minutes. Test images set aside give an accuracy of roughly 90%. For an autonomous driving scenario, I would need to go back and refine the algorithm. For a game of Pictionary, this is a perfectly acceptable number in my opinion.
predLabelsTest = net.classify(uint8(imgDataTest));
testAccuracy = sum(predLabelsTest == labelsTest') / numel(labelsTest)
testAccuracy = 0.8996
Debug the network
Let’s drill down into the accuracy to give more insight into the trained network.
One way to look at the specific categories’ predictions is to create a confusion matrix. A very simple option is to create a heatmap. This works similar to a confusion matrix – assuming you have the same number of images in each category – which we do: 500 test images per category.
% visualize where the predicted label doesn't match the actual

tt = table(predLabelsTest, categorical(labelsTest'),'VariableNames',{'Predicted','Actual'});
figure('name','confusion matrix'); heatmap(tt,'Actual','Predicted');
One thing that pops out is that ants and wristwatches tend to confuse the classifier. This seems like reasonable confusion. If we were confusing wine glasses with ants, then we might have a problem.
There are two reasons for error in our Pictionary classifier:
  1. The person guessing can’t identify the object, or
  2. The person drawing doesn’t describe the object well enough.
  If we are asking for a classifier to be 100% accurate, we are assuming that the person drawing never does a poor job for any of the object categories. Highly unlikely.
Drilling down even further, let’s look at the 67 ants that were misclassified.
% pick out the times  where the predicted label doesn't match the actual 

idx = find(predLabelsTest ~= labelsTest');
loser_ants = idx(idx < 500);

montage(imgDataTest(:,:,1,loser_ants));
I’m going to go out on a limb and say that at least 18 of these ants shouldn’t be called ants at all. In defense of my classifier, let’s say that you were playing Pictionary, and someone drew this:
% select an image from the bunch 
ii = 169
img = squeeze(uint8(imgDataTest(:,:,1,ii)));

actualLabel = labelsTest(ii);
predictedLabel = net.classify(img);

imshow(img,[]);
title(['Predicted: ' char(predictedLabel) ', Actual: '  char(actualLabel)])
 
What are the chances you would call that an ant?? If a computer classifies this as an umbrella, is that really an error??
Try the classifier on new images
Now the whole point of this example was to see what it would be like for new images in real life. I drew an ant...

My very own ant drawing

...and the trained model can now tell me what it thinks it is. Let’s throw in a confidence rating too.
s = snapshot(webcam); 
myDrawing = segmentImage(s(:,:,2));

myDrawing = imresize(myDrawing,[256,256]); % ensure this is the right size for processing

[predval,conf] = net.classify(uint8(myDrawing));
imshow(myDrawing);
title(string(predval)+ sprintf(' %.1f%%',max(conf)*100));
  I used a segmentation function created with image processing, that finds the object I drew and flips the black to white and white to black.
Looks like my Pictionary skills are good enough!
This code is on FileExchange, and you can see this example in a webinar I recorded with my colleague Gabriel Ha. Leave me a comment below if you have any questions. Join me next time when I talk to a MathWorks engineer about using CNNs for Point Cloud segmentation!
|
  • print

コメント

コメントを残すには、ここ をクリックして MathWorks アカウントにサインインするか新しい MathWorks アカウントを作成します。