24

谷歌重磅开源新技术:5行代码打造无限宽神经网络模型

 4 years ago
source link: http://news.51cto.com/art/202003/612577.htm
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
zQFv6rM.jpg!web

开箱即用,5行代码打造无限宽神经网络模型

Neural Tangents 是一个高级神经网络 API,可用于指定复杂、分层的神经网络,在 CPU/GPU/TPU 上开箱即用。

该库用 JAX编写,既可以构建有限宽度神经网络,亦可轻松创建和训练无限宽度神经网络。

有什么用呢?举个例子,你需要训练一个完全连接神经网络。通常,神经网络是随机初始化的,然后采用梯度下降进行训练。

研究人员通过对一组神经网络中不同成员的预测取均值,来提升模型的性能。另外,每个成员预测中的方差可以用来估计不确定性。

如此一来,就需要大量的计算预算。

但当神经网络变得无限宽时,网络集合就可以用高斯过程来描述,其均值和方差可以在整个训练过程中进行计算。

而使用 Neural Tangents ,仅需5行代码,就能完成对无限宽网络集合的构造和训练。

from neural_tangents import predict, stax 
 
init_fn, apply_fn, kernel_fn = stax.serial( 
    stax.Dense(2048, W_std=1.5, b_std=0.05), stax.Erf(), 
    stax.Dense(2048, W_std=1.5, b_std=0.05), stax.Erf(), 
    stax.Dense(1, W_std=1.5, b_std=0.05)) 
 
y_mean, y_var = predict.gp_inference(kernel_fn, x_train, y_train, x_test, ‘ntk’, diag_reg=1e-4, compute_cov=True) 

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK