• tensorflow 2.0 学习(三)MNIST训练


    用tensorflow2.0 版回顾了一下mnist的学习

    代码如下,感觉这个版本下的mnist学习更简洁,更方便

    关于tensorflow的基础知识,这里就不更新了,用到什么就到网上搜索相关的知识

    # encoding: utf-8
    
    import numpy as np
    import tensorflow as tf
    import matplotlib.pyplot as plt
    
    #加载下载好的mnist数据库 60000张训练 10000张测试 每一张维度(28,28)
    path = r'G:2019pythonmnist.npz'
    f = np.load(path)
    x_train, y_train = f['x_train'], f['y_train']
    f.close()
    
    #预处理输入数据
    x = 2*tf.convert_to_tensor(x_train, dtype = tf.float32)/255. - 1
    x = tf.reshape(x, [-1, 28*28])
    y = tf.convert_to_tensor(y_train, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    
    #第一层输入256, 第二次输出128, 第三层输出10
    #第一,二,三层参数w,b
    w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.1))    #正态分布的一种
    b1 = tf.Variable(tf.zeros([256]))
    w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.1))
    b2 = tf.Variable(tf.zeros([128]))
    w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.1))
    b3 = tf.Variable(tf.zeros([10]))
    
    #将60000组数据切分为600组,每组100个数据
    train_db = tf.data.Dataset.from_tensor_slices((x, y)).batch(100)
    lr = 0.001      #学习率
    losses = []     #储存每epoch的loss值,便于观察学习情况
    
    for epoch in range(20):
        #一次性处理100组(x, y)数据
        for step, (x, y) in enumerate(train_db):    #遍历切分好的数据step:0->599
            with tf.GradientTape() as tape:
                #向前传播第一,二,三层
                h1 = x@w1 + tf.broadcast_to(b1, [x.shape[0], 256])  #可以直接写成 +b1
                h1 = tf.nn.relu(h1)
                h2 = h1@w2 + b2
                h2 = tf.nn.relu(h2)
                out = h2@w3 + b3
                #计算mse
                loss = tf.square(y - out)
                loss = tf.reduce_mean(loss)
            #计算参数的梯度,tape.gradient为自动求导函数,loss为目标数据,目的使它越来越接近真实值
            grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
            #更新w,b
            w1.assign_sub(lr*grads[0])  #原地减去给定的值,实现参数的自我更新
            b1.assign_sub(lr*grads[1])
            w2.assign_sub(lr*grads[2])
            b2.assign_sub(lr*grads[3])
            w3.assign_sub(lr*grads[4])
            b3.assign_sub(lr*grads[5])
            #观察学习情况
            if step%500 == 0:
                print(epoch, step, 'loss:', float(loss))
        #将每epoch的loss情况储存起来,最后观察
        losses.append(float(loss))
    
    plt.plot(losses, marker='s', label='training')
    plt.xlabel('Epoch')
    plt.ylabel('MSE')
    plt.legend()
    plt.savefig('exam_mnist_forward.png') plt.show()

    观察结果:

    可由注释理解代码的含义!下一次更新mnist数据集训练的进阶!

  • 相关阅读:
    图形化编程娱乐于教,Kittenblock实例,跟随鼠标指针画画的小螃蟹
    图形化编程娱乐于教,Kittenblock实例,图章效果的音乐画面
    图形化编程娱乐于教,Kittenblock实例,测试声音的响度
    图形化编程娱乐于教,Kittenblock实例,为背景添加音乐
    图形化编程娱乐于教,Kittenblock实例,小猫移动效果,碰触边缘反弹,左右翻转
    图形化编程娱乐于教,Kittenblock实例,按键控制角色发声
    Kittenblock实例,马赛克特效
    图形化编程娱乐于教,Kittenblock实例,广播消息
    redis的基本使用记录
    nginx学习笔记
  • 原文地址:https://www.cnblogs.com/heze/p/12076792.html
Copyright © 2020-2023  润新知