Explainability in Object Detection for MATLAB, TensorFlow, and PyTorch Models
![Get an object detector (from MATLAB, TensorFlow, or PyTorch), use it to detect objects in an image of a road, and explain the detection with DRISE.](https://blogs.mathworks.com/deep-learning/files/2024/05/drise_options-1.png)
- Built-in MATLAB object detector.
- TensorFlow object detector that is imported into MATLAB.
- PyTorch object detector that is used in MATLAB with co-execution.
![Two syntaxes for drise function; syntax for MATLAB object detectors (on the left) and syntax for other object detectors (on the right).](https://blogs.mathworks.com/deep-learning/files/2024/05/drise_syntax.png)
img = imread("testCar.png"); img = im2single(img);Detect vehicles in the test image by using the trained YOLO v2 detector. Note that the detector in this example has been trained to detect only vehicles, whereas the TensorFlow and PyTorch object detectors in the following examples are detecting all objects in the image. Pass the test image and the detector as input to the detect function. The detect function returns the bounding boxes and the detection scores.
[bboxes,scores,labels] = detect(detector,img); figure annotatedImage = insertObjectAnnotation(img,"rectangle",bboxes,scores); imshow(annotatedImage)
![Object detection with YOLO v2, detecting two vehicles.](https://blogs.mathworks.com/deep-learning/files/2024/05/object_detection_yolov2.png)
scoreMap = drise(detector,img);Plot the saliency map over the image. Areas highlighted in red are more significant in the detection than areas highlighted in blue.
tiledlayout(1,2,TileSpacing="tight") for i = 1:2 nexttile annotatedImage = insertObjectAnnotation(img,"rectangle",bboxes(i,:),scores(i)); imshow(annotatedImage) hold on imagesc(scoreMap(:,:,i),AlphaData=0.5) title("DRISE Map: Detection " + i) hold off end colormap jet
![Saliency maps for two detected vehicles.](https://blogs.mathworks.com/deep-learning/files/2024/05/saliency_map_yolov2.png)
Import and Initialize Network
Import a pretrained TensorFlow model for object detection. The model is in the SavedModel format.modelFolder = "centernet_resnet50_v2_512x517_coco17"; detector = importNetworkFromTensorFlow(modelFolder);Specify the input size of the imported network. You can find the expected image size stated in the name of the TensorFlow network. The data format of the dlarray object must have the dimensions "SSCB" (spatial, spatial, channel, batch) to represent a 2-D image input. For more information, see Data Formats for Prediction with dlnetwork. Then, initialize the imported network.
input_size = [512 512 3]; detector = detector.initialize(dlarray(ones(512,512,3,1),"SSCB"))
detector = dlnetwork with properties: Layers: [1×1 centernet_resnet50_v2_512x517_coco17.kCall11498] Connections: [0×2 table] Learnables: [388×3 table] State: [0×3 table] InputNames: {'kCall11498'} OutputNames: {'kCall11498/detection_boxes' 'kCall11498/detection_classes' 'kCall11498/detection_scores' 'kCall11498/num_detections'} Initialized: 1 View summary with summary.
Detect with Imported Network
The network has four outputs: bounding boxes, classes, scores, and number of detections.mlOutputNames = detector.OutputNames'
mlOutputNames = 4×1 cell 'kCall11498/detection_boxes' 'kCall11498/detection_classes' 'kCall11498/detection_scores' 'kCall11498/num_detections'Read the image that you want to use for object detection. Perform object detection on the image.
img = imread("testCar.png"); [y1,y2,y3,y4] = detector.predict(dlarray(single(img),"SSCB"));
Get Detections with Highest Scores
Create a map of all the network outputs.mlOutputMap = containers.Map; mlOutputs = {y1,y2,y3,y4}; for i = 1:numel(mlOutputNames) opNameStrSplit = strsplit(mlOutputNames{i},'/'); opName = opNameStrSplit{end}; mlOutputMap(opName) = mlOutputs{i}; endGet the detections with scores above the threshold thr, and the corresponding class labels.
thr = 0.5; [bboxes,classes,scores,num_box] = bestDetections(img,mlOutputMap,thr); class_labels = getClassLabels(classes);
Visualize Object Detection
Create the labels associated with each of the detected objects.colons = repmat(": ",[1 num_box]); percents = repmat("%",[1 num_box]); labels = strcat(class_labels,colons,string(round(scores*100)),percents);Visualize the object detection results with annotations.
figure outputImage = insertObjectAnnotation(img,"rectangle",bboxes,labels,LineWidth=1,Color="green"); imshow(outputImage)
![TensorFlow object detector detects three objects in input image.](https://blogs.mathworks.com/deep-learning/files/2024/05/object_detection_for_car_image_TF.png)
Explainability for Object Detector
Explain the predictions of the object detection network using D-RISE. Specify a custom detection function to use D-RISE with the imported TensorFlow network.targetBox = bboxes(1,:); targetLabel = 1; scoreMap = drise(@(img)customDetector(img),img,targetBox,targetLabel);Plot the results. As mentioned above, areas highlighted in red are more significant in the detection than areas highlighted in blue.
figure annotatedImage = insertObjectAnnotation(img,"rectangle",targetBox,"vehicle"); imshow(annotatedImage) hold on imagesc(scoreMap,AlphaData=0.5) title("DRISE Map: Custom Detector") hold off colormap jet
![Saliency map created by DRISE for TensorFlow object detector](https://blogs.mathworks.com/deep-learning/files/2024/05/drise_for_car_image.png)
Python Environment
Set up the Python interpreter for MATLAB by using the pyenv function. Specify the version of Python to use.pe = pyenv(Version=".\env\Scripts\python.exe",ExecutionMode="OutOfProcess");
Object Detection
Read the image that you want to use for object detection.img_filename = "testCar.png"; img = imread(img_filename);Perform object detection with a PyTorch model using co-execution.
pyrun("from PT_object_detection import loadPTmodel, detectPT") [model,weights] = pyrun("[a,b] = loadPTmodel()",["a" "b"]); predictions = pyrun("a = detectPT(b,c,d)","a",b=img,c=model,d=weights);Convert the prediction outputs from Python data types to MATLAB data types.
[bboxes,labels,scores] = convertVariables(predictions,imread(img_filename));Get the class labels.
class_labels = getClassLabels(labels);
Visualization
Create the labels associated with each of the detected objects.num_box = length(scores); colons = repmat(": ",[1 num_box]); percents = repmat("%",[1 num_box]); class_labels1 = strcat(class_labels,colons,string(round(scores'*100)),percents);Visualize the object detection results with annotations.
figure outputImage = insertObjectAnnotation(img,"rectangle",bboxes,class_labels1,LineWidth=1,Color="green"); imshow(outputImage)
![Object detected in input image by PyTorch object detector.](https://blogs.mathworks.com/deep-learning/files/2024/05/object_detection_for_car_image_PT.png)
Explainability
Explain the predictions of the PyTorch model using D-RISE. Specify a custom detection function to use D-RISE.targetBbox = bboxes(1,:); targetLabel = 1; scoreMap = drise(@(img)customDetector(img),img,targetBbox,targetLabel,... NumSamples=512,MiniBatchSize=8);You can plot the saliency map computed by D-RISE as you previously did for the object detection results for the TensorFlow model.
Comments
To leave a comment, please click here to sign in to your MathWorks Account or create a new one.