Skip to content

Latest commit

 

History

History
147 lines (106 loc) · 4.31 KB

1_1_1-全连接神经网络做线性回归.md

File metadata and controls

147 lines (106 loc) · 4.31 KB

[TOC]

一、定义前向、后向传播

本文将用numpy实现全连接层的前向过程和反向过程,并使用一个线性回归作为例子进行测试;

如果对神经网络的反向传播过程还有不清楚的,可以0_1-全连接层、损失函数的反向传播

import numpy as np

def fc_forword(z, W, b):
    """
    全连接层的前向传播
    :param z: 当前层的输出
    :param W: 当前层的权重
    :param b: 当前层的偏置
    :return: 下一层的输出
    """
    return np.dot(z, W) + b


def fc_backword(next_dz, W, z):
    """
    全连接层的反向传播
    :param next_dz: 下一层的梯度
    :param W: 当前层的权重
    :param z: 当前层的输出
    :return:
    """
    N = z.shape[0]
    dz = np.dot(next_dz, W.T)  # 当前层的梯度
    dw = np.dot(z.T, next_dz)  # 当前层权重的梯度
    db = np.sum(next_dz,axis=0)  # 当前层偏置的梯度, N个样本的梯度求和
    return dw/N, db/N, dz

二、定义损失函数

def mean_squared_loss(y_predict,y_true):
    """
    均方误差损失函数
    :param y_predict: 预测值,shape (N,d),N为批量样本数
    :param y_true: 真实值
    :return:
    """
    loss = np.mean(np.sum(np.square(y_predict-y_true),axis=-1))  # 损失函数值
    dy = y_predict - y_true  # 损失函数关于网络输出的梯度
    return loss, dy

三、初始化数据

# 实际的权重和偏置
W = np.array([[3,7,4],
              [5,2,6]])
b = np.array([2,9,3])

# 产生训练样本
x_data = np.random.randint(0,10,1000).reshape(500,2)
y_data = np.dot(x_data,W)+b

def next_sample(batch_size=1):
    idx=np.random.randint(500)
    return x_data[idx:idx+batch_size],y_data[idx:idx+batch_size]

print("x.shape:{},y.shape:{}".format(x_data.shape,y_data.shape))
x.shape:(500, 2),y.shape:(500, 3)

四、定义网络、使用SGD训练

# 随机初始化参数
W1 = np.random.randn(2,3)
b1 = np.zeros([3])
loss = 100.0
lr = 0.01
i = 0 

while loss > 1e-15:
    x,y_true=next_sample(2)  # 获取当前样本
    # 前向传播
    y = fc_forword(x,W1,b1)
    # 反向传播更新梯度
    loss,dy=mean_squared_loss(y,y_true)
    dw,db,_ = fc_backword(dy,W,x)
    
    # 在一个batch中梯度取均值
    #print(dw)
    
    # 更新梯度
    W1 -= lr*dw
    b1 -= lr*db
    
    # 更新迭代次数
    i += 1
    if i % 1000 == 0:
        print("\n迭代{}次,当前loss:{}, 当前权重:{},当前偏置{}".format(i,loss,W1,b1))   

# 打印最终结果
print("\n迭代{}次,当前loss:{}, 当前权重:{},当前偏置{}".format(i,loss,W1,b1))
迭代1000次,当前loss:0.43387298848896233, 当前权重:[[3.01734672 7.12785625 4.02756123]
 [5.0221794  2.16347613 6.0352396 ]],当前偏置[1.81757802 7.65543542 2.71016   ]

迭代2000次,当前loss:0.024775748245913158, 当前权重:[[3.00242166 7.01784918 4.00384764]
 [5.00295757 2.02179914 6.00469912]],当前偏置[1.96775495 8.76233376 2.94876766]

迭代3000次,当前loss:0.00014564406568725818, 当前权重:[[3.00082136 7.00605396 4.00130502]
 [5.00061563 2.00453758 6.00097814]],当前偏置[1.99381124 8.95438495 2.99016703]

迭代4000次,当前loss:2.6237167410353415e-05, 当前权重:[[3.0001119  7.00082475 4.00017779]
 [5.00008191 2.0006037  6.00013014]],当前偏置[1.99885749 8.99157899 2.99818473]

迭代5000次,当前loss:3.713805657221762e-07, 当前权重:[[3.00002322 7.00017112 4.00003689]
 [5.00001109 2.00008176 6.00001763]],当前偏置[1.99979785 8.99851001 2.99967881]

迭代6000次,当前loss:8.807646869757514e-09, 当前权重:[[3.0000031  7.00002283 4.00000492]
 [5.00000397 2.00002927 6.00000631]],当前偏置[1.99996212 8.9997208  2.99993981]

迭代7000次,当前loss:1.536245925844849e-10, 当前权重:[[3.00000073 7.00000539 4.00000116]
 [5.00000067 2.00000494 6.00000106]],当前偏置[1.99999324 8.99995017 2.99998926]

迭代7398次,当前loss:3.3297294256090265e-16, 当前权重:[[3.00000043 7.00000318 4.00000069]
 [5.0000004  2.00000294 6.00000063]],当前偏置[1.99999655 8.99997456 2.99999452]
print("W1==W: {} \nb1==b:  {}".format(np.allclose(W1,W),np.allclose(b1,b)))
W1==W: True 
b1==b:  True