What’s New in Interoperability with TensorFlow and PyTorch
Export to TensorFlow
The support package Deep Learning Toolbox Converter for TensorFlow Models just added the capability to export from MATLAB to TensorFlow, by using the exportNetworkToTensorFlow function. There are many reasons to be excited about the new exportNetworkToTensorFlow function:
Load a pretrained network. The Pretrained Deep Neural Networks documentation page shows you all options of how to get a pretrained network. You can alternatively create your own network.
net = darknet19;Export the network net to TensorFlow. The exportNetworkToTensorFlow function saves the TensorFlow model in the Python package DarkNet19.
- The _init_.py file, which defines the DarkNet19 folder as a regular Python package.
- The model.py file, which contains the code that defines the untrained TensorFlow-Keras model.
- The README.txt file, which provides instructions on how to load the TensorFlow model and save it in HDF5 or SavedModel format.
- The weights.h5 file which contains the model weights in HDF5 format.
Figure: The exported TensorFlow model is saved in the regular Python package DarkNet19.
Load the exported TensorFlow model from the DarkNet19 package.
import DarkNet19 model = DarkNet19.load_model()Save the exported model in the SavedModel format.
Import from PyTorch
In R2022b we introduced the Deep Learning Toolbox Converter for PyTorch Models support package. This initial release supports importing image classification models. Support for other model types will be added in future releases. Use the importNetworkFromPyTorch function to import a PyTorch model. Make sure that the PyTorch model that you are importing is pretrained and traced. I am showing you here how to import an image classification model from PyTorch and initialize it.
Load a pretrained image classification model from the TorchVision library.
import torch from torchvision import models model = models.mnasnet1_0(pretrained=True)Trace the PyTorch model. For more information on how to trace a PyTorch model, go to Torch documentation: Tracing a function.
X = torch.rand(1,3,224,224) traced_model = torch.jit.trace(model.forward,X)Save the PyTorch model.
Import the PyTorch model into MATLAB by using the importNetworkTFromPyTorch function. The function imports the model as an uninitialized dlnetwork object without an input layer.
net = importNetworkFromPyTorch("traced_mnasnet1_0.pt");Specify the input size of the imported network and create an image input layer. Then, add the image input layer to the imported network and initialize the network by using the addInputLayer function (also new in R2022b).
InputSize = [224 224 3]; InputLayer = imageInputLayer(InputSize,Normalization="none"); net = addInputLayer(net,InputLayer,Initialize=true);
Interoperability Capabilities Summary
The interoperability support packages allow you to connect Deep Learning Toolbox with TensorFlow, Pytorch, and ONNX. Use the import and export functions to access models available in open-source repositories and collaborate with colleagues who work in other deep learning frameworks. More information:
- To find all the available import and export functions (and their documentation), go to Deep Learning Import and Export.
- To learn more about how to import and export networks, see Interoperability Between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX.
- For answers to common questions about importing models, see Tips on Importing Models from TensorFlow, PyTorch, and ONNX.
- If you are working just in MATLAB, you can probably find a suitable network in our constantly-updated model repository: MATLAB Deep Learning Model Hub.
- Check out our previous blog post Importing Models from TensorFlow, PyTorch, and ONNX; you will find useful tips on importing and an example you can download (focus on importing from TensorFlow).
To leave a comment, please click here to sign in to your MathWorks Account or create a new one.