• tf2 fashion_mnist 入门


    学习使用tf2

    视频教程传送门

    知识点:

    loss="sparse_categorical_crossentropy"

    这个 sparse是对y进行one-hot操作,如果y已经做过one-hot,则使用 categorical_crossentropy.

    #!/usr/bin/env python
    # coding: utf-8
    
    # In[1]:
    
    
    import tensorflow as tf
    import tensorflow.keras as k
    import numpy as np
    import matplotlib.pyplot as plt
    
    
    # In[21]:
    
    
    fashion_mnist = k.datasets.fashion_mnist
    (x_train,y_train),(x_test,y_test)=fashion_mnist.load_data()
    x_train,x_valid = x_train[:5000],x_train[5000:]
    y_train,y_valid= y_train[:5000],y_train[5000:]
    
    
    # In[7]:
    
    
    def show_single_img(img):
        plt.imshow(img,cmap="binary")
        plt.show()
    
    
    # In[8]:
    
    
    show_single_img(x_vaild[0])
    
    
    # In[16]:
    
    
    def show_imgs(n_rows,n_cols,x,y,classes):
        plt.figure(figsize=(n_rows*1.4,n_cols*1.6))
        for row in range(n_rows):
            for col in range(n_cols):
                index = row * n_cols + col
                plt.subplot(n_rows,n_cols,index+1)
                plt.imshow(x[index],cmap="binary")
                plt.title(classes[y[index]])
                plt.axis("off")
    classes=['T-shirt/top','Trouser','Pullover','Dress','Coat',
             'Sandal','Shirt','Sneaker','Bag','Ankle boot']
    
    
    # In[17]:
    
    
    show_imgs(1,5,x_train[:5],y_train[:5],classes)
    
    
    # In[24]:
    
    
    #build the model
    model =k.Sequential()
    model.add(k.layers.Flatten(input_shape=[28,28]))
    model.add(k.layers.Dense(300,activation="relu"))
    model.add(k.layers.Dense(100,activation="relu"))
    model.add(k.layers.Dense(10,activation="softmax"))
    
    model.compile(loss="sparse_categorical_crossentropy",
                 optimizer="adam",
                 metrics=["accuracy"])
    
    
    # In[25]:
    
    
    history=model.fit(x_train,y_train,epochs=10,
             validation_data=(x_valid,y_valid))
    
    
    # In[27]:
    
    
    import pandas as pd
    def plot_curve(history):
        pd.DataFrame(history.history).plot(figsize=(8,5))
        plt.grid(True)
        plt.gca().set_ylim(0,1)
        plt.show()
    plot_curve(history)
    
    
    # In[ ]:
    View Code

    适用sklearn对数据集进行归一化操作

    #!/usr/bin/env python
    # coding: utf-8
    
    # In[5]:
    
    
    import tensorflow as tf
    import tensorflow.keras as k
    import numpy as np
    import matplotlib.pyplot as plt
    
    
    # In[6]:
    
    
    fashion_mnist = k.datasets.fashion_mnist
    (x_train,y_train),(x_test,y_test)=fashion_mnist.load_data()
    x_train,x_valid = x_train[:5000],x_train[5000:]
    y_train,y_valid= y_train[:5000],y_train[5000:]
    
    
    # In[7]:
    
    
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
    x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
    x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
    
    
    # In[8]:
    
    
    #build the model
    model =k.Sequential()
    model.add(k.layers.Flatten(input_shape=[28,28]))
    model.add(k.layers.Dense(300,activation="relu"))
    model.add(k.layers.Dense(100,activation="relu"))
    model.add(k.layers.Dense(10,activation="softmax"))
    
    model.compile(loss="sparse_categorical_crossentropy",
                 optimizer="adam",
                 metrics=["accuracy"])
    
    
    # In[9]:
    
    
    history=model.fit(x_train,y_train,epochs=10,
             validation_data=(x_valid,y_valid))
    
    
    # In[10]:
    
    
    import pandas as pd
    def plot_curve(history):
        pd.DataFrame(history.history).plot(figsize=(8,5))
        plt.grid(True)
        plt.gca().set_ylim(0,1)
        plt.show()
    plot_curve(history)
    
    
    # In[ ]:
    View Code
  • 相关阅读:
    蜗牛讲-Fabric入门之架构
    No module named flask 错误解决
    adb测试Doze和App Standby模式
    以太坊挖矿原理
    mac上 go-delve 安装出现The specified item could not be found in the keychain 解决方法
    nginx+lua 根据指定路径反向代理
    asp.net 网站监控方案
    go开源项目influxdb-relay源码分析(一)
    碰到的jpython用ssh连接机器,有些命令无法运行
    git常用命令(备忘)
  • 原文地址:https://www.cnblogs.com/superxuezhazha/p/12257140.html
Copyright © 2020-2023  润新知