用纯NumPy码一个RNN、LSTM:这是最好的入门方式了
机器之心报道
参与:思源
随着 TensorFlow 和 PyTorch 等框架的流行,很多时候搭建神经网络也就调用几行 API 的事。大多数开发者对底层运行机制,尤其是如何使用纯 NumPy 实现神经网络变得比较陌生。以前机器之心曾介绍过如何使用 NumPy 实现简单的卷积神经网络,但今天会介绍如何使用 NumPy 实现 LSTM 等循环神经网络。
一般使用纯 NumPy 实现深度网络会面临两大问题,首先对于前向传播,卷积和循环网络并不如全连接网络那样可以直观地实现。为了计算性能,实践代码与理论之间也有差别。其次,我们实现了前向传播后还需要继续实现反向传播,这就要求我们对矩阵微分和链式法则等数学基础都有比较充足的了解。
尽管 NumPy 不能利用 GPU 的并行计算能力,但利用它可以清晰了解底层的数值计算过程,这也许就是为什么 CS231n 等课程最开始都要求使用 NumPy 手动实现深度网络吧。
项目地址:https://github.com/krocki/dnc
在这个项目中,作者主要使用 NumPy 实现了 DNC、RNN 和 LSTM,其中 RNN 代码借鉴了 A.Karpathy 以前写过的代码。此外,作者还写了 Gradient check 以确定实现的正确性,是不是感觉自深度学习框架流行以来,梯度检验这个词就渐渐消失了~
具体而言,这个项目是 DeepMind 于 2016 年发表在 Nature 的论文《Hybrid computing using a neural network with dynamic external memory》的实现,即可微神经计算机(DNC),其示例的任务是字符级预测。repo 中还包括 RNN(rnn-numpy.py) 和 LSTM (lstm-numpy.py) 的实现,一些外部数据(ptb, wiki)需要分别下载。
如下所示为 LSTM 的前向传播过程,Pyhon 2.7 的 xrange 改成 range 就好了 ˉ\(ツ)/ˉ:
loss=0
#forwardpass
fortinxrange(len(inputs)):
#encodein1-of-krepresentation
xs[t]=np.zeros((M,B))
forbinrange(0,B):xs[t][:,b][inputs[t][b]]=1
#gates,linearpart
gs[t]=np.dot(Wxh,xs[t])+np.dot(Whh,hs[t-1])+bh
#gatesnonlinearpart
#i,o,fgates
gs[t][0:3*HN,:]=sigmoid(gs[t][0:3*HN,:])
#cgate
gs[t][3*HN:4*HN,:]=np.tanh(gs[t][3*HN:4*HN,:])
#mem(t)=cgate*igate+fgate*mem(t-1)
cs[t]=gs[t][3*HN:4*HN,:]*gs[t][0:HN,:]+gs[t][2*HN:3*HN,:]*cs[t-1]
#memcell-nonlinearity
cs[t]=np.tanh(cs[t])
#newhiddenstate
hs[t]=gs[t][HN:2*HN,:]*cs[t]
#unnormalizedlogprobabilitiesfornextchars
ys[t]=np.dot(Why,hs[t])+by
###################
mx=np.max(ys[t],axis=0)
#normalize
ys[t]-=mx
#probabilitiesfornextchars
ps[t]=np.exp(ys[t])/np.sum(np.exp(ys[t]),axis=0)
forbinrange(0,B):
#softmax(cross-entropyloss)
ifps[t][targets[t,b],b]0:loss+=-np.log(ps[t][targets[t,b],b])
如上代码所示,最外层的循环 t 表示不同的时间步。而在每一个时间步下,首先需要计算不同的门控激活值,这三个门都是并在一起算的,这和我们在理论上看到的三个独立公式不太一样,但很合理。接下来按照 LSTM 单元的计算过程依次算出当前记忆内容 cs[t]、隐藏单元输出值 hs[t] 和最后的概率预测 ys[t]。最后只需要根据预测算损失值,并加入总体损失就行了。
除了上述的前向传播,更厉害的还是 RNN 和 LSTM 等的反向传播,即沿时间的反向传播(BPTT),这里就需要读者具体参考代码并测试了。
项目的使用
除了读源码外,当然我们也可以通过命令行直接试用模型效果,首先检验梯度等关键结构与代码:
pythondnc-debug.py
下面的版本都是准备好的:
pythonrnn-numpy.py
pythonlstm-numpy.py
pythondnc-numpy.py
该项目具有这些特点:数值计算仅依赖于 NumPy、添加了批处理、可将 RNN 修改为 LSTM,还能进行梯度检查。
该项目已经实现了 LSTM-控制器,2D 内存数组和内容可寻址的读/写。但有一个问题是,关键相似度的 softmax 会导致崩溃(除以 0),如果遇到这种情况,需要重新启动。该 repo 还有一些需要完成或改进的地方,包括动态内存分配和释放,实现更快、可保存的模型等。
在采样输出时,我们可以得到的数据包括时间、迭代次数、BPC(预测误差-每字符的位数,越低越好),以及处理速度(char/s)。
0:4163.009s,iter104800,1.2808BPC,1488.38char/s
如下展示了反向传播的数值梯度检验(最右边列的值应该小于 1e-4),中间列是计算得到的分析和数值梯度范围(这些应该或多或少都能匹配上)。
GRADCHECK
Wxh:n=[-1.828500e-02,5.292866e-03]min3.005175e-09,max3.505012e-07
a=[-1.828500e-02,5.292865e-03]mean5.158434e-08#10/4
Whh:n=[-3.614049e-01,6.580141e-01]min1.549311e-10,max4.349188e-08
a=[-3.614049e-01,6.580141e-01]mean9.340821e-09#10/10
Why:n=[-9.868277e-02,7.518284e-02]min2.378911e-09,max1.901067e-05
a=[-9.868276e-02,7.518284e-02]mean1.978080e-06#10/10
Whr:n=[-3.652128e-02,1.372321e-01]min5.520914e-09,max6.750276e-07
a=[-3.652128e-02,1.372321e-01]mean1.299713e-07#10/10
Whv:n=[-1.065475e+00,4.634808e-01]min6.701966e-11,max1.462031e-08
a=[-1.065475e+00,4.634808e-01]mean4.161271e-09#10/10
Whw:n=[-1.677826e-01,1.803906e-01]min5.559963e-10,max1.096433e-07
a=[-1.677826e-01,1.803906e-01]mean2.434751e-08#10/10
Whe:n=[-2.791997e-02,1.487244e-02]min3.806438e-08,max8.633199e-06
a=[-2.791997e-02,1.487244e-02]mean1.085696e-06#10/10
Wrh:n=[-7.319636e-02,9.466716e-02]min4.183225e-09,max1.369062e-07
a=[-7.319636e-02,9.466716e-02]mean3.677372e-08#10/10
Wry:n=[-1.191088e-01,5.271329e-01]min1.168224e-09,max1.568242e-04
a=[-1.191088e-01,5.271329e-01]mean2.827306e-05#10/10
bh:n=[-1.363950e+00,9.144058e-01]min2.473756e-10,max5.217119e-08
a=[-1.363950e+00,9.144058e-01]mean7.066159e-09#10/10
by:n=[-5.594528e-02,5.814085e-01]min1.604237e-09,max1.017124e-05
a=[-5.594528e-02,5.814085e-01]mean1.026833e-06#10/10
本文为机器之心报道,转载请联系本公众号获得授权。
?------------------------------------------------
加入机器之心(全职记者 / 实习生):hr@jiqizhixin.com
投稿或寻求报道:content@jiqizhixin.com
广告 & 商务合作:bd@jiqizhixin.com
网站开发网络凭借多年的网站建设经验,坚持以“帮助中小企业实现网络营销化”为宗旨,累计为4000多家客户提供品质建站服务,得到了客户的一致好评。如果您有网站建设、网站改版、域名注册、主机空间、手机网站建设、网站备案等方面的需求...
请立即点击咨询我们或拨打咨询热线:13245491521 13245491521 ,我们会详细为你一一解答你心中的疑难。 项目经理在线