Artificial Intelligence

Apply machine learning and deep learning

Explainability in Object Detection for MATLAB, TensorFlow, and PyTorch Models

In R2024a, Deep Learning Toolbox Verification Library introduced the d-rise function. D-RISE is an explainability tool that helps you visualize and understand  which parts are important for object detection. If you need a refresher on what explainable AI is and why it’s important, watch this short video.
D-RISE is a model-agnostic method that doesn’t require knowledge of the inner workings of the object detection model, as proposed in this paper. It produces a saliency map (image with highlighted areas that most affect the prediction) given a specific image and object detector. Because it’s a general and model-agnostic method, it can be applied to different types of object detectors.
Get an object detector (from MATLAB, TensorFlow, or PyTorch), use it to detect objects in an image of a road, and explain the detection with DRISE.
Figure: Perform object detection and explain the detection results using D-RISE for MATLAB, imported TensorFlow, and PyTorch models.
 
In this blog post, I’ll show you how to use D-RISE to explain object detection results for MATLAB, TensorFlow™, and PyTorch® models. More specifically, I will walk through how to use D-RISE for these object detectors:
  1. Built-in MATLAB object detector.
  2. TensorFlow object detector that is imported into MATLAB.
  3. PyTorch object detector that is used in MATLAB with co-execution.
The drise function provides two syntaxes for explaining the predictions of built-in MATLAB object detectors and any other type of object detector. To use drise with TensorFlow (even imported) and PyTorch models, you must use the custom detection syntax of the function.
Two syntaxes for drise function; syntax for MATLAB object detectors (on the left) and syntax for other object detectors (on the right).
Figure: Syntaxes of the drise function for built-in MATLAB object detectors and other types of object detectors.
 
But don’t fret, I’ll provide you with the necessary code for all options.  Check out this GitHub repository to get the code of the examples using D-RISE with TensorFlow and PyTorch object detectors.
 

D-RISE with MATLAB Model

This section shows how to use D-RISE with a built-in MATLAB object detector, more specifically a YOLO v2 object detector. You can get the full example from here.
Read in a test image from the Caltech Cars data set.
img = imread("testCar.png");
img = im2single(img);
Detect vehicles in the test image by using the trained YOLO v2 detector. Note that the detector in this example has been trained to detect only vehicles, whereas the TensorFlow and PyTorch object detectors in the following examples are detecting all objects in the image.
Pass the test image and the detector as input to the detect function. The detect function returns the bounding boxes and the detection scores.
[bboxes,scores,labels] = detect(detector,img);
figure
annotatedImage = insertObjectAnnotation(img,"rectangle",bboxes,scores);
imshow(annotatedImage)
Object detection with YOLO v2, detecting two vehicles.
Use the drise function to create saliency maps explaining the detections made by the YOLO v2 object detector.
scoreMap = drise(detector,img);
Plot the saliency map over the image. Areas highlighted in red are more significant in the detection than areas highlighted in blue.
tiledlayout(1,2,TileSpacing="tight")

for i = 1:2
    nexttile
    annotatedImage = insertObjectAnnotation(img,"rectangle",bboxes(i,:),scores(i));
    imshow(annotatedImage)
    hold on
    imagesc(scoreMap(:,:,i),AlphaData=0.5)
    title("DRISE Map: Detection " + i)
    hold off
end

colormap jet
Saliency maps for two detected vehicles.
To see more examples on how to use D-RISE with MATLAB object detectors, see the d-rise reference page.
 

D-RISE with TensorFlow Model

This section shows how to import a TensorFlow model for object detection, how to use the imported model in MATLAB and visualize the detections, and how to use D-RISE to explain the predictions of the model. You can get the code for this example from here.

Import and Initialize Network

Import a pretrained TensorFlow model for object detection. The model is in the SavedModel format.
modelFolder = "centernet_resnet50_v2_512x517_coco17";
detector = importNetworkFromTensorFlow(modelFolder);
Specify the input size of the imported network. You can find the expected image size stated in the name of the TensorFlow network. The data format of the dlarray object must have the dimensions "SSCB" (spatial, spatial, channel, batch) to represent a 2-D image input. For more information, see Data Formats for Prediction with dlnetwork. Then, initialize the imported network.
input_size = [512 512 3];
detector = detector.initialize(dlarray(ones(512,512,3,1),"SSCB"))
detector = 
  dlnetwork with properties:

         Layers: [1×1 centernet_resnet50_v2_512x517_coco17.kCall11498]
    Connections: [0×2 table]
     Learnables: [388×3 table]
          State: [0×3 table]
     InputNames: {'kCall11498'}
    OutputNames: {'kCall11498/detection_boxes'  'kCall11498/detection_classes'  'kCall11498/detection_scores'  'kCall11498/num_detections'}
    Initialized: 1

  View summary with summary.

Detect with Imported Network

The network has four outputs: bounding boxes, classes, scores, and number of detections.
mlOutputNames = detector.OutputNames'
mlOutputNames = 4×1 cell
'kCall11498/detection_boxes'  
'kCall11498/detection_classes'
'kCall11498/detection_scores' 
'kCall11498/num_detections'   
Read the image that you want to use for object detection. Perform object detection on the image.
img = imread("testCar.png");
[y1,y2,y3,y4] = detector.predict(dlarray(single(img),"SSCB"));

Get Detections with Highest Scores

Create a map of all the network outputs.
mlOutputMap = containers.Map;
mlOutputs = {y1,y2,y3,y4};
for i = 1:numel(mlOutputNames)
    opNameStrSplit = strsplit(mlOutputNames{i},'/');
    opName = opNameStrSplit{end};
    mlOutputMap(opName) = mlOutputs{i};
end
Get the detections with scores above the threshold thr, and the corresponding class labels.
thr = 0.5;
[bboxes,classes,scores,num_box] = bestDetections(img,mlOutputMap,thr);
class_labels = getClassLabels(classes);

Visualize Object Detection

Create the labels associated with each of the detected objects.
colons = repmat(": ",[1 num_box]);
percents = repmat("%",[1 num_box]);
labels = strcat(class_labels,colons,string(round(scores*100)),percents);
Visualize the object detection results with annotations.
figure
outputImage = insertObjectAnnotation(img,"rectangle",bboxes,labels,LineWidth=1,Color="green");
imshow(outputImage)
TensorFlow object detector detects three objects in input image.

Explainability for Object Detector

Explain the predictions of the object detection network using D-RISE. Specify a custom detection function to use D-RISE with the imported TensorFlow network.
targetBox = bboxes(1,:);
targetLabel = 1;
scoreMap = drise(@(img)customDetector(img),img,targetBox,targetLabel);
Plot the results. As mentioned above, areas highlighted in red are more significant in the detection than areas highlighted in blue.
figure
annotatedImage = insertObjectAnnotation(img,"rectangle",targetBox,"vehicle");
imshow(annotatedImage)
hold on
imagesc(scoreMap,AlphaData=0.5)
title("DRISE Map: Custom Detector")
hold off
colormap jet
Saliency map created by DRISE for TensorFlow object detector
 

D-RISE with PyTorch Model

This section shows how to perform object detection with a PyTorch model using co-execution, and how to use D-RISE to explain the predictions of the PyTorch model. You can get the code for this example from here.

Python Environment

Set up the Python interpreter for MATLAB by using the pyenv function. Specify the version of Python to use.
pe = pyenv(Version=".\env\Scripts\python.exe",ExecutionMode="OutOfProcess");

Object Detection

Read the image that you want to use for object detection.
img_filename = "testCar.png";
img = imread(img_filename);
Perform object detection with a PyTorch model using co-execution.
pyrun("from PT_object_detection import loadPTmodel, detectPT")
[model,weights] = pyrun("[a,b] = loadPTmodel()",["a" "b"]);
predictions = pyrun("a = detectPT(b,c,d)","a",b=img,c=model,d=weights);
Convert the prediction outputs from Python data types to MATLAB data types.
[bboxes,labels,scores] = convertVariables(predictions,imread(img_filename));
Get the class labels.
class_labels = getClassLabels(labels);

Visualization

Create the labels associated with each of the detected objects.
num_box = length(scores);
colons = repmat(": ",[1 num_box]);
percents = repmat("%",[1 num_box]);
class_labels1 = strcat(class_labels,colons,string(round(scores'*100)),percents);
Visualize the object detection results with annotations.
figure
outputImage = insertObjectAnnotation(img,"rectangle",bboxes,class_labels1,LineWidth=1,Color="green");
imshow(outputImage)
Object detected in input image by PyTorch object detector.

Explainability

Explain the predictions of the PyTorch model using D-RISE. Specify a custom detection function to use D-RISE.
targetBbox = bboxes(1,:);
targetLabel = 1;
scoreMap = drise(@(img)customDetector(img),img,targetBbox,targetLabel,...
    NumSamples=512,MiniBatchSize=8);
You can plot the saliency map computed by D-RISE as you previously did for the object detection results for the TensorFlow model.

 
|
  • print

댓글

댓글을 남기려면 링크 를 클릭하여 MathWorks 계정에 로그인하거나 계정을 새로 만드십시오.

Loading...
Go to top of page