• tensorflow学习018——200中鸟类图片分类实例


    数据链接:https://pan.baidu.com/s/1zxa2KnkW5nFYFhF_bzfxkg

    提取码:eii7

    需要注意的是将里面所有的路径改成自己下载的数据所在的路径

     

    from tensorflow import keras
    import tensorflow as tf
    import matplotlib.pyplot as plt
    import numpy as np
    import glob
    
    #1 获取图片和标签
    imgs_path = glob.glob(r"E:\WORK\tensorflow\dataset\日月光华_birds分类竞赛数据\birds_train\*\*.jpg")
    all_labels_name = [img_p.split("\\")[-2].split(".")[1] for img_p in imgs_path]
    label_names = np.unique(all_labels_name)
    label_to_index = dict((name, i) for i, name in enumerate(label_names))
    index_to_label = dict((v,k) for k, v in label_to_index.items())
    
    #2 创建dataset 读取图片
    all_labels = [label_to_index.get(name) for name in all_labels_name]
    np.random.seed(2021)
    random_index = np.random.permutation(len(imgs_path)) #返回一个乱序,可以对img和label同时进行乱序操作
    imgs_path = np.array(imgs_path)[random_index]
    all_labels = np.array(all_labels)[random_index]
    train_count = int(len(imgs_path)*0.8)
    train_path = imgs_path[:train_count]
    train_labels = all_labels[:train_count]
    test_path = imgs_path[train_count:]
    test_labels = all_labels[train_count:]
    train_ds = tf.data.Dataset.from_tensor_slices((train_path,train_labels))
    test_ds = tf.data.Dataset.from_tensor_slices((test_path,test_labels))
    def load_img(path, label):
        image = tf.io.read_file(path)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.resize(image, [256,256])
        image = tf.cast(image,tf.float32)
        image = image / 255
        return image, label
    
    #3 模型和损失函数
    
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    train_ds = train_ds.map(load_img, num_parallel_calls=AUTOTUNE) #会根据CPU的情况开启多线程
    test_ds = test_ds.map(load_img, num_parallel_calls=AUTOTUNE)
    BATCH_SIZE = 8
    train_ds = train_ds.repeat().shuffle(100).batch(BATCH_SIZE)
    test_ds = test_ds.batch(BATCH_SIZE)
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(64,(3,3),input_shape=(256,256,3)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.Conv2D(64,(3,3)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D(128,(3,3)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.Conv2D(128,(3,3)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D(256,(3,3)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.Conv2D(256,(3,3)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D(512,(3,3)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.Conv2D(512,(3,3)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D(512,(3,3)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.Conv2D(512,(3,3)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.Conv2D(512,(3,3)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(1024),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.Dense(512),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.Dense(200) #激活放在损失函数中
    ])
    model.compile(optimizer=tf.keras.optimizers.Adam(0.0001),
                  loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),#并进行激活
                  metrics=['acc']
                  )
    
    # 4训练和预测
    train_count = len(train_path)
    test_count = len(test_path)
    steps_per_epoch = train_count // BATCH_SIZE
    validation_steps = test_count // BATCH_SIZE
    history = model.fit(
        train_ds,epochs=10,
        steps_per_epoch=steps_per_epoch,
        validation_data=test_ds,
        validation_steps=validation_steps
    )
    
    def load_and_preprocess_image(path):
        img_raw = tf.io.read_file(path)
        img_tensor = tf.image.decode_jpeg(img_raw, channels=3)  # tf.Tensor 每一个元素的大小都是[0,255] 是uint8
        img_tensor = tf.image.resize(img_tensor, [256, 256])  # 将所有图片统一大小
        # 转换为float类型
        img_tensor = tf.cast(img_tensor, tf.float32)
        # 归一化
        img_tensor = img_tensor / 255
        return img_tensor
    
    test_img = r"E:\WORK\tensorflow\dataset\日月光华_birds分类竞赛数据\birds_test\0.jpg"
    test_tensor = load_and_preprocess_image(test_img)
    test_tensor = tf.expand_dims(test_tensor, axis=0)
    pred = model.predict(test_tensor)
    print(index_to_label.get(np.argmax(pred)))

     

     

  • 相关阅读:
    Exception in thread "main" java.lang.IllegalArgumentException:解决方法
    Warning: $HADOOP_HOME is deprecated.解决方法
    Android :TextView使用SpannableString设置复合文本
    一、harbor部署之centos7的基本配置
    木马基础知识要点
    【原创】红客闯关游戏部分题解
    【原创】利用Office宏实现powershell payload远控
    【原创】字典攻击教务处(BurpSuite使用)
    【原创】逆向练习(CrackMe)
    显式游标和隐式游标二者的区别
  • 原文地址:https://www.cnblogs.com/sunjianzhao/p/15961087.html
Copyright © 2020-2023  润新知