• Tensorflow学习009——softmax多分类


    2.8 softmax分类

    前面学的对数几率回归解决的是二分类问题,对于多分类问题,可以使用softmax函数,它是对数几率再N个可能不同的值上的推广
    神经网络的原始输出并不是一个概率值,实质上只是对输入的数值做了复杂的甲醛和非线性处理之后的一个值而已,而softmax层就可以将这个输出变成概率分布

    image
    图2-17
    如图2-17所示,是softmax的计算公式,softmax要求每个样本必须属于某个类别,且所有可能的样本均被覆盖,softmax样本分量之和为1,当只有两个类别时,与对数几率回归完全相同
    在tf.keras中,对于多分类问题使用categorical_crossentropy和sparse_categorical_crossentropy(当标签是数字的时候)来计算softmax交叉熵

    先介绍一下需要使用的数据集。
    Fashion MnIST时经典MNIST数据集的简易替换,MNIST数据集包含手写数字(0,1,2等)的图像,这些图像的格式与本节课中使用的Fashion MNIST服饰图像的格式相同
    Fashion MNIST比常规的MNIST手写数据集更具有挑战性,这两个数据集都比较小,用于验证某个算法能否如期正常运行,是测试和调试代码的良好起点。
    Fashion MNIST数据集包含了70000张灰度图像,涵盖了10个类别。
    将使用60000张图像训练网络,并使用10000张图像评估经过学习的网络分类图像的准确率,可以从Tensorflow直接访问Fashion MNIST,只需要导入和加载数据即可。

    点击查看代码
    import tensorflow as tf
    (train_image,train_label),(test_image,test_label) = tf.keras.datasets.fashion_mnist.load_data()
    

    运行代码

    点击查看代码
    import tensorflow as tf
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    #加载数据集,第一次运行会进行下载,会比较慢一点
    (train_image,train_label),(test_image,test_label) = tf.keras.datasets.fashion_mnist.load_data()
    #train_image的形状是(60000,28,28) train_label的形状是(60000,)
    #train_label使用数值0,1.。。9表示类别
    #对数据继续宁归一化
    train_image = train_image / 255
    test_image = test_image / 255
    
    model = tf.keras.Sequential()
    #相比较之前写的代码,这里的每个输入都是2维的(28,28),所以需要flatten对其进行展平,将其映射成一维数据,长度为28*28
    model.add(tf.keras.layers.Flatten(input_shape=(28,28)))
    model.add(tf.keras.layers.Dense(128,activation='relu'))
    model.add(tf.keras.layers.Dense(10,activation='softmax')) #使用softmax机会将结果转换为概率
    model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])
    model.fit(train_image,train_label,epochs=5)
    print(model.evaluate(test_image,test_label))
    

    独热编码:
    当前有三类,如果北京,则标签为[1,0,0]如果时上海是[0,1,0]。这样只有相关类才是1就是独热编码的方式
    train_label_onehot = tf.keras.utils.to_categorical(train_label)
    将数字编码换成独热编码
    按照独热编码方式预测的结果标签也是独热编码方式,我们可以直接使用np.argmax(predict)将预测结果转换为id顺序标签


    作者:孙建钊
    出处:http://www.cnblogs.com/sunjianzhao/
    本文版权归作者和博客园共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则保留追究法律责任的权利。

  • 相关阅读:
    Python命名规范 Test
    MySQL学习总结之路(第七章:选择合适的数据类型)
    MySQL 遇到错误以及解决
    MySQL学习 配置执行日志记录
    SpringCloud Alibaba整合Sentinel
    Java对姓名, 手机号, 身份证号, 地址进行脱敏
    JSP页面button标签的id与onclick函数名字相同导致函数失效的问题
    mysql游标的使用:对查询的结果进行遍历
    基于SpringBoot+redis如何实现一个点赞功能?
    byte与str的相互转换
  • 原文地址:https://www.cnblogs.com/sunjianzhao/p/15552845.html
Copyright © 2020-2023  润新知