[TOC]
本文将用numpy实现全连接层的前向过程和反向过程,并使用一个线性回归作为例子进行测试;
如果对神经网络的反向传播过程还有不清楚的,可以0_1-全连接层、损失函数的反向传播
import numpy as np
def fc_forword(z, W, b):
"""
全连接层的前向传播
:param z: 当前层的输出
:param W: 当前层的权重
:param b: 当前层的偏置
:return: 下一层的输出
"""
return np.dot(z, W) + b
def fc_backword(next_dz, W, z):
"""
全连接层的反向传播
:param next_dz: 下一层的梯度
:param W: 当前层的权重
:param z: 当前层的输出
:return:
"""
N = z.shape[0]
dz = np.dot(next_dz, W.T) # 当前层的梯度
dw = np.dot(z.T, next_dz) # 当前层权重的梯度
db = np.sum(next_dz,axis=0) # 当前层偏置的梯度, N个样本的梯度求和
return dw/N, db/N, dz
def mean_squared_loss(y_predict,y_true):
"""
均方误差损失函数
:param y_predict: 预测值,shape (N,d),N为批量样本数
:param y_true: 真实值
:return:
"""
loss = np.mean(np.sum(np.square(y_predict-y_true),axis=-1)) # 损失函数值
dy = y_predict - y_true # 损失函数关于网络输出的梯度
return loss, dy
# 实际的权重和偏置
W = np.array([[3,7,4],
[5,2,6]])
b = np.array([2,9,3])
# 产生训练样本
x_data = np.random.randint(0,10,1000).reshape(500,2)
y_data = np.dot(x_data,W)+b
def next_sample(batch_size=1):
idx=np.random.randint(500)
return x_data[idx:idx+batch_size],y_data[idx:idx+batch_size]
print("x.shape:{},y.shape:{}".format(x_data.shape,y_data.shape))
x.shape:(500, 2),y.shape:(500, 3)
# 随机初始化参数
W1 = np.random.randn(2,3)
b1 = np.zeros([3])
loss = 100.0
lr = 0.01
i = 0
while loss > 1e-15:
x,y_true=next_sample(2) # 获取当前样本
# 前向传播
y = fc_forword(x,W1,b1)
# 反向传播更新梯度
loss,dy=mean_squared_loss(y,y_true)
dw,db,_ = fc_backword(dy,W,x)
# 在一个batch中梯度取均值
#print(dw)
# 更新梯度
W1 -= lr*dw
b1 -= lr*db
# 更新迭代次数
i += 1
if i % 1000 == 0:
print("\n迭代{}次,当前loss:{}, 当前权重:{},当前偏置{}".format(i,loss,W1,b1))
# 打印最终结果
print("\n迭代{}次,当前loss:{}, 当前权重:{},当前偏置{}".format(i,loss,W1,b1))
迭代1000次,当前loss:0.43387298848896233, 当前权重:[[3.01734672 7.12785625 4.02756123]
[5.0221794 2.16347613 6.0352396 ]],当前偏置[1.81757802 7.65543542 2.71016 ]
迭代2000次,当前loss:0.024775748245913158, 当前权重:[[3.00242166 7.01784918 4.00384764]
[5.00295757 2.02179914 6.00469912]],当前偏置[1.96775495 8.76233376 2.94876766]
迭代3000次,当前loss:0.00014564406568725818, 当前权重:[[3.00082136 7.00605396 4.00130502]
[5.00061563 2.00453758 6.00097814]],当前偏置[1.99381124 8.95438495 2.99016703]
迭代4000次,当前loss:2.6237167410353415e-05, 当前权重:[[3.0001119 7.00082475 4.00017779]
[5.00008191 2.0006037 6.00013014]],当前偏置[1.99885749 8.99157899 2.99818473]
迭代5000次,当前loss:3.713805657221762e-07, 当前权重:[[3.00002322 7.00017112 4.00003689]
[5.00001109 2.00008176 6.00001763]],当前偏置[1.99979785 8.99851001 2.99967881]
迭代6000次,当前loss:8.807646869757514e-09, 当前权重:[[3.0000031 7.00002283 4.00000492]
[5.00000397 2.00002927 6.00000631]],当前偏置[1.99996212 8.9997208 2.99993981]
迭代7000次,当前loss:1.536245925844849e-10, 当前权重:[[3.00000073 7.00000539 4.00000116]
[5.00000067 2.00000494 6.00000106]],当前偏置[1.99999324 8.99995017 2.99998926]
迭代7398次,当前loss:3.3297294256090265e-16, 当前权重:[[3.00000043 7.00000318 4.00000069]
[5.0000004 2.00000294 6.00000063]],当前偏置[1.99999655 8.99997456 2.99999452]
print("W1==W: {} \nb1==b: {}".format(np.allclose(W1,W),np.allclose(b1,b)))
W1==W: True
b1==b: True