15

CNN特征提取结果可视化——hooks简单应用

 4 years ago
source link: https://blog.csdn.net/qq_34769162/article/details/115567093
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

本文代码地址https://github.com/njulhy/funny_code/blob/main/cnn_feature_visualization.ipynb

CNN特征提取结果可视化——hooks简单应用

在神经网络搭建时可能出现各式各样的错误,使用hook而非print或者简单的断点调试有助于你更清晰的意识到错误所在。

hook的使用场景多种多样,本文将使用hooks来简单可视化卷积神经网络的特征提取。用到的神经网络框架为Pytorch

Hooks简单介绍

每个hook都是预先定义好的可调用对象,在pytorch框架中,每个nn.Module对象都能够方便地注册(定义)一个hook。当一些trigger方法调用(如forward()backward())后,注册了hook的nn.Module对象会将相关信息传递到hook里面去。
在PyTorch中,可以注册三种hook:

  1. forward prehook (在forward之前执行)
  2. forward hook (在forward之后执行)
  3. backward hook (在backward之后执行)

具体理解每种hook的使用不是本文讨论的范围,我们将通过一个生动的卷积神经网络可视化例子来介绍hook的使用

CNN特征提取的简单可视化

我们将要进行的工作包括:

  1. 创建CNN特征提取器,本文使用PyTorch自带的resnet34
  2. 创建一个保存hook内容的对象
  3. 为每个卷积层创建hook
  4. 读取图像并进行特征提取
  5. 查看卷积层特征提取效果

本文将对下图进行特征提取并可视化

20210410111323237.jpg

创建CNN特征提取器

import torch
import torchvision

feature_extractor = torchvision.models.resnet34(pretrained=True)
if torch.cuda.is_available():
	feature_extractor.cuda()

创建保存hook内容的对象

class SaveOutput:
	def __init__(self):
		self.outputs = []
	def __call__(self, module, module_in, module_out):
		self.outputs.append(module_out)
	def clear(self):
		self.outputs=[]
		
save_output = SaveOutput()

为卷积层注册hook

hook_handles = []

for layer in feature_extractor.modules():
	if isinstance(layer, torch.nn.Conv2d):
		handle = layer.register_forward_hook(save_output)
		hook_handles.append(handle)

读取图像并进行特整体提取

cat.jpg地址

from PIL import Image
from torchvision import transforms as T

image = Image.open('cat.jpg')
transform = T.Compose([T.Resize((224, 224)), T.ToTensor()])
X = transform(image).unsqueeze(dim=0).to(device)

out = feature_extractor(X)

查看卷积层特征提取效果

对于resnet来说,其具体结构如下:

2021041012222937.png

卷积层共有1+6+(4*2+1)+(6*2+1)+(3*2+1)=36个,对conv3_x层有4*2+1卷积层的原因是(1)四个basicblock本身有4*2个卷积层(2)其中一个basicblock进行了downsample,又多了一个卷积层

查看卷积层数

此时每个卷积层的结果都通过hook保存到了save_output.outputs里面,我们查看是否为36个结果

20210410124224111.png 可见全部卷积层的输出都保存了下来

可视化第一个卷积层

对resnet34来说,首个卷积层的卷积核为7*7,将输入的三通道彩色图像通道增加至64,尺寸从224*224对折为112*112,tensor的shape为1x64x112x112

20210410125810947.png

我们对首个卷积层的提取结果进行可视化:

import matplotlib.pyplot as plt
plt.figure(figsize = (15,15))
plt.imshow(torchvision.utils.make_grid(save_output.outputs[0].cpu().permute(1, 0, 2, 3), nrow=8).permute(1, 2, 0))

emm这是第一个卷积层的提取结果,可爱的小猫咪开始黑化
在这里插入图片描述

可视化第二、七个卷积层

对resnet34来说,第2-7个卷积层tensor的shape为64x1x56x56,我们对其2个卷积层输出进行可视化:

plt.figure(figsize = (15,15))
plt.imshow(torchvision.utils.make_grid(save_output.outputs[1].cpu().permute(1, 0, 2, 3), nrow=8).permute(1, 2, 0))

可见第二个卷积层的结果更加模糊一些
在这里插入图片描述
第2-7个卷积层tensor的shape为64x1x56x56,我们对第七个卷积层也可视化:

plt.figure(figsize = (15,15))
plt.imshow(torchvision.utils.make_grid(save_output.outputs[6].cpu().permute(1, 0, 2, 3), nrow=8).permute(1, 2, 0))

在这里插入图片描述

可视化第16个卷积层

第16个卷积层对应的是conv3_x的结果,其shape为1x128x28x28,可视化如下

plt.figure(figsize = (15,30))
plt.imshow(torchvision.utils.make_grid(save_output.outputs[15].cpu().permute(1, 0, 2, 3), nrow=8).permute(1, 2, 0))

可见图像经过多层特征提取,提取到的特征变得更加高层,大部分通道已经变得难以辨认
在这里插入图片描述

对神经网络提取结果进行可视化有助于理解其特征提取逐渐高层化的过程。
hook的使用场景还有很多,希望小伙伴们继续探索。


Recommend

  • 84
    • 微信 mp.weixin.qq.com 5 years ago
    • Cache

    图像特征提取实践小结

    一、背景 视觉是人类最高级别的感知,因此图像在人类感知中扮演着重要的角色。统计表明,人类从外界获得的绝大部分信息是来自视觉所接收的图像信息。随着计算机软硬件技术、社交网络和多媒体技术的快速发展,图像的生成、处理、存储、传...

  • 60

    三大特征提取器 - RNN、CNN和Transformer 简介 近年来,深度学习在各个NLP任务中都取得了SOTA结果。这一节,我们先了解一下现阶段在自然语言处理领域最常用的特征抽取结构。 本文部分参考张俊林老师的文...

  • 28

    这些年来,参与或了解了不少模式识别技术相关项目的诞生、成长与死亡。项目会死亡,但是知识不会,人的精力是有限的,如果不记录一下,多年后参与者也会变得和路人一样,忘了那些技术是咋实现的,用来干嘛的。所以这里零散的记录一下吧,...

  • 23

    ↑↑↑关注后" 星标 "Datawhale 每日干货 & 

  • 32

    作者|Ayisha D 编译|VK 来源|Towards Data Science 这篇文章中,我们探讨从语音数据中提取的特征,以及基于这些特征构建模型的不同方法。

  • 11
    • lanbing510.info 4 years ago
    • Cache

    特征提取与特征选择

    特征提取和特征选择都是从原始特征中找出最有效(同类样本的不变性、不同样本的鉴别性、对噪声的鲁棒性)的特征。 区别与联系特征提取:将原始特征转换为一组具有明显物理意义(Gabor、几何特征[角点、不变量]、纹理[LBP HOG])或者统计...

  • 16
    • www.biaodianfu.com 4 years ago
    • Cache

    使用Scikit-Learn提取文本特征

    文本分析是机器学习算法的主要应用领域。由于大部分机器学习算法只能接收固定长度的数值型矩阵特征,导致文本字符串等并不能直接被使用,针对此问题Scikit-Learn提供了将文本转化为数值型特征的方法,今天就一起来学习下。

  • 10

    行为识别(时间序列)特征提取代码 1年前 ⋅...

  • 7
    • bbs.cvmart.net 4 years ago
    • Cache

    行为识别常用提取特征

    行为识别常用提取特征 2年前 ⋅ 3914...

  • 7

    梅尔倒谱系数(Mel-Frequency Ceptral Coeffcients,MFCC),因为其独特的基于倒谱(ceptral)的提取方式,是目前最常用也是最有效的的语音特征提取算法之一。下图描述了MFCC语音特征提取的过程。 1、步骤介绍 1.1、预加重

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK