{"id":8412,"date":"2017-02-24T12:37:27","date_gmt":"2017-02-24T17:37:27","guid":{"rendered":"https:\/\/blogs.mathworks.com\/pick\/?p=8412"},"modified":"2018-09-14T06:53:02","modified_gmt":"2018-09-14T10:53:02","slug":"deep-learning-transfer-learning-in-10-lines-of-matlab-code","status":"publish","type":"post","link":"https:\/\/blogs.mathworks.com\/pick\/2017\/02\/24\/deep-learning-transfer-learning-in-10-lines-of-matlab-code\/","title":{"rendered":"Deep Learning: Transfer Learning in 10 lines of MATLAB Code"},"content":{"rendered":"<p><a href=\"https:\/\/www.mathworks.com\/matlabcentral\/profile\/authors\/4291457-avi-nehemiah\">Avi&#8217;s<\/a> pick of the week is <a href=\"https:\/\/www.mathworks.com\/matlabcentral\/fileexchange\/61639-deep-learning--transfer-learning-in-10-lines-of-matlab-code\">Deep Learning: Transfer Learning in 10 lines of MATLAB Code<\/a> by the <a href=\"https:\/\/www.mathworks.com\/matlabcentral\/profile\/authors\/8743315\">MathWorks Deep Learning Toolbox Team<\/a>.<\/p>\n<p>Have you ever wanted to try deep learning to solve a problem but didn&#8217;t go through with it because you didn&#8217;t have enough data or were not comfortable designing deep neural networks? Transfer learning is a very practical way to use deep learning by modifying an existing deep network (usually trained by an expert) to work with your data. This post contains a detailed explanation on how transfer learning works. Before you read through the rest of this post, I would highly recommend you watch <strong><a href=\"https:\/\/www.mathworks.com\/videos\/deep-learning-with-matlab-transfer-learning-in-10-lines-of-matlab-code-1487714838381.html\"><span style=\"text-decoration: underline;\">this video<\/span><\/a><\/strong> by Joe Hicklin (pictured below) that illustrates what I will explain in more detail.<\/p>\n<p>&nbsp;<\/p>\n<p><a href=\"https:\/\/blogs.mathworks.com\/pick\/files\/2017-02-17_12-19-12.png\"><img decoding=\"async\" loading=\"lazy\" width=\"854\" height=\"479\" class=\"aligncenter size-full wp-image-8413\" src=\"https:\/\/blogs.mathworks.com\/pick\/files\/2017-02-17_12-19-12.png\" alt=\"\" \/><\/a><\/p>\n<h2><\/h2>\n<p>The problem I am trying to solve with transfer learning is to distinguish between 5 categories of food &#8211; cupcakes, burgers, apple pie, hot dogs and ice cream. To get started you need two things:<\/p>\n<div>\n<ol>\n<li>Training images of the different types of food we are trying to recognize<\/li>\n<li>A pre-trained deep neural network that we can re-train for our data and task<\/li>\n<\/ol>\n<\/div>\n<h2 id=\"3\">Load Training Images<\/h2>\n<p>I have all my images stored in the &#8220;Training Data &#8221; folder with sub-directories corresponding to the different classes. I chose this structure as it allows imageDataStore to use the folder names as labels for the image categories.<\/p>\n<p><a href=\"https:\/\/blogs.mathworks.com\/pick\/files\/Folders.png\"><img decoding=\"async\" loading=\"lazy\" width=\"370\" height=\"187\" class=\"size-full wp-image-8414 aligncenter\" src=\"https:\/\/blogs.mathworks.com\/pick\/files\/Folders.png\" alt=\"\" \/><\/a><\/p>\n<p>To bring the images into MATLAB I use \u00a0imageDatastore. imageDataStore is used to manage large collections of images. With a single line of code I can bring all my training data into MATLAB, in my case I have a few thousand images but I would use the same code even if I had millions of images. Another advantage of using the imageDataStore is that it supports reading images from disk, network drives, databases and big-data file systems like Hadoop.<\/p>\n<pre class=\"codeinput\">allImages = imageDatastore(<span class=\"string\">'TrainingData'<\/span>, <span class=\"string\">'IncludeSubfolders'<\/span>, true, <span class=\"string\">'LabelSource'<\/span>, <span class=\"string\">'foldernames'<\/span>);\r\n<\/pre>\n<p>I then split the training data into two sets one for training and the other for testing, the split I used is 80% for training and the rest for testing.<\/p>\n<pre class=\"codeinput\">[trainingImages, testImages] = splitEachLabel(allImages, 0.8, <span class=\"string\">'randomize'<\/span>);\r\n<\/pre>\n<h2 id=\"7\">Load Pre-trained Network (AlexNet)<\/h2>\n<p>My next step is to load a pre-trained model, I&#8217;ll use AlexNet which is a deep convolutional neural network that has been trained to recognize 1000 categories of objects and was trained on millions of images. AlexNet has already learned how to perform the basic image pre-processing that is needed to distinguish between different categories of images in it&#8217;s early layers, my goal is to &#8220;transfer&#8221; that learning to my task of categorizing different kinds of food.<\/p>\n<pre class=\"codeinput\">alex = alexnet;\r\n<\/pre>\n<p>Now lets take a look at the structure of the AlexNet convolutional neural network.<\/p>\n<pre class=\"codeinput\">layers = alex.Layers\r\n<\/pre>\n<pre class=\"codeoutput\">layers = \r\n\r\n  25x1 Layer array with layers:\r\n\r\n     1   'data'     Image Input                   227x227x3 images with 'zerocenter' normalization\r\n     2   'conv1'    Convolution                   96 11x11x3 convolutions with stride [4  4] and padding [0  0]\r\n     3   'relu1'    ReLU                          ReLU\r\n     4   'norm1'    Cross Channel Normalization   cross channel normalization with 5 channels per element\r\n     5   'pool1'    Max Pooling                   3x3 max pooling with stride [2  2] and padding [0  0]\r\n     6   'conv2'    Convolution                   256 5x5x48 convolutions with stride [1  1] and padding [2  2]\r\n     7   'relu2'    ReLU                          ReLU\r\n     8   'norm2'    Cross Channel Normalization   cross channel normalization with 5 channels per element\r\n     9   'pool2'    Max Pooling                   3x3 max pooling with stride [2  2] and padding [0  0]\r\n    10   'conv3'    Convolution                   384 3x3x256 convolutions with stride [1  1] and padding [1  1]\r\n    11   'relu3'    ReLU                          ReLU\r\n    12   'conv4'    Convolution                   384 3x3x192 convolutions with stride [1  1] and padding [1  1]\r\n    13   'relu4'    ReLU                          ReLU\r\n    14   'conv5'    Convolution                   256 3x3x192 convolutions with stride [1  1] and padding [1  1]\r\n    15   'relu5'    ReLU                          ReLU\r\n    16   'pool5'    Max Pooling                   3x3 max pooling with stride [2  2] and padding [0  0]\r\n    17   'fc6'      Fully Connected               4096 fully connected layer\r\n    18   'relu6'    ReLU                          ReLU\r\n    19   'drop6'    Dropout                       50% dropout\r\n    20   'fc7'      Fully Connected               4096 fully connected layer\r\n    21   'relu7'    ReLU                          ReLU\r\n    22   'drop7'    Dropout                       50% dropout\r\n    23   'fc8'      Fully Connected               1000 fully connected layer\r\n    24   'prob'     Softmax                       softmax\r\n    25   'output'   Classification Output         crossentropyex with 'tench', 'goldfish', and 998 other classes\r\n<\/pre>\n<h2 id=\"9\">Modify Pre-trained Network<\/h2>\n<p>AlexNet was trained to recognize 1000 classes, we need to modify it to recognize just 5 classes. To do this I&#8217;m going to modify a couple of layers. Notice how structure of the last few layers now differs from AlexNet<\/p>\n<pre class=\"codeinput\">layers(23) = fullyConnectedLayer(5);\r\nlayers(25) = classificationLayer\r\n<\/pre>\n<pre class=\"codeoutput\">layers = \r\n\r\n  25x1 Layer array with layers:\r\n\r\n     1   'data'    Image Input                   227x227x3 images with 'zerocenter' normalization\r\n     2   'conv1'   Convolution                   96 11x11x3 convolutions with stride [4  4] and padding [0  0]\r\n     3   'relu1'   ReLU                          ReLU\r\n     4   'norm1'   Cross Channel Normalization   cross channel normalization with 5 channels per element\r\n     5   'pool1'   Max Pooling                   3x3 max pooling with stride [2  2] and padding [0  0]\r\n     6   'conv2'   Convolution                   256 5x5x48 convolutions with stride [1  1] and padding [2  2]\r\n     7   'relu2'   ReLU                          ReLU\r\n     8   'norm2'   Cross Channel Normalization   cross channel normalization with 5 channels per element\r\n     9   'pool2'   Max Pooling                   3x3 max pooling with stride [2  2] and padding [0  0]\r\n    10   'conv3'   Convolution                   384 3x3x256 convolutions with stride [1  1] and padding [1  1]\r\n    11   'relu3'   ReLU                          ReLU\r\n    12   'conv4'   Convolution                   384 3x3x192 convolutions with stride [1  1] and padding [1  1]\r\n    13   'relu4'   ReLU                          ReLU\r\n    14   'conv5'   Convolution                   256 3x3x192 convolutions with stride [1  1] and padding [1  1]\r\n    15   'relu5'   ReLU                          ReLU\r\n    16   'pool5'   Max Pooling                   3x3 max pooling with stride [2  2] and padding [0  0]\r\n    17   'fc6'     Fully Connected               4096 fully connected layer\r\n    18   'relu6'   ReLU                          ReLU\r\n    19   'drop6'   Dropout                       50% dropout\r\n    20   'fc7'     Fully Connected               4096 fully connected layer\r\n    21   'relu7'   ReLU                          ReLU\r\n    22   'drop7'   Dropout                       50% dropout\r\n    23   ''        Fully Connected               5 fully connected layer\r\n    24   'prob'    Softmax                       softmax\r\n    25   ''        Classification Output         crossentropyex\r\n<\/pre>\n<h2 id=\"10\">Perform Transfer Learning<\/h2>\n<p>Now that I have modified the network structure it is time to learn the weights for the last few layers that we modified. For transfer learning we want to change the network ever so slightly. How much a network is changed during training is controlled by the learning rates. Here we do not modify the learning rates of the original layers, i.e. the ones before the last 3. The rates for these layers are already pretty small so they don&#8217;t need to be lowered further. You could even freeze the weights of these early layers by setting the rates to zero.<\/p>\n<pre class=\"codeinput\">opts = trainingOptions(<span class=\"string\">'sgdm'<\/span>, <span class=\"string\">'InitialLearnRate'<\/span>, 0.001, <span class=\"string\">'MaxEpochs'<\/span>, 20, <span class=\"string\">'MiniBatchSize'<\/span>, 64);\r\n<\/pre>\n<p>One of the great things about imageDataStore it lets me specify a &#8220;custom&#8221; read function, in this case I am simply resizing the input images to 227&#215;227 pixels which is what AlexNet expects. I can do this by specifying a function handle \u00a0with code to read and pre-process the image.<\/p>\n<pre class=\"codeinput\">trainingImages.ReadFcn = @readFunctionTrain;\r\n<\/pre>\n<p>Now let&#8217;s go ahead and train the network, this process usually takes about 5-20 minutes on a desktop GPU. This is a great time to grab a cup of coffee.<\/p>\n<pre class=\"codeinput\">myNet = trainNetwork(trainingImages, layers, opts);\r\n<\/pre>\n<pre class=\"codeoutput\">Training on single GPU.\r\nInitializing image normalization.\r\n|=========================================================================================|\r\n|     Epoch    |   Iteration  | Time Elapsed |  Mini-batch  |  Mini-batch  | Base Learning|\r\n|              |              |  (seconds)   |     Loss     |   Accuracy   |     Rate     |\r\n|=========================================================================================|\r\n|            1 |            1 |         2.32 |       1.9052 |       26.56% |       0.0010 |\r\n|            1 |           50 |        42.65 |       0.7895 |       73.44% |       0.0010 |\r\n|            2 |          100 |        83.74 |       0.5341 |       87.50% |       0.0010 |\r\n|            3 |          150 |       124.51 |       0.3321 |       87.50% |       0.0010 |\r\n|            4 |          200 |       165.79 |       0.3374 |       87.50% |       0.0010 |\r\n|            5 |          250 |       208.79 |       0.2333 |       87.50% |       0.0010 |\r\n|            5 |          300 |       250.70 |       0.1183 |       96.88% |       0.0010 |\r\n|            6 |          350 |       291.97 |       0.1157 |       96.88% |       0.0010 |\r\n|            7 |          400 |       333.00 |       0.1074 |       93.75% |       0.0010 |\r\n|            8 |          450 |       374.26 |       0.0379 |       98.44% |       0.0010 |\r\n|            9 |          500 |       415.51 |       0.0699 |       96.88% |       0.0010 |\r\n|            9 |          550 |       456.80 |       0.1083 |       95.31% |       0.0010 |\r\n|           10 |          600 |       497.80 |       0.1243 |       93.75% |       0.0010 |\r\n|           11 |          650 |       538.83 |       0.0231 |      100.00% |       0.0010 |\r\n|           12 |          700 |       580.26 |       0.0353 |       96.88% |       0.0010 |\r\n|           13 |          750 |       621.47 |       0.0154 |      100.00% |       0.0010 |\r\n|           13 |          800 |       662.39 |       0.0104 |      100.00% |       0.0010 |\r\n|           14 |          850 |       703.69 |       0.0360 |       98.44% |       0.0010 |\r\n|           15 |          900 |       744.72 |       0.0065 |      100.00% |       0.0010 |\r\n|           16 |          950 |       785.74 |       0.0375 |       98.44% |       0.0010 |\r\n|           17 |         1000 |       826.64 |       0.0102 |      100.00% |       0.0010 |\r\n|           17 |         1050 |       867.78 |       0.0026 |      100.00% |       0.0010 |\r\n|           18 |         1100 |       909.37 |       0.0019 |      100.00% |       0.0010 |\r\n|           19 |         1150 |       951.01 |       0.0120 |      100.00% |       0.0010 |\r\n|           20 |         1200 |       992.63 |       0.0009 |      100.00% |       0.0010 |\r\n|           20 |         1240 |      1025.67 |       0.0015 |      100.00% |       0.0010 |\r\n|=========================================================================================|\r\n<\/pre>\n<h2 id=\"13\">Test Network Performance<\/h2>\n<p>Now let&#8217;s the test the performance of our new &#8220;snack recognizer&#8221; on the test set. We&#8217;ll see that the algorithm has an accuracy of over 80%. You can improve this accuracy by adding more training data or tweaking some of the parameters for training.<\/p>\n<pre class=\"codeinput\">testImages.ReadFcn = @readFunctionTrain;\r\npredictedLabels = classify(myNet, testImages);\r\naccuracy = mean(predictedLabels == testImages.Labels)\r\n<\/pre>\n<pre class=\"codeoutput\">accuracy =\r\n\r\n    0.8260\r\n\r\n<\/pre>\n<h2 id=\"14\">Try the classifier on live video<\/h2>\n<p>Now that we have a deep learning based snack recognizer, i&#8217;d encourage you to grab a snack and try it out yourself.<\/p>\n<p><a href=\"https:\/\/blogs.mathworks.com\/pick\/files\/2017-02-17_12-25-33.png\"><img decoding=\"async\" loading=\"lazy\" width=\"628\" height=\"480\" class=\"size-full wp-image-8416 aligncenter\" src=\"https:\/\/blogs.mathworks.com\/pick\/files\/2017-02-17_12-25-33.png\" alt=\"\" \/><\/a><\/p>\n<p>&nbsp;<\/p>\n<p>Also check out <a href=\"https:\/\/www.mathworks.com\/videos\/deep-learning-in-11-lines-of-matlab-code-1481229977318.html\">this video<\/a> by Joe that shows how to perform image recognition on a live stream from a webcam.<\/p>\n","protected":false},"excerpt":{"rendered":"<div class=\"overview-image\"><img decoding=\"async\"  class=\"img-responsive\" src=\"https:\/\/blogs.mathworks.com\/pick\/files\/2017-02-17_12-19-12.png\" onError=\"this.style.display ='none';\" \/><\/div>\n<p>Avi&#8217;s pick of the week is Deep Learning: Transfer Learning in 10 lines of MATLAB Code by the MathWorks Deep Learning Toolbox Team.<br \/>\nHave you ever wanted to try deep learning to solve a problem&#8230; <a class=\"read-more\" href=\"https:\/\/blogs.mathworks.com\/pick\/2017\/02\/24\/deep-learning-transfer-learning-in-10-lines-of-matlab-code\/\">read more >><\/a><\/p>\n","protected":false},"author":132,"featured_media":0,"comment_status":"open","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":[],"categories":[16],"tags":[],"_links":{"self":[{"href":"https:\/\/blogs.mathworks.com\/pick\/wp-json\/wp\/v2\/posts\/8412"}],"collection":[{"href":"https:\/\/blogs.mathworks.com\/pick\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/blogs.mathworks.com\/pick\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/blogs.mathworks.com\/pick\/wp-json\/wp\/v2\/users\/132"}],"replies":[{"embeddable":true,"href":"https:\/\/blogs.mathworks.com\/pick\/wp-json\/wp\/v2\/comments?post=8412"}],"version-history":[{"count":13,"href":"https:\/\/blogs.mathworks.com\/pick\/wp-json\/wp\/v2\/posts\/8412\/revisions"}],"predecessor-version":[{"id":10131,"href":"https:\/\/blogs.mathworks.com\/pick\/wp-json\/wp\/v2\/posts\/8412\/revisions\/10131"}],"wp:attachment":[{"href":"https:\/\/blogs.mathworks.com\/pick\/wp-json\/wp\/v2\/media?parent=8412"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/blogs.mathworks.com\/pick\/wp-json\/wp\/v2\/categories?post=8412"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/blogs.mathworks.com\/pick\/wp-json\/wp\/v2\/tags?post=8412"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}