• CNN模板


    该代码使用CNN,实现一个简单的10分类问题

    代码如下:

    import tensorflow as tf
    from tensorflow.keras import datasets, layers, models
    import matplotlib.pyplot as plt
    
    #准备数据CIFAR10
    (train_images, train_labels),(test_images, test_labels)=datasets.cifar10.load_data()
    #将像素的值标准化
    train_images=train_images/255.0
    test_images=test_images/255.0
    
    #验证数据,将前25张图片打印出来
    '''
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck']
    
    plt.figure(figsize=(10,10))
    for i in range(25):
        plt.subplot(5,5,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(train_images[i], cmap=plt.cm.binary)
        plt.xlabel(class_names[train_labels[i][0]])
    plt.show()
    '''
    
    #构建神经网络模型CNN
    model=models.Sequential()
    model.add(layers.Conv2D(32, (3,3), activation='relu', input_shape=(32, 32, 3)))
    model.add(layers.MaxPooling2D(2,2))
    model.add(layers.Conv2D(64, (3,3), activation='relu'))
    model.add(layers.MaxPooling2D((2,2)))
    model.add(layers.Conv2D(64,(3,3), activation='relu'))
    
    #构建全联接层
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(10))
    
    #查看整个CNN结构
    model.summary()
    
    #训练并编译模型
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy']
                  )
    history=model.fit(train_images, train_labels, epochs=10,
                       validation_data=(test_images, test_labels))
    
    #模型评估
    plt.rcParams['font.sans-serif']=['SimHei']
    plt.plot(history.history['accuracy'], label='accuracy')
    plt.plot(history.history['val_accuracy'], label='val_accuracy')
    plt.xlabel('')
    plt.ylabel('准确率:')
    plt.ylim([0.5, 1])
    plt.legend(loc='best')
    plt.show()
  • 相关阅读:
    Linux添加系统环境变量
    keras 或 tensorflow 调用GPU报错:Blas GEMM launch failed
    python 安装虚拟环境
    Seq2Seq 到 Attention的演变
    聊天内容处理笔记
    LSTM 详解
    keras 打印模型图
    zip 的对象是不能用索引去取的
    c# 反射获取属性值 TypeUtils
    .iml文件恢复
  • 原文地址:https://www.cnblogs.com/ALINGMAOMAO/p/14108123.html
Copyright © 2020-2023  润新知