Data Type Conversion Between MATLAB and Python: What’s New in R2024a
Python Data Type | MATLAB Data Type | |
Pandas DataFrame | MATLAB table | |
Python dictionary | MATLAB dictionary | |
Python dictionary | MATLAB structure |
Object Detection Example
This example calls a PyTorch model from MATLAB to detect objects in the input image. When you call the PyTorch model you want to (1) pass MATLAB data to a format that the model can process and (2) convert the model outputs (Python data type) to a MATLAB data type that can be used in MATLAB for visualization, as shown here, but also as an input to the next steps in your workflow.Python Environment
Set up the Python interpreter for MATLAB by using the pyenv function.pe = pyenv(Version=".\env\Scripts\python.exe",ExecutionMode="OutOfProcess");
Python Code for Object Detection
The following Python code is saved in the PT_object_detection.py file, which you can call from MATLAB to perform object detection with a PyTorch model.import torch import torchvision as vision import numpy def loadPTmodel(): # Initialize model with the best available weights weights = vision.models.detection.FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT model = vision.models.detection.fasterrcnn_resnet50_fpn_v2(weights=weights,box_score_thresh=0.95) model.eval() return model, weights def detectPT(img,model,weights): # Reshape image and convert to a tensor. X = numpy.asarray(img) X_torch1 = torch.from_numpy(numpy.copy(X)) if X_torch1.ndim==3: X_torch = torch.permute(X_torch1,(2,0,1)) elif X_torch1.ndim==4: X_torch = torch.permute(X_torch1,(3,2,0,1)) # Initialize the inference transforms preprocess = weights.transforms() # Apply inference preprocessing transforms batch = [preprocess(X_torch)] # Use the model if X_torch.ndim==3: prediction = model(batch)[0] elif X_torch.ndim==4: prediction = model(list(batch[0])) return prediction
Object Detection
Read the image that you want to use for object detection.img_filename = "testCar.png"; img = imread(img_filename);Perform object detection with a PyTorch model using co-execution.
pyrun("from PT_object_detection import loadPTmodel, detectPT") [model,weights] = pyrun("[a,b] = loadPTmodel()",["a" "b"]); predictions_pt = pyrun("a = detectPT(b,c,d)","a",b=img,c=model,d=weights);The output of the object detector is a Python dictionary.
class(predictions_pt)
ans = 'py.dict'
Convert Predictions
Convert the objection detection results from a Python dictionary to a MATLAB structure.predictions = struct(predictions_pt)
predictions = struct with fields: boxes: [1×1 py.torch.Tensor] labels: [1×1 py.torch.Tensor] scores: [1×1 py.torch.Tensor]The data variables in the structure prediction are PyTorch tensors. Convert the variables into MATLAB arrays.
predictions.boxes = double(predictions.boxes.detach().numpy); predictions.labels = double(predictions.labels.tolist)'; predictions.scores = double(predictions.scores.tolist)'; predictions
predictions = struct with fields: boxes: [4×4 double] labels: [4×1 double] scores: [4×1 double]The bounding boxes require further processing to align with the input image. A bounding box is an axis-aligned rectangle defined in spatial coordinates as an Mx4 numeric matrix with rows of the form [x y w h], where:
- M is the number of axis-aligned rectangles.
- x and y specify the upper-left corner of the rectangle.
- w specifies the width of the rectangle, which is its length along the x-axis.
- h specifies the height of the rectangle, which is its length along the y-axis.
predictions.boxes = cat(2,predictions.boxes(:,1:2)+1,predictions.boxes(:,3:4)/2);Get the class labels. The PyTorch model was trained on the COCO data set.
class_labels = getClassLabels(predictions.labels);
Visualization
Create the labels associated with each of the detected objects.num_box = length(predictions.scores); colons = repmat(": ",[1 num_box]); percents = repmat("%",[1 num_box]); class_labels1 = strcat(class_labels,colons,string(round(predictions.scores'*100)),percents);Visualize the object detection results with annotations.
figure outputImage = insertObjectAnnotation(img,... "rectangle",predictions.boxes,class_labels1,LineWidth=1,Color="green"); imshow(outputImage)So, you can see the object detection was successful with a high degree of confidence.
コメント
コメントを残すには、ここ をクリックして MathWorks アカウントにサインインするか新しい MathWorks アカウントを作成します。