• keras实现残差网络(keras搬砖二)


    参考keras官网

     1 from keras import layers
     2 import keras
     3 import numpy as np
     4 
     5 inputs = keras.Input(shape=(32, 32, 3), name="img")
     6 x = layers.Conv2D(32, 3, activation="relu")(inputs)
     7 x = layers.Conv2D(64, 3, activation="relu")(x)
     8 block_1_output = layers.MaxPooling2D(3)(x)
     9 
    10 x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_1_output)
    11 x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)
    12 block_2_output = layers.add([x, block_1_output])
    13 
    14 x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_2_output)
    15 x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)
    16 block_3_output = layers.add([x, block_2_output])
    17 
    18 x = layers.Conv2D(64, 3, activation="relu")(block_3_output)
    19 x = layers.GlobalAveragePooling2D()(x)
    20 x = layers.Dense(256, activation="relu")(x)
    21 x = layers.Dropout(0.5)(x)
    22 outputs = layers.Dense(10)(x)
    23 
    24 model = keras.Model(inputs, outputs, name="toy_resnet")
    25 model.summary()
    26 
    27 # 绘制模型
    28 keras.utils.plot_model(model, "mini_resnet.png", show_shapes=True)
    29 
    30 # 训练模型
    31 (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
    32 
    33 x_train = x_train.astype("float32") / 255.0
    34 x_test = x_test.astype("float32") / 255.0
    35 y_train = keras.utils.to_categorical(y_train, 10)
    36 y_test = keras.utils.to_categorical(y_test, 10)
    37 
    38 model.compile(
    39     optimizer=keras.optimizers.RMSprop(1e-3),
    40     loss="categorical_crossentropy",
    41     metrics=["accuracy"],
    42 )
    43 # We restrict the data to the first 1000 samples so as to limit execution time
    44 # on Colab. Try to train on the entire dataset until convergence!
    45 model.fit(x_train[:1000], y_train[:1000], batch_size=64, epochs=1, validation_split=0.2)
     

    模型

  • 相关阅读:
    netbeans 快捷键
    Netbeans Platform的Lookup 边学边记
    Swing中的并发使用SwingWorker线程模式
    转:ExtJS:tabpanel 多个tab同时渲染问题
    NetBeans Platform Login Tutorial
    检测项目
    C#实现通过程序自动抓取远程Web网页信息
    WebClient类
    string字符串的方法
    WebClient的研究笔记认识WebClient
  • 原文地址:https://www.cnblogs.com/pergrand/p/12924779.html
Copyright © 2020-2023  润新知