Transfer Learning for Grayscale Images
What Are Grayscale Images?
Grayscale images are single-channel images where each pixel represents a shade of gray, ranging from black (intensity = 0) to white (intensity = 255 for 8-bit images). Unlike RGB images, which have three color channels (red, green, and blue), grayscale images encode intensity information only. For simplicity, I am going to demonstrate how to preprocess grayscale images of handwritten digits for retraining a pretrained model. But grayscale images are not only encountered in toy datasets. They are commonly used in medical imaging (e.g., X-rays and MRIs), remote sensing, and document analysis, where color information is less critical.data:image/s3,"s3://crabby-images/933d6/933d678807d65c51ddb739a3e7025a1d893589b0" alt="Examples of grayscale images. Handwritten digits on the left. X-ray image on the right."
Transfer Learning
Here, I am showing you how to prepare your grayscale image dataset, which contains images of handwritten digits, for retraining a network that was trained on colored images.Load Pretrained Network
Load the GoogLeNet network and get the input size of the network.net = imagePretrainedNetwork("googlenet",NumClasses=numClasses); inputSize = networkInputSize(net);
Load Training Data
Create an image data store.dataFolder = fullfile(toolboxdir("nnet"),"nndemos","nndatasets","DigitDataset"); imds = imageDatastore(dataFolder,IncludeSubfolders=true,LabelSource="foldernames");Get information on classes.
classNames = categories(imds.Labels); numClasses = numel(classNames);
Prepare Data for Training
Split the data into training and validation datasets.[imdsTrain,imdsValidation] = splitEachLabel(imds,0.8,"randomized");Specify the data augmentation options. You might need to change these options, depending on your data. Applying randomized preprocessing operations helps prevent the network from overfitting and memorizing the exact details of the training images.
imageAugmenter = imageDataAugmenter( ... RandXReflection=true, ... RandXTranslation=[-30 30], ... RandYTranslation=[-30 30]);Augment the training and validation datasets. The augmented image datastores are transformed batches of the training and validation data. The transformations make the input images compatible with the deep neural network.
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ... ColorPreprocessing="gray2rgb",DataAugmentation=imageAugmenter); augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation, ... ColorPreprocessing="gray2rgb");
When you created the augmented image datastores, you specified that the input grayscale images need color preprocessing. It’s as easy as that! |
Retrain Network
Find the last learnable layer of the network. Freeze the weights of the network, keeping the last learnable layer unfrozen.[layerName,learnableNames] = networkHead(net); net = freezeNetwork(net,LayerNamesToIgnore=layerName);Retrain the network with your input images and training options.
net = trainnet(augimdsTrain,net,"crossentropy",options);
To learn more about transfer learning with MATLAB , check out the Transfer Learning in 10 Lines of Code video and this example. |
Wrapping Up
Leave a comment below to discuss if you have encountered grayscale images in your workflow, what models you typically use for transfer learning, and what preprocessing techniques you find the most useful.- Category:
- Deep Learning
Comments
To leave a comment, please click here to sign in to your MathWorks Account or create a new one.