Deep Learning in Action – part 1
“Deep Learning in Action:
Cool projects created at MathWorks” This aims to give you insight into what we’re working on at MathWorks: I’ll show some demos, and give you access to the code and maybe even post a video or two. Today’s demo is called "Pictionary" and it’s the first article in a series of posts, including:- 3D Point Cloud Segmentation using CNNs
- GPU Coder
- Age Detection
- And maybe a few more!
Demo: Pictionary Pictionary refers to a game in which one person/team draws an object and the other person/team tries to guess what the object is. The developer of the Pictionary demo is actually … me! This demo came about when a MathWorks developer posted on an internal message board: We already had an example of doing handwriting detection with the MNIST dataset, but this was a unique spin on that concept. Thus, the idea of creating a Pictionary example was born. Read the images in the dataset The first challenge [and honestly, the hardest part of the example] was reading in the images. Each image contains many drawings of an object category, for example there’s an “ant” category which has thousands of hand-drawn ants stored in a JSON file. Each line of the file looks something like this:
Stroke | X Values | Y Values |
1 | 27,17,16,21,34,50,49,34,23,17 | 47,58,73,81,84,67,54,46,47,51 |
2 | 22,0 | 51,18 |
3 | 41,46,43 | 45,11,0 |
4 | 53,65,64,69,91,119,135,148,159,158,149,126,87,68,62 | 68,68,58,51,36,34,38,48,64,78,85,90,90,83,73 |
Full Image: | Zoomed In: |
>> help iptui.intline [X, Y] = intline(X1, X2, Y1, Y2) computes an approximation to the line segment joining (X1, Y1) and (X2, Y2) with integer coordinates.We do that for the remaining strokes, and we get: Finally! A drawing slightly resembling an ant. Now that we can create images from these x,y points, we can create a function and quickly repeat this for all the ants in the file, and multiple categories too. Now, this dataset assumes that people drew with a pencil, or something thin, since the thickness of the line is only 1 pixel. We can quickly change the thickness of the drawing with image processing tools, like image dilation. I imagined that people would be playing this on a whiteboard with markers, so training on thicker lines will help with this assumption.
larger_im = imdilate(im2,strel('Disk',3));And while we’re cleaning things up, lets center the image too: For this example, I pulled 5000 training images, and 500 test images. There are many (many many!) more example images available in the files, so feel free to increase these numbers if you’re so inclined. Create and train the network Now that our dataset is ready to go, let’s start training. Here’s the structure of the network:
layers = [ imageInputLayer([256 256 1]) convolution2dLayer(3,16,'Padding',1) batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,32,'Padding',1) batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride') convolution2dLayer(3,64,'Padding',1) batchNormalizationLayer reluLayer fullyConnectedLayer(5) softmaxLayer classificationLayer];How did I pick this specific structure? Glad you asked. I stole from other people that had already created a network. This specific network structure is 99% accurate on the MNIST dataset, so I figured it was a good starting point for these handwritten drawings. Here’s a handy plot created with this code:
lgraph = layerGraph(layers); plot(layers)
I’ll admit, this is a fairly boring layer graph since it’s all a straight line, but if you were working with DAG networks, you could easily see the connections of a complicated network. I trained this with a zippy NVIDIA P100 GPU in roughly 20 minutes. Test images set aside give an accuracy of roughly 90%. For an autonomous driving scenario, I would need to go back and refine the algorithm. For a game of Pictionary, this is a perfectly acceptable number in my opinion.
predLabelsTest = net.classify(uint8(imgDataTest)); testAccuracy = sum(predLabelsTest == labelsTest') / numel(labelsTest)
testAccuracy = 0.8996Debug the network Let’s drill down into the accuracy to give more insight into the trained network. One way to look at the specific categories’ predictions is to create a confusion matrix. A very simple option is to create a heatmap. This works similar to a confusion matrix – assuming you have the same number of images in each category – which we do: 500 test images per category.
% visualize where the predicted label doesn't match the actual
tt = table(predLabelsTest, categorical(labelsTest'),'VariableNames',{'Predicted','Actual'});
figure('name','confusion matrix'); heatmap(tt,'Actual','Predicted');
One thing that pops out is that ants and wristwatches tend to confuse the classifier. This seems like reasonable confusion. If we were confusing wine glasses with ants, then we might have a problem.
There are two reasons for error in our Pictionary classifier:
- The person guessing can’t identify the object, or
- The person drawing doesn’t describe the object well enough.
% pick out the times where the predicted label doesn't match the actual
idx = find(predLabelsTest ~= labelsTest');
loser_ants = idx(idx < 500);
montage(imgDataTest(:,:,1,loser_ants));
I’m going to go out on a limb and say that at least 18 of these ants shouldn’t be called ants at all. In defense of my classifier, let’s say that you were playing Pictionary, and someone drew this:
% select an image from the bunch ii = 169 img = squeeze(uint8(imgDataTest(:,:,1,ii))); actualLabel = labelsTest(ii); predictedLabel = net.classify(img); imshow(img,[]); title(['Predicted: ' char(predictedLabel) ', Actual: ' char(actualLabel)])
What are the chances you would call that an ant?? If a computer classifies this as an umbrella, is that really an error?? Try the classifier on new images Now the whole point of this example was to see what it would be like for new images in real life. I drew an ant... ...and the trained model can now tell me what it thinks it is. Let’s throw in a confidence rating too.
s = snapshot(webcam); myDrawing = segmentImage(s(:,:,2)); myDrawing = imresize(myDrawing,[256,256]); % ensure this is the right size for processing [predval,conf] = net.classify(uint8(myDrawing)); imshow(myDrawing); title(string(predval)+ sprintf(' %.1f%%',max(conf)*100));I used a segmentation function created with image processing, that finds the object I drew and flips the black to white and white to black. Looks like my Pictionary skills are good enough! This code is on FileExchange, and you can see this example in a webinar I recorded with my colleague Gabriel Ha. Leave me a comment below if you have any questions. Join me next time when I talk to a MathWorks engineer about using CNNs for Point Cloud segmentation!
- 범주:
- Deep Learning
댓글
댓글을 남기려면 링크 를 클릭하여 MathWorks 계정에 로그인하거나 계정을 새로 만드십시오.