本篇文章5093字,读完约13分钟
机器的心
从sjmielke中选择
机器的心被编译
机器的心
说到现在的深度学习框架,我们经常无法绕过tensorflow和pytorch。 但是,除了这两个框架,一点也不能小看新生力量。 其中之一是jax。 具有正方向和反方向的自动微分功能,非常擅长高次导数的计算。 这个引人注目的框架到底有多有用? 如何用它展示神经网络内部许多杂乱的梯度更新和反向传递? 本文是粘贴教程,以了解jax的基本逻辑,方便从pytorch等迁移。
jax是谷歌开发的python库,用于机器学习和数学计算。 上市后,jax将其定义为python+numpy的包。 可以微分、矢量化,具有在tpu和gpu中使用jit语言等特征。 简单来说,这就是gpu版本的numpy,也可以自动微分。 像skye wanderman-milne这样的研究者也在去年的neurlps大会上介绍了jax。
但是,从开发者熟悉的pytorch和tensorflow 2.x迁移到jax,在构建计算和反转的方法方面肯定有本质区别。 pytorch绘制计算图,计算正向和反向的传递过程。 结果节点上的坡度是中间节点的坡度累积而成的。
jax不是。 用python函数表示计算过程,用grad ( )转换为梯度函数,就可以进行评价了。 但是,不是给出结果,而是给出结果的梯度。 两者的对象如下。
那样的话,编程和模型构建的方法就不同了。 所以,可以采用tape-based的自动微分方法,采用有状态的对象。 但是jax可能会吓到你。 因为如果运行grad ( )函数,微分过程就像函数一样。
你可能会在决策中看到flax、haiku等基于jax的工具。 通过resnet等示例,您可以发现它与其他框架的代码不同。 除了定义层次和进行培训以外,基础逻辑是什么? 这些小numpy计划是怎么训练巨大的体系结构的?
这篇文章是介绍jax构建模式的教程,机心摘录了那两部分。
快速回顾pytorch上的lstm-lm应用程序。
查看pytorch样式的代码(基于mutate状态),了解纯函数是如何构建模型的( jax )
可以参考原文的复印件在增加。
pytorch上的lstm语言模型
我们首先使用pytorch实现lstm语言模型。 以下是代码。
importtorchclasslstmcell ( torch.nn.module ):def _ _ init _ _ ( self,in_dim, out _ dim ) self )._ _ init _ _ ( ) self.weight _ ih
def forward(self,inputs,h,c ):ifgo = self.weight _ ih @ inputs + self.weight _ hh @。 4 ) I = torch.sigmoid ( I ) f = torch.sigmoid ( f ) g = torch.tanh ( g ) o = torch.sigmoid
然后我们根据这个lstm神经元构建单层互联网。 这里有埋入层。 这个和可以学习的( h,c)0表示各个参数如何变化。
classlstmlm ( torch.nn.module ):def _ _ init _ _ ( self
@ propertydefhc _0( self ):return ( torch.tanh ( self.c_0),self.c _0)
def forward(self,seq,HC ):loss = torch.tensor (0. ) foridxinseq:loss-= torch.log _。
def greedy_argmax(self,hc
构筑后进行训练。
torch.manual _ seed (0) # astrainingdata、wewillhaveindicesofwords/word pieces/characters、# wejustassumetheyaretokenizedandints : training _ data = JNP.array ( [ 4,8,15,16,23,42 ] )
LM = LST MLM ( VOCAB _ size = VCAB _ size ) Print ( " sample Before:",LM
bptt _ length =3# toillustratehc.detach-ing
forepochinrange ( 101 ):HC = LM.HC _0total LSS =0. Forstartinrange ( 0,LEN c ) = loss c.detach ( ) ) if epoch % 50 = =0: total loss + = loss.item param.graddelparam total loss ) print ( " sample after:",LM
pytorch的代码很清楚,但还是有问题。 我很观察,请关注计算图中的节点数。 哪个中间节点需要在正确的时间被清除。
纯函数
为了理解jax如何解决这个问题首先需要理解纯函数的概念。 如果你以前做过函数编程,纯函数可能就像数学函数或公式。 这定义了如何从某些输入值中获取输出值。 重要的是,没有“副作用”。 这意味着函数的任何部分都不会访问或更改全局状态。
我们在pytorch上写代码时充满了中间变量和状态,而且这些状态总是在变化,所以推理和优化变得非常困难。 因此,jax选择了把程序员限制在纯粹的函数范围内,不要引起这个。
在详细了解jax之前,让我们看看一些纯函数的例子。 纯函数必须满足以下条件:
在什么情况下执行函数,什么时候执行函数应该不影响输出。 如果输入不变,输出也应该不变。
即使我们执行0次、1次或多次函数,事后也应该无法判别。
以下非亚纯函数违反了上述至少一个条件。
importrandomimporttimenr _ executions = 0
def pure _ FN _1( x ):return2* xdefpure _ FN _2( XS )
def impure _ fn _1( XS ):# mutatingargumentshaslastingconsequencesoutsidethefunction! ( xs.append(sum(xs)) return xs
def impure _ fn _2( x ):# veryobviouslymutatingglobalstateisbad ... global NR _ executions NR _ executions
def impure _ fn _3( x ):# ...butjustaccessingitis,too,becausenowthefunctiondependsonthe # execution c return NR _ executions * x
def impure _ fn _4( x ):# thingslikeioareclassicexamplesofimpurity.# allthreeofthefollowinglinesareviolation ) user _ input = input。
def impure _ fn _5( x ):# whichconstraintdoesthisviolate? 老板,actually! youaccessthecurrent # stateofrandomness * and * advancethenumbergenerator! p = random.random ( ) returnp * xlet ' sseeapurefunctionthatjaxoperateson:theexamplefromtheintrofigure。
( almost )1- dlinearregressiondeff ( w,x): return w * x
print(f(13 .,42.))546.0
至今没有出现任何情况。 在jax中,现在可以通过返回函数结果比较函数的第一个参数的梯度而不是返回结果,将以下函数转换为另一个函数:
import jaximport jax.numpy as jnp
# gradient: with respect to weights! jaxusesthefirstargumentbydefault.df _ dw = jax.grad ( f )
def manual_df_dw(w,x): return x assert df_dw(13 .,42.) == manual_df_dw(13
print(df_dw(13 .,42.))42.0
到目前为止,我在jax的自述文件中大致看过所有以前的副本,但副本也很合理。 但是,如何跳转到pytorch代码这样大的模块呢?
首先,追加偏移项,试着把一维线性回归变量包装成我们习性采用的对象线性回归“层”。
classlinearregressor ( ):def _ _ init _ _ ( self,w,b ):self
# a kind of loss fuction,usedfortrainingxs = JNP.Array ( [ 42.0 ] ) ys = JNP.Array ( [ 500.0 ] ) PR inning
# predictionfortestdataprint ( my _ regressor.predict ( 42.) ) 46.0546.0
接下来如何利用坡度进行训练? 需要以模型权重为函数输入参数的纯函数。 有可能变成这样。
def loss_fn(w,b,xs,ys ):my _ regressor = linear regressor ( w,b ) returnmy _ regressor1) totelljaxtogiveus # gradientswrtfirstans
print(loss_fn(13 .,0 .,xs,ys))print(grad_fn(13
你必须说服自己。 那是对的。 目前可以这样做,但不能在loss_fn的定义部分列举所有参数。
幸运的是,jax不仅可以微分标量、矢量、矩阵,还可以微分多个类似树的数据结构。 这样的结构被称为pytree,包括python dicts。
def loss_fn(params,xs,ys ):my _ regressor = linear regressor ( params [ ' w ' ],params[]
grad_fn = jax.grad(loss_fn )
print(loss_fn({'w': 13 .,' b': 0.},xs,ys))print(grad_fn({
现在看得很清楚! 我们可以写以下的训练周期.。
params = {'w': 13 .,' b': 0.}
for _ in range ( 15 ):Print ( LSS _ FN ( Params,xs,ys)) grads = grad_fn(params,XS ) YS Predict:Linear Params [ ' b ' ] ).Predict ( 42.) 46.042.4700338.94000235.41003431.88006628.35009824
观察,现在可以采用越来越多的jax helper进行自我更新。 因为参数和坡度有共同的(像树一样)结构,所以可以想象把它们放在上面,创造新的树。 那个值到处都是这两棵树的“组合”。 如下所示。
def update_combiner(param,grad,LR = 0.002 ):return Param-LR * Grad Params = JAX.tree _ Mu Lula Grads ) # instead of
参考链接: sjmielke/jax-purify
本文为机心报道,转载请联系本公众号获得授权。
原标题:“发现只有tf和pytorch是不够的,我们来看看从pytorch到自动微分神器jax的转移方法。”
阅读原文。
来源:重庆新闻
标题:热门:只知道TF和PyTorch还不够,快来看看如何从PyTorch转向自动微分神器JAX
地址:http://www.ccqdqw.cn/cqyw/26361.html