• pytorch 如何使用tensorboard实时查看曲线---- tensorboardX简介


    用惯了tensorflow的小伙伴肯定都用过tensorboard工具吧。虽然Facebook也推出了visdom,但是在一次不小心误触clear之后,我放弃了这个工具(页面的一个clear按钮我本来是想按save的……它们俩一左一右,脑子一热按错了,点击之后clear之后不知道怎么找回曲线数据,真的崩溃)clear按钮示例

     

    说回pytorch使用tensorboard吧……

    1. 首先是安装。
      pip install tensorboardX

      这东西虽然是给pytorch用的,但是其实还是走的tensorboard那一套东西,所以你的环境里还需要有tensorflow。(cpu、gpu版本不限,随手装一个就好了)

    2. 调用
      from tensorboardX import SummaryWriter

      使用的就是SummaryWriter这个类。简单的使用可以直接使用SummaryWriter实例

      # before train
      log_writer = SummaryWriter('log_file_path')
      
      # in training
      log_writer.add_scalar('Train/Loss', loss.data[0], niter) 
      # in pytorch1.0 loss.data[0] should be loss.item()

      如果你不仅仅是需要记录一个loss这么简单,也可以对SummaryWriter做一个封装。

      class Tacotron2Logger(SummaryWriter):
          def __init__(self, logdir):
              super(Tacotron2Logger, self).__init__(logdir)
      
          def log_training(self, reduced_loss, grad_norm, learning_rate, duration,
                           iteration):
                  self.add_scalar("training.loss", reduced_loss, iteration)
                  self.add_scalar("grad.norm", grad_norm, iteration)
                  self.add_scalar("learning.rate", learning_rate, iteration)
                  self.add_scalar("duration", duration, iteration)
      
          def log_validation(self, reduced_loss, model, y, y_pred, iteration):
              self.add_scalar("validation.loss", reduced_loss, iteration)
              _, mel_outputs, gate_outputs, alignments = y_pred
              mel_targets, gate_targets = y
      
              # plot distribution of parameters
              for tag, value in model.named_parameters():
                  tag = tag.replace('.', '/')
                  self.add_histogram(tag, value.data.cpu().numpy(), iteration)
      
              # plot alignment, mel target and predicted, gate target and predicted
              idx = random.randint(0, alignments.size(0) - 1)
              self.add_image(
                  "alignment",
                  plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T),
                  iteration)
              self.add_image(
                  "mel_target",
                  plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()),
                  iteration)
              self.add_image(
                  "mel_predicted",
                  plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()),
                  iteration)
              self.add_image(
                  "gate",
                  plot_gate_outputs_to_numpy(
                      gate_targets[idx].data.cpu().numpy(),
                      F.sigmoid(gate_outputs[idx]).data.cpu().numpy()),
                  iteration)
      View Code

      这段代码是从NVIDIA tacotron2中摘取过来的。使用和前面一样,只不过把类名改一下就是了,调用的时候,按照你自己定义的类函数去调用就好了。基本功能还是都在的,画图画曲线什么的。没有visdom花里胡哨就是了。

      log_writer = Tacotron2Logger('log_file_path')
      
      log_writer.log_training(self, reduced_loss, grad_norm, learning_rate, duration, iteration)

      这样封装后,就不会在train的代码里很凌乱了。

    3. 网页查看,这个就回到tensorboard一样的操作了。
      tensorboard --logdir=./log_file_path --port=8765
      # log_file_path 是初始化log_writer时候的那个参数地址。
      # 这里端口号可以随意改,默认是6006。
    4. 然后命令行会告诉你在浏览器输入 ip:8765进行查看,这个和tensorboard一样了就。

     

  • 相关阅读:
    《人月神话》阅读笔记2
    【个人作业】单词链
    【个人作业】找水王
    【团队】 冲刺一(10/10)
    【团队】 冲刺一(9/10)
    【个人作业】单词统计续
    【团队】 冲刺一(8/10)
    【团队】 冲刺一(7/10)
    【团队】 冲刺一(6/10)
    【团队】 冲刺一(5/10)
  • 原文地址:https://www.cnblogs.com/chengebigdata/p/10121109.html
Copyright © 2020-2023  润新知