4

Perceiver: One Neural-Network Model for Multiple Input Data Types

 3 years ago
source link: https://www.infoq.com/news/2021/04/perceiver-neural-network-model/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Perceiver: One Neural-Network Model for Multiple Input Data Types

Apr 13, 2021 2 min read

Google's DeepMind company has recently released a state-of-the-art deep-learning model called Perceiver that receives and processes multiple input data ranging from audio to images, similarly to how the human brain perceives multimodal data.

Perceiver is able to receive and classify input multiple data types, namely point cloud, audio and images. For this purpose, the deep-learning model is based on transformers (a.k.a. attention), which make no assumptions about the input data type. 

Usually the bottleneck of using transformers is the quadratic number of operations needed for algorithms. For instance, processing an image measuring 224 pixels by 224 pixels could lead to 224^2 operations, over 50,000, which is a huge computational overhead. To sort this problem, DeepMind researchers replaced the self-attention layer with a cross-attention layer in the transformer, resulting in a linear algorithm complexity.

2091-1618218923726.jpg

Source: Perceiver: General Perception with Iterative Attention

In addition, the input data used to compute cross attention is converted into a byte array, which means this model is agnostic to the data type. 

The great breakthrough about this model is that it makes no assumption about input data type, while, for instance, existing convolutional neural networks work for images only. 

1782-1618218924536.jpg

Source: Perceiver: General Perception with Iterative Attention

For image classification, this model achieves state-of-the-art accuracy on ImageNet of 76.4% (while ResNet achieves 39.4%).

1453-1618218924205.jpg

Source: Perceiver: General Perception with Iterative Attention

Perceiver got attention on social media, having thousands of views on YouTube, a thread discussion on Reddit and Twitter ongoing discussion. There is an interesting comment on a Reddit thread that show the relevance of this new model:

The basic idea, as I understand it, is to achieve cross-domain generality by recreating the MLP with transformers, where

  • "neurons" and activations are vectors not scalars, and
  • interlayer weights are dynamic, not fixed

You can also reduce input dimensionality by applying cross-attention to a fixed set of learned vectors. Pretty cool.

In addition, there is a researcher insight on Twitter thread:

This is really great work. There is a community implementation too.
...github.com/lucidrains/per...
Definitely going to be playing around with this. Thanks for the paper.

Finally, there is an open-source implementation in PyTorch by members of the deep-learning community. In order to use it, you can use the following snippet:

import torch
from perceiver_pytorch import Perceiver

model = Perceiver(
    input_channels = 3,          # number of channels for each token of the input
    input_axis = 2,              # number of axis for input data (2 for images, 3 for video)
    num_freq_bands = 6,          # number of freq bands, with original value (2 * K + 1)
    max_freq = 10.,              # maximum frequency, hyperparameter depending on how fine the data is
    depth = 6,                   # depth of net
    num_latents = 256,           # number of latents, or induced set points, or centroids. different papers giving it different names
    latent_dim = 512,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,
    latent_dim_head = 64,
    num_classes = 1000,          # output number of classes
    attn_dropout = 0.,
    ff_dropout = 0.,
    weight_tie_layers = False    # whether to weight tie layers (optional, as indicated in the diagram)
)

img = torch.randn(1, 224, 224, 3) # 1 imagenet image, pixelized

model(img) # (1, 1000)




About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK