Image-to-Image Regression
Today I'd like to talk about the basic concepts of setting up a network to train on an image-to-image regression problem.
This demo came about for two reasons:
- There are quite a few questions on MATLAB answers about image–to–image deep learning problems.
- I’m planning a future in-depth post with an image processing/deep learning expert, where we’ll be getting into the weeds on regression, and it would be good to understand the basics to keep up with him.
Given an image, predict which category an object belongs to.
In regression problems, there are no longer discrete categories. The output could be a non-discrete value: for example, given an image, output the rotation value. Along the same lines, given an image, predict a new image! To learn more about the concept, of image-to-image deep learning we can start with a simple example in documentation: https://www.mathworks.com/help/deeplearning/examples/remove-noise-from-color-image-using-pretrained-neural-network.html This is a great introduction to the topic that's explained well in the example. Plus, if you’re trying to denoise an image, this example solves the problem, so you're done! However, the goal of this post is understand how to create our custom deep learning algorithm from scratch. The hardest part is getting the data set up. Everything else should be reasonably straightforward.All About Datastores
Datastores deserve a post of their own, but let me just say, if you can appreciate and master datastores, you can conquer the world. At a high level, datastores make sense: They are an efficient way of bringing in data for deep learning (and other) applications. You don’t have to deal with memory management, and deep learning functions know how to handle the datastores as an input to the function. This is all good. “How do I get datastores to work for image-to-image deep learning training data?” Great question!!randomPatchExtractionDatastore
I'm going to recommend using this handy function called Random Patch Extraction Datastore, which is what I use in the example below. We’re not exactly short with our naming convention here, but you get a great idea of what you’re getting with this function! Extracting random patches of your images is a great way to cultivate more input images, especially if you're low on data. The algorithm needs enough data samples to train accurately, so we can cut the images into smaller pieces and deliver more examples for the network to learn. This function will take an input datastore, a corresponding output datastore, and a patch size.The code:
Our problem is going to be image deblurring. And we're going to set up this up from scratch. I have a perfect final image: I blur the image: and I put all of my data into individual folders.blurredDir = createTrainingSet(trainImages); blurredImages = imageDatastore(blurredDir,'FileExtensions','.mat','ReadFcn',@matRead); imagesDir = '.'; trainImagesDir = fullfile(imagesDir,'iaprtc12','images','02'); exts = {'.jpg','.bmp','.png'}; trainImages = imageDatastore(trainImagesDir,'FileExtensions',exts);The blurred image is my input, the perfect/original image is my output. This felt backwards, but I reminded myself: I want the network to see a blurry image and output the clean image as a final result. Visualize the input and output images
im_orig = trainImages.readimage(ii); im_blurred = blurredImages.readimage(ii); imshow(im_orig); title('Clean Image - Final Result'); figure; imshow(im_blurred); title('Blurred Image - Input');Set up data augmentation for even more variety of training images.
augmenter = imageDataAugmenter( ... 'RandRotation',@()randi([0,1],1)*90, ... 'RandXReflection',true);This will rotate the input images a random amount, and allow for reflection on the X axis. Then our random patch extraction datastore is used to compile the input and output images in a way the trainNetwork command will understand.
miniBatchSize = 64; patchSize = [40 40]; patchds = randomPatchExtractionDatastore(blurredImages,trainImages,patchSize, .... 'PatchesPerImage',64, ... 'DataAugmentation',augmenter); patchds.MiniBatchSize = miniBatchSize;
Network layers
To set up an image-to-image regression network, let's start with a set of layers almost right for our example. Computer Vision Toolbox has the function unetLayers that allows you to set up the layers of a semantic segmentation network (U-Net) quickly.lgraph = unetLayers([40 40 3] , 3,'encoderDepth',3);We have to alter this slightly to fit our network by adding an L2 loss layer. Remove the last 2 layers, replace them with a regression layer.
lgraph = lgraph.removeLayers('Softmax-Layer'); lgraph = lgraph.removeLayers('Segmentation-Layer'); lgraph = lgraph.addLayers(regressionLayer('name','regressionLayer')); lgraph = lgraph.connectLayers('Final-ConvolutionLayer','regressionLayer');deepNetworkDesigner app will also remove and connect new layers for you as shown below. Set the training parameters
maxEpochs = 100; epochIntervals = 1; initLearningRate = 0.1; learningRateFactor = 0.1; l2reg = 0.0001; options = trainingOptions('sgdm', ... 'Momentum',0.9, ... 'InitialLearnRate',initLearningRate, ... 'LearnRateSchedule','piecewise', ... 'LearnRateDropPeriod',10, ... 'LearnRateDropFactor',learningRateFactor, ... 'L2Regularization',l2reg, ... 'MaxEpochs',maxEpochs ,... 'MiniBatchSize',miniBatchSize, ... 'GradientThresholdMethod','l2norm', ... 'Plots','training-progress', ... 'GradientThreshold',0.01);and train
modelDateTime = datestr(now,'dd-mmm-yyyy-HH-MM-SS'); net = trainNetwork(patchds,lgraph,options); save(['trainedNet-' modelDateTime '-Epoch-' num2str(maxEpochs*epochIntervals) ... 'ScaleFactors-' num2str(234) '.mat'],'net','options');(...8 hours later...) I came back this morning and… I have a fully trained network! Now the quality may not be the best for deblurring images, because my main intention was to show the setup of the training images and the network. But I have a network that really tries. Show the original image and the blurred image.
testImage = testImages.readimage(randi(400)); LEN = 21; THETA = 11; PSF = fspecial('motion', LEN, THETA); blurredImage = imfilter(testImage, PSF, 'conv', 'circular'); title('Blurry Image'); figure; imshow(testImage); title('Original Image');... and create a 'deblurred' image from the network:
Ideblurred = activations(net,blurredImage,'regressionoutput'); figure; imshow(Ideblurred) Iapprox = rescale(Ideblurred); Iapprox = im2uint8(Iapprox); imshow(Iapprox) title('Denoised Image')Update: I know it says denoised, rather than deblurred, I coppied the code from another example and forgot to switch the title. Keep in mind, the quality of the network was not the point, though now I’m very curious to keep working and improving this network. That’s all today! I hope you found this useful – I had a great time playing in MATLAB, and I hope you do too. UPDATE: I changed a few training parameters and ran the network again. If you're planning on running this code, I would highly suggest training with these parameters:
options = trainingOptions('adam','InitialLearnRate',1e-4,'MiniBatchSize',64,... 'Shuffle','never','MaxEpochs',50,... 'Plots','training-progress');The results are much better:
Copyright 2018 The MathWorks, Inc.
Get the MATLAB code
- Category:
- Deep Learning Example
Comments
To leave a comment, please click here to sign in to your MathWorks Account or create a new one.