Transfer Learning for Grayscale Images
In computer vision, pretrained models are often used and adapted to the task at hand by performing transfer learning. Transfer learning involves modifying and retraining a pretrained network with your data. Most pretrained models are trained on large, colorful image datasets like ImageNet. But how can you use one of these pretrained models when you have grayscale images?
In this blog post, I am going to show you a simple way in MATLAB to quickly adapt your grayscale images to be used as an input to a model pretrained with colored images.
Figure: Collage of grayscale images of handwritten digits
What Are Grayscale Images?
Grayscale images are single-channel images where each pixel represents a shade of gray, ranging from black (intensity = 0) to white (intensity = 255 for 8-bit images). Unlike RGB images, which have three color channels (red, green, and blue), grayscale images encode intensity information only. For simplicity, I am going to demonstrate how to preprocess grayscale images of handwritten digits for retraining a pretrained model. But grayscale images are not only encountered in toy datasets. They are commonly used in medical imaging (e.g., X-rays and MRIs), remote sensing, and document analysis, where color information is less critical.
Transfer Learning
Here, I am showing you how to prepare your grayscale image dataset, which contains images of handwritten digits, for retraining a network that was trained on colored images.Load Pretrained Network
Load the GoogLeNet network and get the input size of the network.net = imagePretrainedNetwork("googlenet",NumClasses=numClasses); inputSize = networkInputSize(net);
Load Training Data
Create an image data store.dataFolder = fullfile(toolboxdir("nnet"),"nndemos","nndatasets","DigitDataset"); imds = imageDatastore(dataFolder,IncludeSubfolders=true,LabelSource="foldernames");Get information on classes.
classNames = categories(imds.Labels); numClasses = numel(classNames);
Prepare Data for Training
Split the data into training and validation datasets.[imdsTrain,imdsValidation] = splitEachLabel(imds,0.8,"randomized");Specify the data augmentation options. You might need to change these options, depending on your data. Applying randomized preprocessing operations helps prevent the network from overfitting and memorizing the exact details of the training images.
imageAugmenter = imageDataAugmenter( ... RandXReflection=true, ... RandXTranslation=[-30 30], ... RandYTranslation=[-30 30]);Augment the training and validation datasets. The augmented image datastores are transformed batches of the training and validation data. The transformations make the input images compatible with the deep neural network.
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ... ColorPreprocessing="gray2rgb",DataAugmentation=imageAugmenter); augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation, ... ColorPreprocessing="gray2rgb");
When you created the augmented image datastores, you specified that the input grayscale images need color preprocessing. It’s as easy as that! |
Retrain Network
Find the last learnable layer of the network. Freeze the weights of the network, keeping the last learnable layer unfrozen.[layerName,learnableNames] = networkHead(net); net = freezeNetwork(net,LayerNamesToIgnore=layerName);Retrain the network with your input images and training options.
net = trainnet(augimdsTrain,net,"crossentropy",options);
To learn more about transfer learning with MATLAB , check out the Transfer Learning in 10 Lines of Code video and this example. |
Wrapping Up
Leave a comment below to discuss if you have encountered grayscale images in your workflow, what models you typically use for transfer learning, and what preprocessing techniques you find the most useful.- Category:
- Deep Learning
Comments
To leave a comment, please click here to sign in to your MathWorks Account or create a new one.