Artificial Intelligence

Apply machine learning and deep learning

Neural Network Feature Visualization

Visualization of the data and the semantic content learned by a network

This post comes from Maria Duarte Rosa, who is going to talk about different ways to visualize features learned by networks.
Today, we'll look at two ways to gain insight into a network using two methods: k-nearest neighbors and t-SNE, which we'll describe in detail below.
Semantic Clustering using t-SNE

Visualization of a trained network using t-SNE

Dataset and Model

For both of these exercises, we'll be using ResNet-18, and our favorite food dataset, which you can download here. (Please be aware this is a very large download. We're using this for examples purposes only, since food is relevant to everyone! This code should work with any other dataset you wish).
The network has been retrained to identify the 5 categories of objects from the data:
Salad Pizza Fries Burger Sushi
Next we want to visualize our network and understand features used by a neural network to classify data. The following are two ways to visualize high-level features of a network, to gain insight into a network beyond accuracy.

k-nearest neighbors search

A nearest neighbor search is a type of optimization problem where the goal is to find the closest (or most similar) points in space to a given point.  K-nearest neighbors search identifies the top k closest neighbors to a point in feature space. Closeness in metric spaces is generally defined using a distance metric such as the Euclidean distance or Minkowski distance. The more similar the points are, the smaller this distance should be. This technique is often used as a machine learning classification method, but can also be used for visualization of data and high-level features of a neural network, which is what we're going to do.
Let's start with 5 test images from the food dataset:
idxTest = [394 97 996 460 737];
im = imtile(string({imdsTest.Files{idxTest}}),'ThumbnailSize',[100 100], 'GridSize', [5 1]);
and look for the 10 nearest neighbors of these images in the training data in the pixel space. The code below is going to get the features (i.e. "activations") for all test images, and find which ones are element-wise closest to our chosen sample images.
Get the features, aka activations
dataTrainFS = activations(netFood, imdsTrainAu, 'data', 'OutputAs', 'rows');
imgFeatSpaceTest = activations(netFood, imdsTestAu,'data', 'OutputAs', 'rows');
dataTestFS = imgFeatSpaceTest(idxTest,:);
Create KNN model and search for nearest neighbours
Mdl = createns(dataTrainFS,'Distance','euclidean');
idxKnn = knnsearch(Mdl,dataTestFS, 'k', 10);
Searching for similarities in pixel space does not generally return any meaningful information about the semantic content of the image but only similarities in pixel intensity and color distribution. The 10 nearest neighbors in the data (pixel) space do not necessarily correspond to the same class as the test image. There is no "learning" taking place.
Take a look at the 4th row:
The image of the fries is yellow and brighter at the top, and dark at the bottom. Most of the nearest neighbors in pixel space seem to be images of any class that contains the same pixel intensity and color pattern (they are somewhat brighter at the top and dark at the bottom). Let's compare this with images passed through the network and search for the 10 nearest neighbors in feature space, where the features are the output of the final average pooling layer of the network, pool5.
dataTrainFS = activations(netFood, imdsTrainAu, 'pool5', 'OutputAs', 'rows');
imgFeatSpaceTest = activations(netFood, imdsTestAu,'pool5', 'OutputAs', 'rows');
dataTestFS = imgFeatSpaceTest(idxTest,:);
Create KNN model and search for nearest neighbours
Mdl = createns(dataTrainFS,'Distance','euclidean');
idxKnn(:,:) = knnsearch(Mdl,dataTestFS, 'k', 10);

The first column (highlighted) is the test image, the remaining columns are the 10 nearest neighbors

Now we can see the color and intensity no longer matter, but rather the higher level features of the objects in the image. The nearest neighbors are now images of the same class. These results show that the features from the deep neural network contain information about the semantic content of the images. In other words, the network learned to discriminate between classes by learning high-level object specific features similarly to what allows humans to distinguish hamburgers from pizzas or Caesar salads from sushi.
K-NN: What can we learn from this?
This can confirm what we expect to see from the network, or simply another visualization of the network in a new way. If the training accuracy of the network is high but the nearest neighbors in feature space (assuming the features are the output of one of the final layers of the network) are not objects from the same class, this may indicate that the network has not captured any semantic knowledge related to the classes but might have learned to classify based on some artifact of the training data.

Semantic clustering with t-SNE

t-Distributed Stochastic Neighbor Embedding (t-SNE) is a non-linear dimensionality reduction technique that allows embedding high-dimensional data in a lower-dimensional space. (Typically we choose the lower dimensional space to be two or three dimensions, since this makes it easy to plot and visualize). This lower dimensional space is estimated in such a way that it preserves similarities from the high dimensional space. In other words, two similar objects have high probability of being nearby in the lower dimensional space, while two dissimilar objects should be represented by distant points. This technique can be used to visualize deep neural network features.
Let's apply this technique to the training images of the dataset and get a two dimensional and three dimensional embedding of the data.
Similar to k-nn example, we'll start by visualizing the original data (pixel space) and the output of the final averaging pooling layer.
layers = {'data', 'pool5'};
for k = 1:length(layers)
   dataTrainFS = activations(netFood, imdsTrainAu, layers{k}, 'OutputAs', 'rows');
   AlltSNE2dim(:,:,k) = tsne(dataTrainFS);
   AlltSNE3dim(:,:,k) = tsne(dataTrainFS), 'NumDimensions', 3);

subplot(1,2,1);gscatter(AlltSNE2dim(:,1,1), AlltSNE2dim(:,2,1), labels);
title(sprintf('Semantic clustering - %s layer', layers{1}));
subplot(1,2,2);gscatter(AlltSNE2dim(:,1,end), AlltSNE2dim(:,2,end), labels);
title(sprintf('Semantic clustering - %s layer', layers{end}));
subplot(1,2,1);scatter3(AlltSNE3dim(:,1,1),AlltSNE3dim(:,2,1),AlltSNE3dim(:,3,1), 20*ones(3500,1),  labels)
title(sprintf('Semantic clustering - %s layer', layers{1}));
subplot(1,2,2);scatter3(AlltSNE3dim(:,1,end),AlltSNE3dim(:,2,end),AlltSNE3dim(:,3,end), 20*ones(3500,1),  labels)
title(sprintf('Semantic clustering - %s layer', layers{end}));
Both in the two and three dimensional images, it is possible to see that the data is scattered all over the space - in a very random pattern. But when we plot the embedding for the output of 'pool5' the pattern is very different. Now we can clearly see clusters of points according to the semantic content of the image. The clusters correspond to the 5 different classes available in the data. This means that the high-level representations learned by the network contain discriminative information about the objects in the images, which allows the network to accurately predict the class of the object.
In addition to the information that these visualizations provide about the network, they can also be useful to inspect the data itself. For example, let's visualize a few images where the images are in the wrong cluster, and see if we can get some insight into why the network miss-predicted the output.

Examples of images in the wrong semantic cluster

Let's take a closer look at the 2D image of the pool5 layer, and zoom in on a few of the misclassified images.
im = imread(imdsTrain.Files{1619});
figure;imshow(im);title('Hamburger that looks like a salad');
A hamburger in the salad cluster. Unlike other hamburger images, there is a significant amount of salad in the photo and no bread/bun.
im = imread(imdsTrain.Files{125});
figure;imshow(im);title('Ceaser salad that looks like a hamburger');
A salad in the hamburger cluster. This may be because the image contains a bun or bread in the background.
im = imread(imdsTrain.Files{3000});
figure;imshow(im);title('Sushi that looks like a hamburger');
Maybe because it has some features that look like something one could find in a burger?
Finally, I think it is interesting to visualize the t-SNE for all the layers of the network, where we can see the data starts as random points, and slowly becomes clustered appropriately.
You can download the code using the small "Get the MATLAB code" link below. You'll need to bring your own pretrained network and dataset, since that is not included.
Hopefully you find these visualizations interesting and useful! Have a question for Maria? Leave a comment below!

Copyright 2018 The MathWorks, Inc.
Get the MATLAB code

  • print


To leave a comment, please click here to sign in to your MathWorks Account or create a new one.