• torch.utils.tensorboard



    TensorBoard 神经网络可视化工具


    1. pytorch 官方文档解析

    from torch.utils.tensorboard import SummaryWriterclass
    
    class SummaryWriter(object)
    	def __init__(
            self,
            log_dir=None, 
            comment='', 
            purge_step=None, 
            max_queue=10, 
            flush_secs=120, 
            filename_suffix=''
        )
        def add_scalar(
            self,
            tag, 
            scalar_value, 
            global_step=None,
            walltime=None, 
            new_style=False, 
            double_precision=False
        )
        def add_scalars(
            self, 
            main_tag, 
            tag_scalar_dict, 
            global_step=None, 
            walltime=None
        )
        def add_histogram(
            self,
            tag,
            values,
            global_step=None,
            bins="tensorflow",
            walltime=None,
            max_bins=None,
        )
        def add_image(
            self, 
            tag, 
            img_tensor, 
            global_step=None, 
            walltime=None, 
            dataformats="CHW"
        )
        def add_images(
            self, 
            tag, 
            img_tensor, 
            global_step=None, 
            walltime=None, 
            dataformats="NCHW"
        )
        def add_graph(
            self, 
            model, 
            input_to_model=None, 
            verbose=False, 
            use_strict_trace=True
        )
    

    将条目直接写入 \(log\_dir\) 中的时间文件以供 \(TensorBoard\) 使用。

    \(SummaryWriter\) 类提供了一个高级 \(API\),用于在给定目录中创建事件文件并向其添加摘要和事件。


    __init__

    • log_dir(\(string\)):保存目录位置。默认为 runs/current_datetime_hostname,每次运行后都会更改。可自定义。
    • comment(\(string\)):不指定 log_dir, 文件夹后缀。
    • filename_suffix(\(int\)):log_dir目录中所有事件文件名后缀。

    add_scalar:记录标量。

    • tag(\(string\)):标签名。
    • scalar_value(\(float、string、blobname\)):要记录的标量。
    • global_step(\(int\)):轮次。
    • new_stype(\(boolean\)):使用新样式(张亮字段)还是旧样式(\(simple\_value\) 字段)。新样式可能有更快的加载速度。

    add_scalars:记录多个标量。

    • main_tag(\(string\)):多个标签名。
    • tag_scalar_dict(\(dict\)):存储标签和对应的键值。
    • global_step(\(int\)):轮次。

    add_histogram:统计直方图与多分位数折线图

    • tag(\(string\)):标签名。
    • values(\(torch.Tensor、numpy.array、string、blobname\)):构建直方图的值。
    • global_step(\(int\)):轮次。
    • bins(\(string\)):取值 \(tensorflow、auto、fd\) 等。这决定如何制作垃圾箱。

    add_image:显示图像

    • tag(\(string\)):标签名。
    • img_tensor(\(torch.Tensor、numpy.array、string、blobname\)):图像数据。
    • global_step(\(int\)):轮次。
    • dataformats(\(string\)):\(CHW、HWC、HW、WH\) 图像数据的格式。

    add_images:批量显示图像

    • tag(\(string\)):标签名。
    • img_tensor(\(torch.Tensor、numpy.array、string、blobname\)):图像数据。
    • global_step(\(int\)):轮次。
    • dataformats(\(string\)):\(NCHW、NHWC、CHW、HWC、HW、WH\) 图像数据的格式。

    add_graph:查看模型图

    • model(\(torch.nn.Model\)):模型,必须是 nn.Module
    • input_to_model(\(torch.Tensor、torch.Tensor列表\)):输出给模型的数据。
    • verbose(\(bool\)):是否打印计算图结构信息。

    写完记得写 writer.close()



    2. 调用方法

    2.1 创建接口

    writer = SummaryWriter('runs')
    

    2.2 记录多个标量

    writer.add_scalars('name', {'dic': val}, epoch)
    

    2.3 统计直方图

    writer.add_histogram('weight', self.fc.weight, epoch)
    

    2.4 批次显示图像

    writer.add_images(“Cifar10”, img_batch, epoch, 'CHW')
    

    2.5 查看模型图

    writer.add_graph(model=net,input_to_model=torch.randn(1,3, 224, 224).to(device))
    


    来自:

    https://pytorch.org/docs/stable/tensorboard.html

    https://m.w3cschool.cn/article/27419536.html

  • 相关阅读:
    [面试没答上的问题1]http请求,请求头和响应头都有什么信息?
    模拟进度条发现的彩蛋
    实现JavaScript forEach
    JavaScript实现动画效果
    Contents Of My Blogs
    阅读笔记-拍出好照片的30个构图基本功
    阅读笔记-李鸿章传
    阅读笔记-人性的弱点
    阅读笔记-XWord:未来进化
    阅读笔记-活法
  • 原文地址:https://www.cnblogs.com/keye/p/16591786.html
Copyright © 2020-2023  润新知