Data Type Conversion Between MATLAB and Python: What’s New in R2024a
When combining MATLAB with Python® to create deep learning workflows, data type conversion between the two frameworks can be time consuming and sometimes perplexing. I 've certainly experimented with figuring out how to make MATLAB data compatible with a Python-based model and vice versa. So, you can understand why I am so excited that there are two new data type conversions introduced in MATLAB R2024a.
In the following table you can see the new data type conversions. And it came as no surprise to me that they made the list of Mike Croucher's favorite R2024a updates.
Performing the data type conversions is very easy and you can learn all the details in these documentation topics: Use Python Pandas DataFrames in MATLAB and Use Python Dictionaries in MATLAB. So, in this blog post, instead of showing you the code for all the possible conversions, I am going to walk you through an object detection example and present a use case for converting data from a Python dictionary to a MATLAB structure.
Python Data Type | MATLAB Data Type | |
Pandas DataFrame | MATLAB table | |
Python dictionary | MATLAB dictionary | |
Python dictionary | MATLAB structure |
Object Detection Example
This example calls a PyTorch model from MATLAB to detect objects in the input image. When you call the PyTorch model you want to (1) pass MATLAB data to a format that the model can process and (2) convert the model outputs (Python data type) to a MATLAB data type that can be used in MATLAB for visualization, as shown here, but also as an input to the next steps in your workflow.Python Environment
Set up the Python interpreter for MATLAB by using the pyenv function.pe = pyenv(Version=".\env\Scripts\python.exe",ExecutionMode="OutOfProcess");
Python Code for Object Detection
The following Python code is saved in the PT_object_detection.py file, which you can call from MATLAB to perform object detection with a PyTorch model.import torch import torchvision as vision import numpy def loadPTmodel(): # Initialize model with the best available weights weights = vision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT model = vision.models.detection.fasterrcnn_resnet50_fpn_v2(weights=weights,box_score_thresh=0.95) model.eval() return model, weights def detectPT(img,model,weights): # Reshape image and convert to a tensor. X = numpy.asarray(img) X_torch1 = torch.from_numpy(numpy.copy(X)) if X_torch1.ndim==3: X_torch = torch.permute(X_torch1,(2,0,1)) elif X_torch1.ndim==4: X_torch = torch.permute(X_torch1,(3,2,0,1)) # Initialize the inference transforms preprocess = weights.transforms() # Apply inference preprocessing transforms batch = [preprocess(X_torch)] # Use the model if X_torch.ndim==3: prediction = model(batch)[0] elif X_torch.ndim==4: prediction = model(list(batch[0])) return prediction
Object Detection
Read the image that you want to use for object detection.img_filename = "testCar.png"; img = imread(img_filename);Perform object detection with a PyTorch model using co-execution.
pyrun("from PT_object_detection import loadPTmodel, detectPT") [model,weights] = pyrun("[a,b] = loadPTmodel()",["a" "b"]); predictions_pt = pyrun("a = detectPT(b,c,d)","a",b=img,c=model,d=weights);The output of the object detector is a Python dictionary.
class(predictions_pt)
ans = 'py.dict'
Convert Predictions
Convert the objection detection results from a Python dictionary to a MATLAB structure.predictions = struct(predictions_pt)
predictions = struct with fields: boxes: [1×1 py.torch.Tensor] labels: [1×1 py.torch.Tensor] scores: [1×1 py.torch.Tensor]The data variables in the structure prediction are PyTorch tensors. Convert the variables into MATLAB arrays.
predictions.boxes = double(predictions.boxes.detach().numpy); predictions.labels = double(predictions.labels.tolist)'; predictions.scores = double(predictions.scores.tolist)'; predictions
predictions = struct with fields: boxes: [4×4 double] labels: [4×1 double] scores: [4×1 double]The bounding boxes require further processing to align with the input image. A bounding box is an axis-aligned rectangle defined in spatial coordinates as an Mx4 numeric matrix with rows of the form [x y w h], where:
- M is the number of axis-aligned rectangles.
- x and y specify the upper-left corner of the rectangle.
- w specifies the width of the rectangle, which is its length along the x-axis.
- h specifies the height of the rectangle, which is its length along the y-axis.
predictions.boxes = cat(2,predictions.boxes(:,1:2)+1,predictions.boxes(:,3:4)/2);Get the class labels. The PyTorch model was trained on the COCO data set.
class_labels = getClassLabels(predictions.labels);
Visualization
Create the labels associated with each of the detected objects.num_box = length(predictions.scores); colons = repmat(": ",[1 num_box]); percents = repmat("%",[1 num_box]); class_labels1 = strcat(class_labels,colons,string(round(predictions.scores'*100)),percents);Visualize the object detection results with annotations.
figure outputImage = insertObjectAnnotation(img,... "rectangle",predictions.boxes,class_labels1,LineWidth=1,Color="green"); imshow(outputImage)So, you can see the object detection was successful with a high degree of confidence.
コメント
コメントを残すには、ここ をクリックして MathWorks アカウントにサインインするか新しい MathWorks アカウントを作成します。