• 简单粗暴的tensorflow自定义层、损失函数、评估指标


    # 自定义层  y_pred=w*x+b
    class LinearLayer(tf.keras.layers.Layer):
        def __init__(self, units):
            super().__init__()
            self.units = units
    
        def build(self, input_shape):     # 这里 input_shape 是第一次运行call()时参数inputs的形状
            self.w = self.add_weight(name='w',
                shape=[input_shape[-1], self.units], initializer=tf.zeros_initializer())
            self.b = self.add_weight(name='b',
                shape=[self.units], initializer=tf.zeros_initializer())
    
        def call(self, inputs):
            y_pred = tf.matmul(inputs, self.w) + self.b
            return y_pred
    # 模型定义
    class LinearModel(tf.keras.Model):
        def __init__(self):
            super().__init__()
            self.layer = LinearLayer(units=1)
    
        def call(self, inputs):
            output = self.layer(inputs)
            return output
    # 自定义损失函数
    class MeanSquaredError(tf.keras.losses.Loss):
        def call(self, y_true, y_pred):
            return tf.reduce_mean(tf.square(y_pred - y_true))
    # 自定义评估指标
    class SparseCategoricalAccuracy(tf.keras.metrics.Metric):
        def __init__(self):
            super().__init__()
            self.total = self.add_weight(name='total', dtype=tf.int32, initializer=tf.zeros_initializer())
            self.count = self.add_weight(name='count', dtype=tf.int32, initializer=tf.zeros_initializer())
    
        def update_state(self, y_true, y_pred, sample_weight=None):
            values = tf.cast(tf.equal(y_true, tf.argmax(y_pred, axis=-1, output_type=tf.int32)), tf.int32)
            self.total.assign_add(tf.shape(y_true)[0])
            self.count.assign_add(tf.reduce_sum(values))
    
        def result(self):
            return self.count / self.total
  • 相关阅读:
    MySQL学习——操作表
    MySQL学习——数据类型
    MySQL学习——操作数据库
    MySQL学习——存储引擎
    Linux网络——配置防火墙的相关命令
    查询各分类中最大自增ID
    CentOS7下Rsync+sersync实现数据实时同步
    mysql的join连接查询优化经历
    搭建nginx代理支持前端页面跨域调用接口
    Centos查看系统CPU个数、核心数、线程数
  • 原文地址:https://www.cnblogs.com/wuyuan2011woaini/p/15904736.html
Copyright © 2020-2023  润新知