Hello Everyone! It's Johanna, and Steve has allowed me to take over the blog from time to time to talk about deep learning.
I'm back for another episode of:
"Deep Learning in Action:
Cool projects created at MathWorks"
This aims to give you insight into what we’re working on at MathWorks.
Today’s demo is called "Wheel of Fortune" or alternatively "Do you sign MATLAB?" and it’s the third article in a series of posts, including:
The developer of the demo is Joshua Wang who led a team that participated in a MathWorks Hack Day, a fun day where developers at MathWorks get 24 hours to work on a project of the choice related to MATLAB. The team decided to work on a sign language project, and I was drawn to this example because #1) this demo uses images, #2) this demo uses deep learning, and #3) this demo uses MATLAB.
When I reached out to Josh initially, I got this response...
…Not only is it cool that MathWorks tools made it possible to do all of this in a day (our coding all happened on a Wednesday), but it certainly ties in well with our social mission.
I was intrigued how they got this up and running in under 24 hours, so I asked to see the code.
After viewing and running the code, it appeared a significant portion of the work was a nice user interface in MATLAB that looks like Wheel of Fortune: complete with a spinning wheel and the ability to play the game with an opponent. See the game in action here:
User interfaces in MATLAB are great, but not unique to deep learning. So for the remainder of this post, I want to walk through the deep learning portion of the application: how they built the CNN to recognize the letters. I'll ask Josh a few questions, and offer a chance for you to ask any questions to Josh and team in the comments section.
Demo: Sign Language in MATLAB
The basis of this demo is to have a CNN determine which letter is being signed, A through Z. Here are a few sample images of random letters and their corresponding image:
>> samples = imds.splitEachLabel(1,'randomize',true);
>> montage(samples)
These images are from a training dataset that can be downloaded from GitHub here.
Deep Learning Code
This section walks through the code to create and train the network in 4 parts:
- Load the dataset
- Load the network
- Modify the network
- Set training options
%% Load Data
imds = imageDatastore('dataset', ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');
%% Load network
net = inceptionv3();
lgraph = layerGraph(net);
figure('Units','normalized','Position',[0.1 0.1 0.8 0.8]);
plot(lgraph)
yikes! Inception-v3 is a complicated structure.
Note: if you don't have Inception-v3 downloaded, simply typing
>> inceptionv3
on the command line will provide a link to download the model.
A list of all models, [including the new ONNX model converter] can be found here: https://www.mathworks.com/solutions/deep-learning/models.html
Next, change the final layers to reflect the number of classes in the dataset. Since this is a DAG network, add layers and then verify the network is re-connected correctly.
%% Edit the architecture
inputSize = net.Layers(1).InputSize;
lgraph = removeLayers(lgraph, {'predictions','predictions_softmax','ClassificationLayer_predictions'});
numClasses = numel(categories(imdsTrain.Labels));
newLayers = [
fullyConnectedLayer(numClasses,'Name','fc','WeightLearnRateFactor',10,'BiasLearnRateFactor',10)
softmaxLayer('Name','softmax')
classificationLayer('Name','classoutput')];
lgraph = addLayers(lgraph,newLayers);
lgraph = connectLayers(lgraph,'avg_pool','fc');
Next, set up the training options:
layers = lgraph.Layers;
connections = lgraph.Connections;
layers(1:110) = freezeWeights(layers(1:110));
lgraph = createLgraphUsingConnections(layers,connections);
pixelRange = [-30 30];
imageAugmenter = imageDataAugmenter( ...
'RandXReflection',true, ...
'RandXTranslation',pixelRange, ...
'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain) ...
'DataAugmentation',imageAugmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
options = trainingOptions('sgdm', ...
'MiniBatchSize',10, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'Verbose',true);
I'll be honest, I'm still not in love with augmented image datastore (since I don't love the extra lines of code in what would otherwise be a very simple and easy to read section), but it's growing on me in this example since it allows you to create extra samples of images using translation, reflection, and scaling. It also resizes all the images to the appropriate size required by the network.
Finally, train the network.
net = trainNetwork(augimdsTrain,lgraph,options);
Note: I trained the network using my old Tesla K40 GPU card and it took roughly 1 hour 15 minutes to run. I cut the training data size significantly since things appeared to be taking longer than I'd like, so I'd imagine this would take even longer with the full training set.
Q&A with Josh
1. First I have to ask, what is quality engineering? What do you do?
Quality Engineering at MathWorks is a group of software engineers who build the infrastructure and comprehensive test environment to support and champion MathWorks’ primary goal of delivering bug-free, feature-rich software to our customers. Specifically, I work on the web and cloud services which power the MathWorks’ online offerings like MATLAB Online, MATLAB Mobile, and MATLAB Grader. |
2. What is your relationship to deep learning? Do you work on deep learning in your role, or just interested in it outside of work, or both?
I don’t directly work on deep learning in my role at MathWorks. However, recent advances in machine learning have great potential to transform how customers use our products in an increasingly connected world, and our hack day project was designed to demonstrate one way we could use deep learning to make scientific computing more intuitive, contextual, and accessible. |
3. Whose idea was a sign language project, and why?
Making MATLAB accessible to anyone, including those with impaired hearing, is an important part of our mission to provide the ultimate computing environment for technical computation, visualization, design, simulation, and implementation. In addition, making MATLAB more accessible via gesture control would accelerate the pace of engineering and science by enabling its use in environments where a person may not have easy access to traditional input mechanisms, such as an operating room or factory floor. |
4. Why did you choose to transfer learn on Inception-v3?
Inception v3 has higher accuracy than models like GoogLeNet, and is easily available in MATLAB with examples. |
5. What was your validation accuracy approximately?
For someone unfamiliar with American Sign Language, probably around 70%. Certain letters were very similar and therefore more difficult for our model to distinguish, like M and N, and our training data set was fairly homogenous – from the right hand of a single person in front of the same background. Given that this was a single-day effort, we didn’t spend a lot of time tuning the model or improving the training data. |
6. What was your rationale for 6 epochs? Was it training time, or accuracy? If you had more time, would you train longer?
I don’t think we had any real reason for this, nor did everyone even realize we had chosen 6 epochs – we just took the defaults that were available in MATLAB or the Github project that we used for an initial prototype. |
Thanks to the team for the demo (The entire team consists of : Anil Patro, Oral Dalay, Harshad Tambekar, Krishan Sharma, Rohit Kudva, Michael Broshi, and Sara Burke) and thanks to Josh for taking the time to walk me through it! I hope you enjoyed it as well. Anything else you'd like to ask the team? Leave a comment below!
(Hope they don't mind me putting in a picture of the team. Congrats on a job well done!)
댓글
댓글을 남기려면 링크 를 클릭하여 MathWorks 계정에 로그인하거나 계정을 새로 만드십시오.