Deep Learning in 11 Lines of MATLAB Code
Avi’s pick of the week is Deep Learning in 11 Lines of MATLAB Code by the MathWorks Deep Learning Toolbox Team. This post is follow up to this post by Jiro and provides a more detailed explanation.
If you are interested in learning more about deep learning or trying out some of latest deep learning research in MATLAB this blog post will walk you through the first steps to getting started. The entry on File Exchange provides everything you need to download one of the most popular deep neural networks and use it to classify images using live video from a webcam. Before you read through the rest of this post, I would highly recommend you watch this video by Joe Hicklin (pictured below) that illustrates what I will explain in more detail.
This post covers how to download a pre-trained deep convolutional neural network and use it to classify images in a live video stream.
Load Pre-Trained Network
The pre-trained model I’ll use in this post is known as AlexNet. This model was trained to recognize 1000 different categories of objects in images, and was trained using over a million images. The AlexNet model is available as a support package in MATLAB, you can learn more about the AlexNet support package from this blog post. Now let’s load the pre-trained network into MATLAB.
nnet = alexnet; % Will prompt support package install if unavaliable
View Network Structure
Now let us take a look at the structure of AlexNet model. Notice that the first layer requires a 227×227 RGB image, and the last layer classifies 1000 different objects.
nnet.Layers
ans = 25x1 Layer array with layers: 1 'data' Image Input 227x227x3 images with 'zerocenter' normalization 2 'conv1' Convolution 96 11x11x3 convolutions with stride [4 4] and padding [0 0] 3 'relu1' ReLU ReLU 4 'norm1' Cross Channel Normalization cross channel normalization with 5 channels per element 5 'pool1' Max Pooling 3x3 max pooling with stride [2 2] and padding [0 0] 6 'conv2' Convolution 256 5x5x48 convolutions with stride [1 1] and padding [2 2] 7 'relu2' ReLU ReLU 8 'norm2' Cross Channel Normalization cross channel normalization with 5 channels per element 9 'pool2' Max Pooling 3x3 max pooling with stride [2 2] and padding [0 0] 10 'conv3' Convolution 384 3x3x256 convolutions with stride [1 1] and padding [1 1] 11 'relu3' ReLU ReLU 12 'conv4' Convolution 384 3x3x192 convolutions with stride [1 1] and padding [1 1] 13 'relu4' ReLU ReLU 14 'conv5' Convolution 256 3x3x192 convolutions with stride [1 1] and padding [1 1] 15 'relu5' ReLU ReLU 16 'pool5' Max Pooling 3x3 max pooling with stride [2 2] and padding [0 0] 17 'fc6' Fully Connected 4096 fully connected layer 18 'relu6' ReLU ReLU 19 'drop6' Dropout 50% dropout 20 'fc7' Fully Connected 4096 fully connected layer 21 'relu7' ReLU ReLU 22 'drop7' Dropout 50% dropout 23 'fc8' Fully Connected 1000 fully connected layer 24 'prob' Softmax softmax 25 'output' Classification Output crossentropyex with 'tench', 'goldfish', and 998 other classes
Classify Test Images
Before we connect a webcam and try classifying images on a live stream, lets try classifying a single test image.
I = imread('peppers.png'); imshow(I) title('Input Image')
To classify the input image we use the ‘classify’ function. Since AlexNet expects the input image to be 227×227 let’s resize the input image.
Iin = imresize(I,[227 227]); figure label = classify(nnet, Iin); % Classify the picture imshow(Iin) % Show the picture title(char(label)) % Show the label
You’ll see that the network accurately classifies the image as a bell pepper. You can also see the prediction score or confidence of the prediction for each of the 1000 classes by outputing the score from the classify method.
Try on Live Video
Now that we’ve tried the deep learning image classifier on a single image lets try doing the same with a live video stream. We connect to a webcam using the webcam object in MATLAB.
camera = webcam; % Connect to camera picture = snapshot(camera); % Take picture imshow(picture) % Show picture % Run live video capture and classification in loop for n = 1:100 picture = snapshot(camera); % Take picture picture = imresize(picture,[227,227]); % Resize picture label = classify(nnet, picture); % Classify picture imshow(picture); % Show picture title(char(label)) % Show label drawnow end
- Category:
- Picks
Comments
To leave a comment, please click here to sign in to your MathWorks Account or create a new one.