• tensorflow2.0——自定义全连接层实现并保存


    import tensorflow as tf
    
    
    def preprocess(x, y):
        x = tf.cast(x, dtype=tf.float32) / 255 - 0.5
        y = tf.cast(y, dtype=tf.int32)
        return x, y
    
    
    batchsz = 128
    #   [50k,32,32,3],[50k,1]
    (x, y), (x_val, y_val) = tf.keras.datasets.cifar10.load_data()
    y = tf.one_hot(y, depth=10)  # [50k,10]
    y_val = tf.one_hot(y_val, depth=10)
    print(x.shape, y.shape)
    y = tf.squeeze(y)  # 去掉为1 的维度
    y_val = tf.squeeze(y_val)
    print('squeeze后:')
    print(x.shape, y.shape, x.min(), x.max())
    
    train_db = tf.data.Dataset.from_tensor_slices((x, y))
    train_db = train_db.map(preprocess).shuffle(1000).batch(batchsz)
    test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
    test_db = test_db.map(preprocess).batch(batchsz)
    
    sample = next(iter((train_db)))  # 测试下数据集shape是否符合要求  batch (128, 32, 32, 3) (128, 10)
    print('batch:', sample[0].shape, sample[1].shape)
    
    
    #   自定义层
    #   代替标准的tf.keras.layers.Dense()
    class MyDense(tf.keras.layers.Layer):
        def __init__(self, inp_dim, oup_dim):  # 参数为输入的维度和输出维度
            super(MyDense, self).__init__()
            self.kernel = self.add_variable('w', [inp_dim, oup_dim])
            # self.bias = self.add_variable('b',[oup_dim])
    
        def call(self, inputs, training=None):  # 参数为数据
            x = inputs @ self.kernel
            return x
    
    #   自定义网络
    class MyNetwork(tf.keras.Model):
        def __init__(self):
            super(MyNetwork, self).__init__()
            self.fc1 = MyDense(32 * 32 * 3, 256)
            self.fc2 = MyDense(256, 256)
            self.fc3 = MyDense(256, 256)
            self.fc4 = MyDense(256, 32)
            self.fc5 = MyDense(32, 10)
    
        def call(self, inputs, training=None, mask=None):
            '''
            :param inputs:[b,32,32,3]
            :param training:
            :param mask:
            :return:
            '''
            #   [b,32,32,3] -> [b,32*32*3]
            x = tf.reshape(inputs,[-1,32*32*3])
            #   [b,32*32*3] -> [b,256]
            x = self.fc1(x)
            x = tf.nn.relu(x)
            #   [b,256] -> [b,128]
            x = self.fc2(x)
            x = tf.nn.relu(x)
            #   [b,128] -> [b,64]
            x = self.fc3(x)
            x = tf.nn.relu(x)
            #   [b,64] -> [b,32]
            x = self.fc4(x)
            x = tf.nn.relu(x)
            #   [b,32] -> [b,10]
            x = self.fc5(x)
            #   最后一层不需要激活函数
            return x
    
    network = MyNetwork()
    network.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-3),
                    loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                    metrics=['accuracy'])
    network.fit(train_db,epochs=13,validation_data=test_db,validation_freq=1)
    
    network.evaluate(test_db)
    network.save_weights('./save_w_model/test1')
    
    #   加载仅有参数的model
    network2 = MyNetwork()
    network2.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-3),
                    loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                    metrics=['accuracy'])
    network2.load_weights('./save_w_model/test1')
    print('加载仅有参数的模型')
    network2.evaluate(test_db)
  • 相关阅读:
    Linux recordmydesktop
    linux music play
    linux config NDK
    linux install wireshark
    Linux config cocos
    45 线程池都有哪些状态?
    44 创建线程池有哪几种方式?
    final 不能修饰抽象类和接口
    43 线程的 run() 和 start() 有什么区别?
    42 notify()和 notifyAll()有什么区别?
  • 原文地址:https://www.cnblogs.com/cxhzy/p/13697268.html
Copyright © 2020-2023  润新知