46

Visualizing Convolution Neural Networks using Pytorch

 4 years ago
source link: https://www.tuicool.com/articles/rmq6Vvr
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Convolution Neural Network(CNN) is another type of neural network that can be used to enable machines to visualize things and perform tasks such as image classification, image recognition, object detection, instance segmentation etc…But the neural network models are often termed as ‘black box’ models because it is quite difficult to understand how the model is learning the complex dependencies present in the input. Also, it is difficult to analyze why a given prediction is made during inference.

In this article, we will look at two different types of visualization techniques such as :

  1. Visualizing learned filter weights.
  2. Performing occlusion experiments on the image.

These methods help us to understand what does filter learn? what kind of images cause certain neurons to fire? and how good are the hidden representations of the input image?.

Citation Note: The content and the structure of this article is based on the deep learning lectures from One-Fourth Labs — PadhAI . If you are interested checkout there course.

Receptive Field of Neuron

Before we go ahead and visualize the working of Convolution Neural Network, we will discuss the receptive field of filters present in the CNN’s.

Consider that we have a two-layered Convolution Neural Network and we are using 3x3 filters through the network. The centered pixel marked in the yellow present in Layer 2 is actually the result of applying convolution operation on the center pixel present in Layer 1 (by using 3x3 kernels and stride = 1). Similarly, the center pixel present in Layer 3 is a result of applying convolution operation on the center pixel present in Layer 2.

fUzEFz6.png!web

The receptive field of a neuron is defined as the region in the input image that can influence the neuron in a convolution layer i.e…how many pixels in the original image are influencing the neuron present in a convolution layer.

It is clear that the central pixel in Layer 3 depends on the 3x3 neighborhood of the previous layer (Layer 2). The 9 successive pixels (marked in pink) present in Layer 2 including the central pixel corresponds to the 5x5 region in Layer 1. As we go deeper and deeper in the network the pixels at the deeper layers will have a high receptive field i.e… the region of interest with respect to the original image would be larger.

7RzQbmr.png!web

From the above image, we can observe that the highlighted pixel present in the second convolution layer has a high receptive field with respect to the original input image.

Visualizing CNN

To visualize the working of CNN, we will explore two commonly used methods to understand how the neural network learns the complex relationships.

  1. Filter visualization with a pre-trained model.
  2. Occlusion analysis with a pre-trained model.

Run this notebook in Colab

All the code discussed in the article is present on my GitHub . You can open the code notebook with any setup by directly opening my Jupyter Notebook on Github with Colab which runs on Google’s Virtual Machine. Click here , if you just want to quickly open the notebook and follow along with this tutorial.

Don’t forget to upload the input images folder (can be downloaded from the Github Repo) onto Google Colab before executing the code in Colab.

Visualize Input Images

In this article, we will use a small subset of the ImageNet dataset with 1000 categories to visualize the filters of the model. The dataset can be downloaded from my GitHub repo .

To visualize the data set we will implement the custom function imshow .

The function imshow takes two arguments — image in tensor and the title of the image. First, we will perform the inverse normalization of the image with respect to the ImageNet mean and standard deviation values. After that, we will use matplotlib to display the image.

nMruE3m.jpg!web

Sample Input Image

Filter Visualization

By visualizing the filters of the trained model, we can understand how CNN learns the complex Spatial and Temporal pixel dependencies present in the image.

What does a filter capture?

Consider that we have 2D input of size 4x4 and we are applying a filter of 2x2 (marked in red) on the image starting from the top left corner of the image. As we slide the kernel over the image from left to right and top to bottom to perform a convolution operation we would get an output that is smaller than the size of the input.

7ZzQJnE.png!web

The output at each convolution operation (like h₁₄) is equal to the dot product of the input vector and a weight vector. We know that the dot product between the two vectors is proportional to the cosine of the angle between vectors.

26bMRzv.png!web

During convolution operation, certain parts of the input image like the portion of the image containing the face of a dog might give high value when we apply a filter on top of it. In the above example, let’s discuss in what kind of scenarios our output h₁₄ will be high?.

The output h₁₄ would be high if the cosine value between the vectors is high i.e… cosine value should be equal to 1. If the cosine angle is equal to 1 then we know the angle between the vectors is equal to 0⁰. That means both input vector (portion of the image) X and the weight vector W are in the same direction the neuron is going to fire maximally.

qIfENjf.png!web

The neuron h₁₄ will fire maximally when the input X (a portion of the image for convolution) is equal to the unit vector or a multiple of the unit vector in the direction of the filter vector W .

In other words, we can think of a filter as an image. As we slide the filter over the input from left to right and top to bottom whenever the filter coincides with a similar portion of the input, the neuron will fire. For all other parts of the input image that doesn’t align with the filter, the output will be low. This is the reason we call the kernel or weight matrix as a filter because it filters out portions of the input image that doesn’t align with the filter.

To understand what kind of patters does the filter learns, we can just plot the filter i.e… weights associated with the filter. For filter visualization, we will use Alexnet pre-trained with the ImageNet data set.

#alexnet pretrained with imagenet data
#import model zoo in torchvisionimport torchvision.models as models
alexnet = models.alexnet(pretrained=True)

Alexnet contains 5 convolutional layers and 3 fully connected layers. ReLU is applied after every convolution operation. Remember that in convolution operation for 3D (RGB) images, there is no movement of kernel along with the depth since both kernel and image are of the same depth. We will visualize these filters (kernel) in two ways.

  1. Visualizing each filter by combing three channels as an RGB image.
  2. Visualizing each channel in a filter independently using a heatmap.

The main function to plot the weights is plot_weights . The function takes 4 parameters,

model — Alexnet model or any trained model

layer_num — Convolution Layer number to visualize the weights

single_channel — Visualization mode

collated — Applicable for single-channel visualization only.

In the plot_weights function, we take our trained model and read the layer present at that layer number. In Alexnet (Pytorch model zoo) first convolution layer is represented with a layer index of zero. Once we extract the layer associated with that index, we will check whether the layer is the convolution layer or not. Since we can only visualize layers which are convolutional. After validating the layer index, we will extract the learned weight data present in that layer.

#getting the weight tensor data
weight_tensor = model.features[layer_num].weight.data

Depending on the input argument single_channel we can plot the weight data as single-channel or multi-channel images. Alexnet’s first convolution layer has 64 filters of size 11x11. We will plot these filters in two different ways and understand what kind of patterns filters learn.

Visualizing Filters — Multi-Channel

In the case of single_channel = False we have 64 filters of depth 3 (RGB). we will combine each filter RGB channels into one RGB image of size 11x11x3. As a result, we would get 64 RGB images as the output.

#visualize weights for alexnet — first conv layer
plot_weights(alexnet, 0, single_channel = False)

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK