Explainability in Object Detection for MATLAB, TensorFlow, and PyTorch Models
In R2024a, Deep Learning Toolbox Verification Library introduced the d-rise function. D-RISE is an explainability tool that helps you visualize and understand which parts are important for object detection. If you need a refresher on what explainable AI is and why it’s important, watch this short video.
D-RISE is a model-agnostic method that doesn’t require knowledge of the inner workings of the object detection model, as proposed in this paper. It produces a saliency map (image with highlighted areas that most affect the prediction) given a specific image and object detector. Because it’s a general and model-agnostic method, it can be applied to different types of object detectors.
Figure: Perform object detection and explain the detection results using D-RISE for MATLAB, imported TensorFlow, and PyTorch models.
In this blog post, I’ll show you how to use D-RISE to explain object detection results for MATLAB, TensorFlow™, and PyTorch® models. More specifically, I will walk through how to use D-RISE for these object detectors:
- Built-in MATLAB object detector.
- TensorFlow object detector that is imported into MATLAB.
- PyTorch object detector that is used in MATLAB with co-execution.
img = imread("testCar.png"); img = im2single(img);Detect vehicles in the test image by using the trained YOLO v2 detector. Note that the detector in this example has been trained to detect only vehicles, whereas the TensorFlow and PyTorch object detectors in the following examples are detecting all objects in the image. Pass the test image and the detector as input to the detect function. The detect function returns the bounding boxes and the detection scores.
[bboxes,scores,labels] = detect(detector,img); figure annotatedImage = insertObjectAnnotation(img,"rectangle",bboxes,scores); imshow(annotatedImage)Use the drise function to create saliency maps explaining the detections made by the YOLO v2 object detector.
scoreMap = drise(detector,img);Plot the saliency map over the image. Areas highlighted in red are more significant in the detection than areas highlighted in blue.
tiledlayout(1,2,TileSpacing="tight") for i = 1:2 nexttile annotatedImage = insertObjectAnnotation(img,"rectangle",bboxes(i,:),scores(i)); imshow(annotatedImage) hold on imagesc(scoreMap(:,:,i),AlphaData=0.5) title("DRISE Map: Detection " + i) hold off end colormap jetTo see more examples on how to use D-RISE with MATLAB object detectors, see the d-rise reference page. This section shows how to import a TensorFlow model for object detection, how to use the imported model in MATLAB and visualize the detections, and how to use D-RISE to explain the predictions of the model. You can get the code for this example from here.
Import and Initialize Network
Import a pretrained TensorFlow model for object detection. The model is in the SavedModel format.modelFolder = "centernet_resnet50_v2_512x517_coco17"; detector = importNetworkFromTensorFlow(modelFolder);Specify the input size of the imported network. You can find the expected image size stated in the name of the TensorFlow network. The data format of the dlarray object must have the dimensions "SSCB" (spatial, spatial, channel, batch) to represent a 2-D image input. For more information, see Data Formats for Prediction with dlnetwork. Then, initialize the imported network.
input_size = [512 512 3]; detector = detector.initialize(dlarray(ones(512,512,3,1),"SSCB"))
detector = dlnetwork with properties: Layers: [1×1 centernet_resnet50_v2_512x517_coco17.kCall11498] Connections: [0×2 table] Learnables: [388×3 table] State: [0×3 table] InputNames: {'kCall11498'} OutputNames: {'kCall11498/detection_boxes' 'kCall11498/detection_classes' 'kCall11498/detection_scores' 'kCall11498/num_detections'} Initialized: 1 View summary with summary.
Detect with Imported Network
The network has four outputs: bounding boxes, classes, scores, and number of detections.mlOutputNames = detector.OutputNames'
mlOutputNames = 4×1 cell 'kCall11498/detection_boxes' 'kCall11498/detection_classes' 'kCall11498/detection_scores' 'kCall11498/num_detections'Read the image that you want to use for object detection. Perform object detection on the image.
img = imread("testCar.png"); [y1,y2,y3,y4] = detector.predict(dlarray(single(img),"SSCB"));
Get Detections with Highest Scores
Create a map of all the network outputs.mlOutputMap = containers.Map; mlOutputs = {y1,y2,y3,y4}; for i = 1:numel(mlOutputNames) opNameStrSplit = strsplit(mlOutputNames{i},'/'); opName = opNameStrSplit{end}; mlOutputMap(opName) = mlOutputs{i}; endGet the detections with scores above the threshold thr, and the corresponding class labels.
thr = 0.5; [bboxes,classes,scores,num_box] = bestDetections(img,mlOutputMap,thr); class_labels = getClassLabels(classes);
Visualize Object Detection
Create the labels associated with each of the detected objects.colons = repmat(": ",[1 num_box]); percents = repmat("%",[1 num_box]); labels = strcat(class_labels,colons,string(round(scores*100)),percents);Visualize the object detection results with annotations.
figure outputImage = insertObjectAnnotation(img,"rectangle",bboxes,labels,LineWidth=1,Color="green"); imshow(outputImage)
Explainability for Object Detector
Explain the predictions of the object detection network using D-RISE. Specify a custom detection function to use D-RISE with the imported TensorFlow network.targetBox = bboxes(1,:); targetLabel = 1; scoreMap = drise(@(img)customDetector(img),img,targetBox,targetLabel);Plot the results. As mentioned above, areas highlighted in red are more significant in the detection than areas highlighted in blue.
figure annotatedImage = insertObjectAnnotation(img,"rectangle",targetBox,"vehicle"); imshow(annotatedImage) hold on imagesc(scoreMap,AlphaData=0.5) title("DRISE Map: Custom Detector") hold off colormap jetThis section shows how to perform object detection with a PyTorch model using co-execution, and how to use D-RISE to explain the predictions of the PyTorch model. You can get the code for this example from here.
Python Environment
Set up the Python interpreter for MATLAB by using the pyenv function. Specify the version of Python to use.pe = pyenv(Version=".\env\Scripts\python.exe",ExecutionMode="OutOfProcess");
Object Detection
Read the image that you want to use for object detection.img_filename = "testCar.png"; img = imread(img_filename);Perform object detection with a PyTorch model using co-execution.
pyrun("from PT_object_detection import loadPTmodel, detectPT") [model,weights] = pyrun("[a,b] = loadPTmodel()",["a" "b"]); predictions = pyrun("a = detectPT(b,c,d)","a",b=img,c=model,d=weights);Convert the prediction outputs from Python data types to MATLAB data types.
[bboxes,labels,scores] = convertVariables(predictions,imread(img_filename));Get the class labels.
class_labels = getClassLabels(labels);
Visualization
Create the labels associated with each of the detected objects.num_box = length(scores); colons = repmat(": ",[1 num_box]); percents = repmat("%",[1 num_box]); class_labels1 = strcat(class_labels,colons,string(round(scores'*100)),percents);Visualize the object detection results with annotations.
figure outputImage = insertObjectAnnotation(img,"rectangle",bboxes,class_labels1,LineWidth=1,Color="green"); imshow(outputImage)
Explainability
Explain the predictions of the PyTorch model using D-RISE. Specify a custom detection function to use D-RISE.targetBbox = bboxes(1,:); targetLabel = 1; scoreMap = drise(@(img)customDetector(img),img,targetBbox,targetLabel,... NumSamples=512,MiniBatchSize=8);You can plot the saliency map computed by D-RISE as you previously did for the object detection results for the TensorFlow model.
コメント
コメントを残すには、ここ をクリックして MathWorks アカウントにサインインするか新しい MathWorks アカウントを作成します。