Artificial Intelligence

Apply machine learning and deep learning

Transfer Learning for Grayscale Images

In computer vision, pretrained models are often used and adapted to the task at hand by performing transfer learning. Transfer learning involves modifying and retraining a pretrained network with your data. Most pretrained models are trained on large, colorful image datasets like ImageNet. But how can you use one of these pretrained models when you have grayscale images?
In this blog post, I am going to show you a simple way in MATLAB to quickly adapt your grayscale images to be used as an input to a model pretrained with colored images.
 

What Are Grayscale Images?

Grayscale images are single-channel images where each pixel represents a shade of gray, ranging from black (intensity = 0) to white (intensity = 255 for 8-bit images). Unlike RGB images, which have three color channels (red, green, and blue), grayscale images encode intensity information only.
For simplicity, I am going to demonstrate how to preprocess grayscale images of handwritten digits for retraining a pretrained model. But grayscale images are not only encountered in toy datasets. They are commonly used in medical imaging (e.g., X-rays and MRIs), remote sensing, and document analysis, where color information is less critical.
Examples of grayscale images. Handwritten digits on the left. X-ray image on the right.
Figure: Collage of grayscale images of handwritten digits
 

Transfer Learning

Here, I am showing you how to prepare your grayscale image dataset, which contains images of handwritten digits, for retraining a network that was trained on colored images.

Load Pretrained Network

Load the GoogLeNet network and get the input size of the network.
net = imagePretrainedNetwork("googlenet",NumClasses=numClasses);
inputSize = networkInputSize(net);

Load Training Data

Create an image data store.
dataFolder = fullfile(toolboxdir("nnet"),"nndemos","nndatasets","DigitDataset");
imds = imageDatastore(dataFolder,IncludeSubfolders=true,LabelSource="foldernames");
Get information on classes.
classNames = categories(imds.Labels);
numClasses = numel(classNames);

Prepare Data for Training

Split the data into training and validation datasets.
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.8,"randomized");
Specify the data augmentation options. You might need to change these options, depending on your data. Applying randomized preprocessing operations helps prevent the network from overfitting and memorizing the exact details of the training images.
imageAugmenter = imageDataAugmenter( ...
    RandXReflection=true, ...
    RandXTranslation=[-30 30], ...
    RandYTranslation=[-30 30]);
Augment the training and validation datasets. The augmented image datastores are transformed batches of the training and validation data. The transformations make the input images compatible with the deep neural network.
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
    ColorPreprocessing="gray2rgb",DataAugmentation=imageAugmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation, ...
    ColorPreprocessing="gray2rgb");
When you created the augmented image datastores, you specified that the input grayscale images need color preprocessing. It’s as easy as that!

Retrain Network

Find the last learnable layer of the network. Freeze the weights of the network, keeping the last learnable layer unfrozen.
[layerName,learnableNames] = networkHead(net);
net = freezeNetwork(net,LayerNamesToIgnore=layerName);
Retrain the network with your input images and training options.
net = trainnet(augimdsTrain,net,"crossentropy",options);
To learn more about transfer learning with MATLAB , check out the Transfer Learning in 10 Lines of Code video and this example.
 

Wrapping Up

Leave a comment below to discuss if you have encountered grayscale images in your workflow, what models you typically use for transfer learning, and what preprocessing techniques you find the most useful.
|
  • print

Comments

To leave a comment, please click here to sign in to your MathWorks Account or create a new one.

Loading...
Go to top of page