• tensorflow学习021——自定义训练中的tensorboard可视化


    点击查看代码
    import tensorflow as tf
    import datetime
    
    (train_image, train_labels), (test_image, test_labels)= tf.keras.datasets.mnist.load_data()
    train_image = tf.expand_dims(train_image,-1)  # 原先形状为(60000,28,28),现在变为(60000,28,28,1) 扩展维度,也就是通道维度
    # 如果使用参数1,那么形状就变为了(60000,1,28,28)
    test_image = tf.expand_dims(test_image,-1)
    test_image = test_image / 255
    train_image = train_image / 255  # 进行归一化
    train_image = tf.cast(train_image, tf.float32)  # 将数据类型转换为float,因为只有float才能进行自动微分
    test_image = tf.cast(test_image, tf.float32)
    train_labels = tf.cast(train_labels,tf.int64)
    test_labels = tf.cast(test_labels, tf.int64)
    dataset = tf.data.Dataset.from_tensor_slices((train_image,train_labels))  # 将图片和标签进行对应组合,这个函数是将第一维进行拆分
    # 也就是可以拆分为60000个单独的数据
    dataset = dataset.shuffle(1000).batch(32)  # 对数据进行混洗以及绑定32个为一组
    test_dataset = tf.data.Dataset.from_tensor_slices((test_image,test_labels))
    test_dataset = test_dataset.batch(32)
    
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(16, [3,3], activation='relu', input_shape=(None,None,1)),  # None表示只要是灰度图都可以,没有规定大小
        tf.keras.layers.Conv2D(32, [3,3], activation='relu'),
        tf.keras.layers.GlobalMaxPool2D(),
        tf.keras.layers.Dense(10,activation='softmax')  # 这里没有进行激活函数,那么就需要再后面的loss函数中进行一些操作
    ])
    
    optimizer = tf.keras.optimizers.Adam()
    loss_func = tf.keras.losses.SparseCategoricalCrossentropy()
    def loss(model,x,y):
        y_ = model(x)
        return loss_func(y,y_)
    
    train_loss = tf.keras.metrics.Mean('train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')
    test_loss = tf.keras.metrics.Mean('test_loss')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy')
    
    def train_step(model,images,labels):
        with tf.GradientTape() as t:
            pred = model(images)
            loss_step = loss_func(labels,pred)
        grads = t.gradient(loss_step,model.trainable_variables)
        optimizer.apply_gradients(zip(grads,model.trainable_variables))
        train_loss(loss_step)
        train_accuracy(labels,pred)
    
    def test_step(model, images, labels):
        pred = model(images)
        loss_step = loss_func(labels,pred)
        test_loss(loss_step)
        test_accuracy(labels,pred)
    
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_log_dir = 'logs/gradient_tape' + current_time + "train"
    test_log_dir = 'logs/gradient_tape' + current_time + "test"
    train_writer = tf.summary.create_file_writer(train_log_dir)
    test_writer = tf.summary.create_file_writer(test_log_dir)
    
    def train(epoches):
        for epoch in range(epoches):
            for (batch,(images,labels)) in enumerate(dataset):
                train_step(model,images,labels)
    
            with train_writer.as_default():
                tf.summary.scalar('loss',train_loss.result(),step=epoch)
                tf.summary.scalar('acc',train_accuracy.result(),step=epoch)
            for (batch,(images,labels)) in enumerate(test_dataset):
                test_step(model,images,labels)
                # print('*',end='')
            with test_writer.as_default():
                tf.summary.scalar('test_loss', test_loss.result(), step=epoch)
                tf.summary.scalar('test_acc', test_accuracy.result(), step=epoch)
            template = 'Epoch {},Loss:{},Acc:{},Test Loss :{},Test Acc:{}'
            print(template.format(epoch+1,train_loss.result(),train_accuracy.result()*100,test_loss.result(),test_accuracy.result()*100))
    
            train_loss.reset_states()
            train_accuracy.reset_states()
            test_loss.reset_states()
            test_accuracy.reset_states()
    
    train(10)
    
  • 相关阅读:
    【锁】java 锁的技术内幕
    【BlockingQueue】BlockingQueue 阻塞队列实现
    【多线程】获取多个线程任务执行完事件
    【spring cloud】源码分析(一)
    【spring boot】FilterRegistrationBean介绍
    【FAQ】服务下线
    解决org.apache.ibatis.binding.BindingException: Invalid bound statement (not found)...
    实现人民币大写代码解析
    application.yml使用@符合问题:'@' that cannot start any token. (Do not use @ for indentation)
    Maven常见异常及解决方法---测试代码编译错误
  • 原文地址:https://www.cnblogs.com/sunjianzhao/p/15961165.html
Copyright © 2020-2023  润新知