[经验] 《机器学习算法与实现 —— Python编程与应用实例》神经网络的训练 - 反向传播算法

lospring   2024-8-4 23:49 楼主

在多层神经网络中有这样一个问题:最后一层的参数可以用这样的方式求解得到;隐层节点没有输出的真值,因此无法直接构建损失函数来求解。

反向传播算法可以解决该问题,反射传播自满其实就是链式求导法则的应用。

按照机器学习的通用求解思路,我们先确定神经网络的目标函数,然后用随机梯度下降优化算法去求目标函数最小值时的参数值。

取网络所有输出层节点的误差平方和作为目标函数:

image.png  

其中,Ed表示是样本d的误差, t是样本的标签值,y是神经网络的输出值。

然后,使用随机梯度下降算法对目标函数进行优化:

image.png  

随机梯度下降算法也就是需要求出误差Ed对于每个权重wji的偏导数(也就是梯度),如何求解?

image.png  

观察上图,可发现权重wji仅能通过影响节点j的输入值影响网络的其它部分,设netj是节点j的加权输入,即

image.png  

Ed是netj的函数,而netj是wji的函数。根据链式求导法则,可以得到:

image.png  

上式中,xji是节点传递给节点j的输入值,也就是节点i的输出值。

对于的∂Ed/∂netj推导,需要区分输出层和隐藏层两种情况。

1、输出层权值训练

image.png  

对于输出层来说,netj仅能通过节点j的输出值yj来影响网络其它部分,也就是说Ed是yj的函数,而yj是netj的函数,其中yj=sigmod(netj)。所以我们可以再次使用链式求导法则:

image.png  

其中:

image.png   image.png  

将第一项和第二项带入,得到:

image.png  

如果令δj=−∂Ed/∂netj,也就是一个节点的误差项δ是网络误差对这个节点输入的偏导数的相反数。带入上式,得到:

image.png  

将上述推导带入随机梯度下降公式,得到:

image.png  

2、隐藏层权值训练

 

现在我们要推导出隐藏层的∂Ed/∂netj∂:

image.png  

首先,我们需要定义节点j的所有直接下游节点的集合Downstream(j)。例如,对于节点4来说,它的直接下游节点是节点8、节点9。可以看到netj只能通过影响Downstream(j)再影响Ed。设netk是节点j的下游节点的输入,则Ed是netk的函数,而netk是netj的函数。因为netk有多个,我们应用全导数公式,可以做出如下推导:

image.png  

因为δj=−∂Ed/∂netj,带入上式得到:

image.png  

至此,我们已经推导出了反向传播算法。需要注意的是,我们刚刚推导出的训练规则是根据激活函数是sigmoid函数、平方和误差、全连接网络、随机梯度下降优化算法。如果激活函数不同、误差计算方式不同、网络连接结构不同、优化算法不同,则具体的训练规则也会不一样。但是无论怎样,训练规则的推导方式都是一样的,应用链式求导法则进行推导即可。

3、具体解释

image.png  

然后,按照下面的方法计算出每个节点的误差项δi:

对于输出层节点i

image.png

其中,δi是节点i的误差项,yi是节点i的输出值,ti是样本对应于节点i的目标值。举个例子,根据上图,对于输出层节点8来说,它的输出值是y1,而样本的目标值是t1,带入上面的公式得到节点8的误差项应该是:

image.png  

对于隐藏层节点

image.png  

其中,ai是节点i的输出值,wki是节点i到它的下一层节点k的连接的权重,δk是节点i的下一层节点k的误差项。例如,对于隐藏层节点4来说,计算方法如下:

image.png  

最后,更新每个连接上的权值:

image.png  

其中,wji是节点i到节点j的权重,η是一个成为学习速率的常数,δj是节点j的误差项,xji是节点i传递给节点j的输入。例如,权重w84的更新方法如下:

image.png  

类似的,权重w41的更新方法如下:

image.png  

偏置项的输入值永远为1。例如,节点4的偏置项w4b应该按照下面的方法计算:

image.png  

计算一个节点的误差项,需要先计算每个与其相连的下一层节点的误差项,这就要求误差项的计算顺序必须是从输出层开始,然后反向依次计算每个隐藏层的误差项,直到与输入层相连的那个隐藏层,这就是反向传播算法的名字的含义。当所有节点的误差项计算完毕后,就可以根据式5来更新所有的权重。

以上就是反向传播算法的一个求解过程,整个过程也是搬抄其它大佬的结果,希望对大家有点帮助。

 

回复评论

暂无评论,赶紧抢沙发吧
电子工程世界版权所有 京B2-20211791 京ICP备10001474号-1 京公网安备 11010802033920号
    写回复