• caffe Solve函数


    下面来看Solver<Dtype>::Solve(const char* resume_file)

    solver.cpp

    template <typename Dtype>
    void Solver<Dtype>::Solve(const char* resume_file) {
      CHECK(Caffe::root_solver());
      LOG(INFO) << "Solving " << net_->name();
      LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
    
      // Initialize to false every time we start solving.
      requested_early_exit_ = false;
    
      if (resume_file) {
        LOG(INFO) << "Restoring previous solver status from " << resume_file;
        // 从以前中断的训练状态中恢复训练
        Restore(resume_file);
      }
    
      // For a network that is trained by the solver, no bottom or top vecs
      // should be given, and we will just provide dummy vecs.
      int start_iter = iter_;
      // 主要的迭代过程都在这里
      Step(param_.max_iter() - iter_);
      // If we haven't already, save a snapshot after optimization, unless
      // overridden by setting snapshot_after_train := false
      if (param_.snapshot_after_train()
          && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
        Snapshot();
      }
      if (requested_early_exit_) {
        LOG(INFO) << "Optimization stopped early.";
        return;
      }
      // After the optimization is done, run an additional train and test pass to
      // display the train and test loss/outputs if appropriate (based on the
      // display and test_interval settings, respectively).  Unlike in the rest of
      // training, for the train net we only run a forward pass as we've already
      // updated the parameters "max_iter" times -- this final pass is only done to
      // display the loss, which is computed in the forward pass.
      if (param_.display() && iter_ % param_.display() == 0) {
        int average_loss = this->param_.average_loss();
        Dtype loss;
        net_->Forward(&loss);
    
        UpdateSmoothedLoss(loss, start_iter, average_loss);
    
        LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
      }
      if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
        TestAll();
      }
      LOG(INFO) << "Optimization Done.";
    }

    下面先看Solve中的Restore(resume_file)

    solver.cpp

    template <typename Dtype>
    void Solver<Dtype>::Restore(const char* state_file) {
      string state_filename(state_file);
      if (state_filename.size() >= 3 &&
          state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
        RestoreSolverStateFromHDF5(state_filename);
      } else {
        RestoreSolverStateFromBinaryProto(state_filename);
      }
    }

    上面的RestoreSolverStateFromHDF5(state_filename)和RestoreSolverStateFromBinaryProto(state_filename)都是虚函数,调用的其实是其派生类的同名方法。例如,若使用SGD求解,SGDSolver类中的RestoreSolverStateFromBinaryProto方法如下

    sgd_solver.cpp

    template <typename Dtype>
    void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
        const string& state_file) {
      SolverState state;
      ReadProtoFromBinaryFile(state_file, &state);
      // 此处获取上次训练中断时的迭代次数
      this->iter_ = state.iter();
      if (state.has_learned_net()) {
        NetParameter net_param;
        ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
        this->net_->CopyTrainedLayersFrom(net_param);
      }
      this->current_step_ = state.current_step();
      CHECK_EQ(state.history_size(), history_.size())
          << "Incorrect length of history blobs.";
      LOG(INFO) << "SGDSolver: restoring history";
      for (int i = 0; i < history_.size(); ++i) {
        history_[i]->FromProto(state.history(i));
      }
    }

    下面主要分析Solve中的Step(param_.max_iter() - iter_)

    solver.cpp

    template <typename Dtype>
    void Solver<Dtype>::Step(int iters) {
      const int start_iter = iter_;
      const int stop_iter = iter_ + iters;
      int average_loss = this->param_.average_loss();
      losses_.clear();
      smoothed_loss_ = 0;
      iteration_timer_.Start();
    
      while (iter_ < stop_iter) {
        // zero-init the params
        // 将网络中参数的梯度清零
        net_->ClearParamDiffs();
        if (param_.test_interval() && iter_ % param_.test_interval() == 0
            && (iter_ > 0 || param_.test_initialization())) {
          if (Caffe::root_solver()) {
            TestAll();
          }
          if (requested_early_exit_) {
            // Break out of the while loop because stop was requested while testing.
            break;
          }
        }
    
        for (int i = 0; i < callbacks_.size(); ++i) {
          callbacks_[i]->on_start();
        }
        const bool display = param_.display() && iter_ % param_.display() == 0;
        net_->set_debug_info(display && param_.debug_info());
        // accumulate the loss and gradient
        Dtype loss = 0;
        // param.iter_size_默认是1,正常情况下,此处其实只进行了以次前向和反向传播
        for (int i = 0; i < param_.iter_size(); ++i) {
          loss += net_->ForwardBackward();
        }
        loss /= param_.iter_size();
        // average the loss across iterations for smoothed reporting
        UpdateSmoothedLoss(loss, start_iter, average_loss);
        if (display) {
          float lapse = iteration_timer_.Seconds();
          float per_s = (iter_ - iterations_last_) / (lapse ? lapse : 1);
          LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
              << " (" << per_s << " iter/s, " << lapse << "s/"
              << param_.display() << " iters), loss = " << smoothed_loss_;
          iteration_timer_.Start();
          iterations_last_ = iter_;
          const vector<Blob<Dtype>*>& result = net_->output_blobs();
          int score_index = 0;
          for (int j = 0; j < result.size(); ++j) {
            const Dtype* result_vec = result[j]->cpu_data();
            const string& output_name =
                net_->blob_names()[net_->output_blob_indices()[j]];
            const Dtype loss_weight =
                net_->blob_loss_weights()[net_->output_blob_indices()[j]];
            for (int k = 0; k < result[j]->count(); ++k) {
              ostringstream loss_msg_stream;
              if (loss_weight) {
                loss_msg_stream << " (* " << loss_weight
                                << " = " << loss_weight * result_vec[k] << " loss)";
              }
              LOG_IF(INFO, Caffe::root_solver()) << "    Train net output #"
                  << score_index++ << ": " << output_name << " = "
                  << result_vec[k] << loss_msg_stream.str();
            }
          }
        }
        for (int i = 0; i < callbacks_.size(); ++i) {
          callbacks_[i]->on_gradients_ready();
        }
        // 网络的参数在此处更新。该函数是一个虚函数,具体由Solver的派生类来实现
        ApplyUpdate();
    
        // Increment the internal iter_ counter -- its value should always indicate
        // the number of times the weights have been updated.
        // 每次迭代其实是一个batch_size个样本输入网络中,将它们产生的网络参数的梯度加起来作为一次迭代的参数梯度。然后用这个梯度跟据一定的正则化方法、参数更新策略来更新参数
        ++iter_;
    
        SolverAction::Enum request = GetRequestedAction();
    
        // Save a snapshot if needed.
        if ((param_.snapshot()
             && iter_ % param_.snapshot() == 0
             && Caffe::root_solver()) ||
             (request == SolverAction::SNAPSHOT)) {
          Snapshot();
        }
        if (SolverAction::STOP == request) {
          requested_early_exit_ = true;
          // Break out of training loop.
          break;
        }
      }
    }

    上面的loss += net_->ForwardBackward()是训练过程的核心。这行代码的功能是取一个batch_size数据,让其在网络中进行一次前向传播,得出损失的均值;再进行一次反向传播,得出网络参数的梯度(该梯度是一个batch_size数据产生的梯度的均值)。详细分析见下一章节

  • 相关阅读:
    XAML实例教程系列 依赖属性和附加属性
    分享Silverlight/Windows8/WPF/WP7/HTML5周学习导读(6月4日6月10日)
    QT GUI基本布局
    mqtt client libraries for c
    QT sqlite相关操作
    navicat 激活工具激活时必须断网 ,如果没有断网激活 激活过程中报如下错误 请卸载navicat 重新安装再行激活操作
    vmware 16 windows7企业版 tools安装不了 驱动签名验证
    虚拟机复制
    Install systemtap on Ubuntu 12.04
    DevOps的各个阶段
  • 原文地址:https://www.cnblogs.com/pursuiting/p/8542763.html
Copyright © 2020-2023  润新知