{"id":15248,"date":"2024-06-10T09:26:31","date_gmt":"2024-06-10T13:26:31","guid":{"rendered":"https:\/\/blogs.mathworks.com\/deep-learning\/?p=15248"},"modified":"2024-09-09T09:08:02","modified_gmt":"2024-09-09T13:08:02","slug":"explainability-in-object-detection-for-matlab-tensorflow-and-pytorch-models","status":"publish","type":"post","link":"https:\/\/blogs.mathworks.com\/deep-learning\/2024\/06\/10\/explainability-in-object-detection-for-matlab-tensorflow-and-pytorch-models\/","title":{"rendered":"Explainability in Object Detection for MATLAB, TensorFlow, and PyTorch Models"},"content":{"rendered":"<h6><\/h6>\r\nIn R2024a, <a href=\"https:\/\/www.mathworks.com\/products\/deep-learning-verification-library.html\"><span data-teams=\"true\"><span class=\"ui-provider a b c d e f g h i j k l m n o p q r s t u v w x y z ab ac ae af ag ah ai aj ak\" dir=\"ltr\">Deep Learning Toolbox Verification Library<\/span><\/span><\/a> introduced the <a href=\"https:\/\/www.mathworks.com\/help\/deeplearning\/ref\/drise.html\">d-rise<\/a> function. D-RISE is an explainability tool that helps you visualize and understand\u00a0 which parts are important for object detection. If you need a refresher on what explainable AI is and why it\u2019s important, watch <a href=\"https:\/\/www.youtube.com\/watch?v=It2Q1eK_Klc\">this short video<\/a>.\r\n<h6><\/h6>\r\nD-RISE is a model-agnostic method that doesn\u2019t require knowledge of the inner workings of the object detection model, as proposed in <a href=\"https:\/\/arxiv.org\/abs\/2006.03204\">this paper<\/a>. It produces a saliency map (image with highlighted areas that most affect the prediction) given a specific image and object detector. Because it\u2019s a general and model-agnostic method, it can be applied to different types of object detectors.\r\n<h6><\/h6>\r\n<img decoding=\"async\" loading=\"lazy\" class=\"alignnone wp-image-15419 size-full\" src=\"https:\/\/blogs.mathworks.com\/deep-learning\/files\/2024\/05\/drise_options-1.png\" alt=\"Get an object detector (from MATLAB, TensorFlow, or PyTorch), use it to detect objects in an image of a road, and explain the detection with DRISE.\" width=\"1534\" height=\"471\" \/>\r\n<h6><\/h6>\r\n<strong>Figure:<\/strong> Perform object detection and explain the detection results using D-RISE for MATLAB, imported TensorFlow, and PyTorch models.\r\n<h6><\/h6>\r\n&nbsp;\r\n<h6><\/h6>\r\nIn this blog post, I\u2019ll show you how to use D-RISE to explain object detection results for MATLAB, TensorFlow\u2122, and PyTorch\u00ae models. More specifically, I will walk through how to use D-RISE for these object detectors:\r\n<h6><\/h6>\r\n<ol>\r\n \t<li><a href=\"#DRISE_with_MATLAB\">Built-in MATLAB object detector.<\/a><\/li>\r\n \t<li><a href=\"#DRISE_with_TensorFlow\">TensorFlow object detector that is imported into MATLAB.<\/a><\/li>\r\n \t<li><a href=\"#DRISE_with_PyTorch\">PyTorch object detector that is used in MATLAB with co-execution.<\/a><\/li>\r\n<\/ol>\r\nThe drise function provides two syntaxes for explaining the predictions of built-in MATLAB object detectors and any other type of object detector. To use drise with TensorFlow (even imported) and PyTorch models, you must use the custom detection syntax of the function.\r\n<h6><\/h6>\r\n<img decoding=\"async\" loading=\"lazy\" class=\"alignnone wp-image-15254 \" src=\"https:\/\/blogs.mathworks.com\/deep-learning\/files\/2024\/05\/drise_syntax.png\" alt=\"Two syntaxes for drise function; syntax for MATLAB object detectors (on the left) and syntax for other object detectors (on the right).\" width=\"678\" height=\"102\" \/>\r\n<h6><\/h6>\r\n<strong>Figure:<\/strong> Syntaxes of the drise function for built-in MATLAB object detectors and other types of object detectors.\r\n<h6><\/h6>\r\n&nbsp;\r\n<h6><\/h6>\r\nBut don\u2019t fret, I\u2019ll provide you with the necessary code for all options.\u00a0 Check out <a href=\"https:\/\/github.com\/matlab-deep-learning\/object-detection-and-explainability-for-tensorflow-and-pytorch-models\">this GitHub repository<\/a> to get the code of the examples using D-RISE with TensorFlow and PyTorch object detectors.\r\n<h6><\/h6>\r\n&nbsp;\r\n<h6><\/h6>\r\n<p style=\"font-size: 20px;\"><a name=\"DRISE_with_MATLAB\"><\/a><strong>D-RISE with MATLAB Model<\/strong><\/p>\r\nThis section shows how to use D-RISE with a built-in MATLAB object detector, more specifically a <a href=\"https:\/\/www.mathworks.com\/help\/vision\/ref\/trainyolov2objectdetector.html\">YOLO v2<\/a> object detector. You can get the full example from <a href=\"https:\/\/www.mathworks.com\/help\/deeplearning\/ref\/drise.html#mw_08924e21-c9c9-4080-be62-e090a3971359\">here<\/a>.\r\n<h6><\/h6>\r\nRead in a test image from the Caltech Cars data set.\r\n<pre>img = imread(\"testCar.png\");\r\nimg = im2single(img);\r\n<\/pre>\r\n<h6><\/h6>\r\nDetect vehicles in the test image by using the trained YOLO v2 detector. Note that the detector in this example has been trained to detect only vehicles, whereas the TensorFlow and PyTorch object detectors in the following examples are detecting all objects in the image.\r\n<h6><\/h6>\r\nPass the test image and the detector as input to the detect function. The detect function returns the bounding boxes and the detection scores.\r\n<pre>[bboxes,scores,labels] = detect(detector,img);\r\nfigure\r\nannotatedImage = insertObjectAnnotation(img,\"rectangle\",bboxes,scores);\r\nimshow(annotatedImage)\r\n<\/pre>\r\n<h6><\/h6>\r\n<img decoding=\"async\" loading=\"lazy\" class=\"alignnone wp-image-15260 size-full\" src=\"https:\/\/blogs.mathworks.com\/deep-learning\/files\/2024\/05\/object_detection_yolov2.png\" alt=\"Object detection with YOLO v2, detecting two vehicles.\" width=\"342\" height=\"230\" \/>\r\n<h6><\/h6>\r\nUse the drise function to create saliency maps explaining the detections made by the YOLO v2 object detector.\r\n<pre>scoreMap = drise(detector,img);\r\n<\/pre>\r\n<h6><\/h6>\r\nPlot the saliency map over the image. Areas highlighted in red are more significant in the detection than areas highlighted in blue.\r\n<pre>tiledlayout(1,2,TileSpacing=\"tight\")\r\n\r\nfor i = 1:2\r\n    nexttile\r\n    annotatedImage = insertObjectAnnotation(img,\"rectangle\",bboxes(i,:),scores(i));\r\n    imshow(annotatedImage)\r\n    hold on\r\n    imagesc(scoreMap(:,:,i),AlphaData=0.5)\r\n    title(\"DRISE Map: Detection \" + i)\r\n    hold off\r\nend\r\n\r\ncolormap jet\r\n<\/pre>\r\n<img decoding=\"async\" loading=\"lazy\" class=\"alignnone wp-image-15263 size-full\" src=\"https:\/\/blogs.mathworks.com\/deep-learning\/files\/2024\/05\/saliency_map_yolov2.png\" alt=\"Saliency maps for two detected vehicles.\" width=\"438\" height=\"155\" \/>\r\n<h6><\/h6>\r\nTo see more examples on how to use D-RISE with MATLAB object detectors, see the\u00a0<a href=\"https:\/\/www.mathworks.com\/help\/deeplearning\/ref\/drise.html\">d-rise reference page<\/a>.\r\n<h6><\/h6>\r\n&nbsp;\r\n<h6><\/h6>\r\n<p style=\"font-size: 20px;\"><a name=\"DRISE_with_TensorFlow\"><\/a><strong>D-RISE with TensorFlow Model<\/strong><\/p>\r\nThis section shows how to import a TensorFlow model for object detection, how to use the imported model in MATLAB and visualize the detections, and how to use D-RISE to explain the predictions of the model. You can get the code for this example from <a href=\"https:\/\/github.com\/matlab-deep-learning\/object-detection-and-explainability-for-tensorflow-and-pytorch-models\/tree\/main\/ObjectDetectionForTensorFlowModel\">here<\/a>.\r\n<h6><\/h6>\r\n<p style=\"font-size: 16px;\"><strong>Import and Initialize Network<\/strong><\/p>\r\nImport a pretrained TensorFlow model for object detection. The model is in the SavedModel format.\r\n<pre>modelFolder = \"centernet_resnet50_v2_512x517_coco17\";\r\ndetector = importNetworkFromTensorFlow(modelFolder);\r\n<\/pre>\r\n<h6><\/h6>\r\nSpecify the input size of the imported network. You can find the expected image size stated in the name of the TensorFlow network. The data format of the\u00a0dlarray\u00a0object must have the dimensions\u00a0\"SSCB\"\u00a0(spatial, spatial, channel, batch) to represent a 2-D image input. For more information, see\u00a0<a href=\"https:\/\/www.mathworks.com\/help\/deeplearning\/ug\/tips-on-importing-models-from-tensorflow-pytorch-and-onnx.html#mw_b48d26e1-7484-4802-aef3-502f61c36795\">Data Formats for Prediction with dlnetwork<\/a>. Then, initialize the imported network.\r\n<pre>input_size = [512 512 3];\r\ndetector = detector.initialize(dlarray(ones(512,512,3,1),\"SSCB\"))\r\n<\/pre>\r\n<pre class=\"brush: python\" style=\"background-color: white; border: white;\">detector = \r\n  dlnetwork with properties:\r\n\r\n         Layers: [1\u00d71 centernet_resnet50_v2_512x517_coco17.kCall11498]\r\n    Connections: [0\u00d72 table]\r\n     Learnables: [388\u00d73 table]\r\n          State: [0\u00d73 table]\r\n     InputNames: {'kCall11498'}\r\n    OutputNames: {'kCall11498\/detection_boxes'  'kCall11498\/detection_classes'  'kCall11498\/detection_scores'  'kCall11498\/num_detections'}\r\n    Initialized: 1\r\n\r\n  View summary with summary.\r\n<\/pre>\r\n<h6><\/h6>\r\n<p style=\"font-size: 16px;\"><strong>Detect with Imported Network<\/strong><\/p>\r\nThe network has four outputs: bounding boxes, classes, scores, and number of detections.\r\n<pre>mlOutputNames = detector.OutputNames'\r\n<\/pre>\r\n<pre class=\"brush: python\" style=\"background-color: white; border: white;\">mlOutputNames = 4\u00d71 cell\r\n'kCall11498\/detection_boxes'  \r\n'kCall11498\/detection_classes'\r\n'kCall11498\/detection_scores' \r\n'kCall11498\/num_detections'   \r\n<\/pre>\r\n<h6><\/h6>\r\nRead the image that you want to use for object detection. Perform object detection on the image.\r\n<pre>img = imread(\"testCar.png\");\r\n[y1,y2,y3,y4] = detector.predict(dlarray(single(img),\"SSCB\"));\r\n<\/pre>\r\n<h6><\/h6>\r\n<p style=\"font-size: 16px;\"><strong>Get Detections with Highest Scores<\/strong><\/p>\r\nCreate a map of all the network outputs.\r\n<pre>mlOutputMap = containers.Map;\r\nmlOutputs = {y1,y2,y3,y4};\r\nfor i = 1:numel(mlOutputNames)\r\n    opNameStrSplit = strsplit(mlOutputNames{i},'\/');\r\n    opName = opNameStrSplit{end};\r\n    mlOutputMap(opName) = mlOutputs{i};\r\nend\r\n<\/pre>\r\n<h6><\/h6>\r\nGet the detections with scores above the threshold thr, and the corresponding class labels.\r\n<pre>thr = 0.5;\r\n[bboxes,classes,scores,num_box] = bestDetections(img,mlOutputMap,thr);\r\nclass_labels = getClassLabels(classes);\r\n<\/pre>\r\n<h6><\/h6>\r\n<p style=\"font-size: 16px;\"><strong>Visualize Object Detection<\/strong><\/p>\r\nCreate the labels associated with each of the detected objects.\r\n<pre>colons = repmat(\": \",[1 num_box]);\r\npercents = repmat(\"%\",[1 num_box]);\r\nlabels = strcat(class_labels,colons,string(round(scores*100)),percents);\r\n<\/pre>\r\n<h6><\/h6>\r\nVisualize the object detection results with annotations.\r\n<pre>figure\r\noutputImage = insertObjectAnnotation(img,\"rectangle\",bboxes,labels,LineWidth=1,Color=\"green\");\r\nimshow(outputImage)\r\n<\/pre>\r\n<h6><\/h6>\r\n<img decoding=\"async\" loading=\"lazy\" class=\"alignnone wp-image-15275 size-full\" src=\"https:\/\/blogs.mathworks.com\/deep-learning\/files\/2024\/05\/object_detection_for_car_image_TF.png\" alt=\"TensorFlow object detector detects three objects in input image.\" width=\"343\" height=\"233\" \/>\r\n<h6><\/h6>\r\n<p style=\"font-size: 16px;\"><strong>Explainability for Object Detector<\/strong><\/p>\r\nExplain the predictions of the object detection network using D-RISE. Specify a custom detection function to use D-RISE with the imported TensorFlow network.\r\n<pre>targetBox = bboxes(1,:);\r\ntargetLabel = 1;\r\nscoreMap = drise(@(img)customDetector(img),img,targetBox,targetLabel);\r\n<\/pre>\r\n<h6><\/h6>\r\nPlot the results. As mentioned above, areas highlighted in red are more significant in the detection than areas highlighted in blue.\r\n<pre>figure\r\nannotatedImage = insertObjectAnnotation(img,\"rectangle\",targetBox,\"vehicle\");\r\nimshow(annotatedImage)\r\nhold on\r\nimagesc(scoreMap,AlphaData=0.5)\r\ntitle(\"DRISE Map: Custom Detector\")\r\nhold off\r\ncolormap jet\r\n<\/pre>\r\n<h6><\/h6>\r\n<img decoding=\"async\" loading=\"lazy\" class=\"alignnone wp-image-15278 size-full\" src=\"https:\/\/blogs.mathworks.com\/deep-learning\/files\/2024\/05\/drise_for_car_image.png\" alt=\"Saliency map created by DRISE for TensorFlow object detector\" width=\"344\" height=\"253\" \/>\r\n<h6><\/h6>\r\n&nbsp;\r\n<h6><\/h6>\r\n<p style=\"font-size: 20px;\"><a name=\"DRISE_with_PyTorch\"><\/a><strong>D-RISE with PyTorch Model<\/strong><\/p>\r\nThis section shows how to perform object detection with a PyTorch model using co-execution, and how to use D-RISE to explain the predictions of the PyTorch model. You can get the code for this example from <a href=\"https:\/\/github.com\/matlab-deep-learning\/object-detection-and-explainability-for-tensorflow-and-pytorch-models\/tree\/main\/ObjectDetectionForPyTorchModel\">here<\/a>.\r\n<h6><\/h6>\r\n<p style=\"font-size: 16px;\"><strong>Python Environment<\/strong><\/p>\r\nSet up the Python interpreter for MATLAB by using the pyenv function. Specify the version of Python to use.\r\n<pre>pe = pyenv(Version=\".\\env\\Scripts\\python.exe\",ExecutionMode=\"OutOfProcess\");\r\n<\/pre>\r\n<p style=\"font-size: 16px;\"><strong>Object Detection<\/strong><\/p>\r\nRead the image that you want to use for object detection.\r\n<pre>img_filename = \"testCar.png\";\r\nimg = imread(img_filename);\r\n<\/pre>\r\n<h6><\/h6>\r\nPerform object detection with a PyTorch model using co-execution.\r\n<pre>pyrun(\"from PT_object_detection import loadPTmodel, detectPT\")\r\n[model,weights] = pyrun(\"[a,b] = loadPTmodel()\",[\"a\" \"b\"]);\r\npredictions = pyrun(\"a = detectPT(b,c,d)\",\"a\",b=img,c=model,d=weights);\r\n<\/pre>\r\n<h6><\/h6>\r\nConvert the prediction outputs from Python data types to MATLAB data types.\r\n<pre>[bboxes,labels,scores] = convertVariables(predictions,imread(img_filename));\r\n<\/pre>\r\n<h6><\/h6>\r\nGet the class labels.\r\n<pre>class_labels = getClassLabels(labels);\r\n<\/pre>\r\n<h6><\/h6>\r\n<p style=\"font-size: 16px;\"><strong>Visualization<\/strong><\/p>\r\nCreate the labels associated with each of the detected objects.\r\n<pre>num_box = length(scores);\r\ncolons = repmat(\": \",[1 num_box]);\r\npercents = repmat(\"%\",[1 num_box]);\r\nclass_labels1 = strcat(class_labels,colons,string(round(scores'*100)),percents);\r\n<\/pre>\r\n<h6><\/h6>\r\nVisualize the object detection results with annotations.\r\n<pre>figure\r\noutputImage = insertObjectAnnotation(img,\"rectangle\",bboxes,class_labels1,LineWidth=1,Color=\"green\");\r\nimshow(outputImage)\r\n<\/pre>\r\n<h6><\/h6>\r\n<img decoding=\"async\" loading=\"lazy\" class=\"alignnone wp-image-15290 size-full\" src=\"https:\/\/blogs.mathworks.com\/deep-learning\/files\/2024\/05\/object_detection_for_car_image_PT.png\" alt=\"Object detected in input image by PyTorch object detector.\" width=\"346\" height=\"233\" \/>\r\n<h6><\/h6>\r\n<p style=\"font-size: 16px;\"><strong>Explainability<\/strong><\/p>\r\nExplain the predictions of the PyTorch model using D-RISE. Specify a custom detection function to use D-RISE.\r\n<pre>targetBbox = bboxes(1,:);\r\ntargetLabel = 1;\r\nscoreMap = drise(@(img)customDetector(img),img,targetBbox,targetLabel,...\r\n    NumSamples=512,MiniBatchSize=8);\r\n<\/pre>\r\n<h6><\/h6>\r\nYou can plot the saliency map computed by D-RISE as you previously did for the object detection results for the TensorFlow model.\r\n<p style=\"text-align: right; font-size: xx-small; font-weight: lighter; font-style: italic; color: gray;\"><a href=\"https:\/\/github.com\/matlab-deep-learning\/object-detection-and-explainability-for-tensorflow-and-pytorch-models\"><span style=\"font-size: x-small; font-style: italic;\">Get\r\nthe MATLAB code<\/span><\/a><\/p>\r\n\r\n<h6><\/h6>\r\n&nbsp;\r\n<h6><\/h6>","protected":false},"excerpt":{"rendered":"<div class=\"overview-image\"><img src=\"https:\/\/blogs.mathworks.com\/deep-learning\/files\/2024\/05\/drise_options-1.png\" class=\"img-responsive attachment-post-thumbnail size-post-thumbnail wp-post-image\" alt=\"\" decoding=\"async\" loading=\"lazy\" \/><\/div><p>\r\nIn R2024a, Deep Learning Toolbox Verification Library introduced the d-rise function. D-RISE is an explainability tool that helps you visualize and understand\u00a0 which parts are important for object... <a class=\"read-more\" href=\"https:\/\/blogs.mathworks.com\/deep-learning\/2024\/06\/10\/explainability-in-object-detection-for-matlab-tensorflow-and-pytorch-models\/\">read more >><\/a><\/p>","protected":false},"author":194,"featured_media":15419,"comment_status":"open","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":[],"categories":[9,5,66,39,59,45,42],"tags":[],"_links":{"self":[{"href":"https:\/\/blogs.mathworks.com\/deep-learning\/wp-json\/wp\/v2\/posts\/15248"}],"collection":[{"href":"https:\/\/blogs.mathworks.com\/deep-learning\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/blogs.mathworks.com\/deep-learning\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/blogs.mathworks.com\/deep-learning\/wp-json\/wp\/v2\/users\/194"}],"replies":[{"embeddable":true,"href":"https:\/\/blogs.mathworks.com\/deep-learning\/wp-json\/wp\/v2\/comments?post=15248"}],"version-history":[{"count":24,"href":"https:\/\/blogs.mathworks.com\/deep-learning\/wp-json\/wp\/v2\/posts\/15248\/revisions"}],"predecessor-version":[{"id":16064,"href":"https:\/\/blogs.mathworks.com\/deep-learning\/wp-json\/wp\/v2\/posts\/15248\/revisions\/16064"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/blogs.mathworks.com\/deep-learning\/wp-json\/wp\/v2\/media\/15419"}],"wp:attachment":[{"href":"https:\/\/blogs.mathworks.com\/deep-learning\/wp-json\/wp\/v2\/media?parent=15248"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/blogs.mathworks.com\/deep-learning\/wp-json\/wp\/v2\/categories?post=15248"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/blogs.mathworks.com\/deep-learning\/wp-json\/wp\/v2\/tags?post=15248"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}