Preprocess ImageRead the image you want to classify. Resize the image to the input size of the network.
imgOriginal = imread("banana.png"); InputSize = [224 224 3]; img = imresize(imgOriginal,InputSize(1:2));You must preprocess the image in the same way as the training data. For more information, see Input Data Preprocessing. Rescale the image. Then, normalize the image by subtracting the training images mean and dividing by the training images standard deviation.
imgProcessed = rescale(img,0,1); meanIm = [0.485 0.456 0.406]; stdIm = [0.229 0.224 0.225]; imgProcessed = (imgProcessed - reshape(meanIm,[1 1 3]))./reshape(stdIm,[1 1 3]);Permute the image data from the Deep Learning Toolbox dimension ordering (HWCN) to the PyTorch dimension ordering (NCHW), where H is the height of the images, W is the width of the images, C is the number of channels, and N is the number of observations. This is a necessary step to use the image for prediction with a PyTorch model (before importing it into MATLAB).
imgForTorch = permute(img,[4 3 1 2]);For more information on input dimension data ordering for different deep learning platforms, see Input Dimension Ordering.
Install Python and LibrariesYou might have multiple versions of Python installed on your desktop. For example, a MacBook has a pre-installed Python version 2.7, which is likely not the version you want to use. So, it is good practice to create a virtual environment for your project to be in control of the Python version and libraries that you are using. The following commands show how you can setup a virtual environment on a MacBook. If you are using a Windows machine, the commands might be slightly different. Go to your working folder. Create and activate the Python virtual environment env in your working folder.
python3.10 -m venv env source env/bin/activateInstall the necessary Python libraries for this example. Check the installed versions of the libraries.
pip3 install numpy torch torchvision python3 -m pip show numpy torch torchvisionFor reference, we used:
- Python 3.10.8
- numpy 1.23.4
- torch 1.13.0
- torchvision 1.13.0
pe = pyenv(ExecutionMode="OutOfProcess",Version="./env/bin/python3.10");Now, you are ready to call Python from MATLAB.
Explore PyTorch ModelsGet three pretrained PyTorch models (VGG, MobileNet v2, and MNASNet) from the torchvision library. For more information on each model and how to load it, see torchvision.models. You can access Python libraries directly from MATLAB by adding the py. prefix to the Python name. For more information on how to access Python libraries, see Getting Started: Access Python Modules from MATLAB.
model1 = py.torchvision.models.vgg16(pretrained=true); model2 = py.torchvision.models.mobilenet_v2(pretrained=true); model3 = py.torchvision.models.mnasnet1_0(pretrained=true);Convert the image to a tensor in order to classify the image with a PyTorch model.
X = py.numpy.asarray(imgForTorch); X_torch = py.torch.from_numpy(X).float();To find the fastest PyTorch model by calling Python from MATLAB, predict the image classification label multiple times for each of the models. We run the speed test on all models, but we are showing here only how to compute the average speed for the MNASNet model.
N = 30; for i = 1:N tic model3(X_torch); T(i) = toc; end mean(T)
ans = 0.1096This simple test showed that the fastest model in predicting is MNASNet. You can run different tests on PyTorch models easily and fast with co-execution to find the model that best suits your application and workflow. To import the PyTorch model into MATLAB, you first must trace the model and save it, which you can also do by co-executing Python from MATLAB. Execute Python statements in the Python interpreter directly from MATLAB by using the pyrun function. The pyrun function is a stateful interface between MATLAB and Python, which saves the state between the two platforms. Save the fastest PyTorch model, among the three models compared. Then, trace the model. For more information on how to trace a PyTorch model, see Torch documentation: Tracing a function.
pyrun("import torch;X_rnd = torch.rand(1,3,224,224)") pyrun("traced_model = torch.jit.trace(model3.forward,X_rnd)",model3=model3) pyrun("traced_model.save('traced_mnasnet1_0.pt')")
Import PyTorch NetworkImport the MNASNet model by using the importNetworkFromPyTorch function.
net = importNetworkFromPyTorch("traced_mnasnet1_0.pt");The importNetworkFromPyTorch function was introduced in MATLAB R2022b as part of the Deep Learning Toolbox Converter for PyTorch Models support package. For more information, read our previous blog post What’s New in Interoperability with TensorFlow and PyTorch. The function imports the model as an uninitialized dlnetwork object. Create an image input layer. Then, add the image input layer to the imported network and initialize the network by using the addInputLayer function.
inputLayer = imageInputLayer(InputSize,Normalization="none"); net = addInputLayer(net,inputLayer,Initialize=true);Here a simple image classification was shown . By converting a PyTorch or TensorFlow model into a MATLAB network, you gain access to all the deep learning workflows that MATLAB supports for building complete AI systems. For more information on working with models imported versus co-execution, see the ‘Comparison of capabilities for working with deep learning models in MATLAB’ table in our previous blog post: Importing Models from TensorFlow, PyTorch, and ONNX (Summary section).
ConclusionKey takeaways were presented for each step of the workflow. If we had to pick the three key takeaways…
- Image Classification in MATLAB Using TensorFlow
- Hyperparameter Tuning in MATLAB using Experiment Manager & TensorFlow
- PyTorch and TensorFlow Co-Execution for Training a Speech Command Recognition System
- Importing Models from TensorFlow, PyTorch, and ONNX
- What’s New in Interoperability with TensorFlow and PyTorch
To leave a comment, please click here to sign in to your MathWorks Account or create a new one.