• ONNX Runtime 源码阅读:Graph::SetGraphInputsOutputs() 函数


    前言

    为了深入理解ONNX Runtime的底层机制,本文将对 Graph::SetGraphInputsOutputs() 的代码逐行分析。

    正文

    首先判断Graph是否从ONNX文件中加载所得:

    if (is_loaded_from_model_file_) return Status::OK();
    

    如果是,可直接返回;如果不是,则需要解析Graph中的节点,从而设置模型的输入和输出。

    Graph中的成员变量 value_info_graph_inputs_excluding_initializers_graph_inputs_including_initializers_ 以及 graph_outputs_ 全部清空:

    value_info_.clear();
    
    graph_inputs_excluding_initializers_.clear();
    
    if (!graph_inputs_manually_set_) {
      graph_inputs_including_initializers_.clear();
    } else {
      std::unordered_set<std::string> existing_names;
      for (auto arg : graph_inputs_including_initializers_) {
        const std::string& name = arg->Name();
        if (existing_names.count(name) == 0) {
          graph_inputs_excluding_initializers_.push_back(arg);
          existing_names.insert(name);
        }
      }
    }
    
    if (!graph_outputs_manually_set_) {
      graph_outputs_.clear();
    }
    

    设置一些局部变量,方便下面的使用分析:

    std::unordered_map<std::string, size_t> output_name_to_node_arg_index;
    std::vector<const NodeArg*> output_node_args_in_order;
    std::unordered_set<std::string> added_input_names{outer_scope_node_arg_names_};
    

    统计所有节点的输出,添加到以上局部变量(output_name_to_node_arg_index 和 output_node_args_in_order)中:

    for (const auto& node : Nodes()) {
      for (const auto* output_def : node.OutputDefs()) {
        if (output_def->Exists()) {
          output_node_args_in_order.push_back(output_def);
          output_name_to_node_arg_index.insert({output_def->Name(), output_node_args_in_order.size() - 1});
        }
      }
    }
    auto graph_output_args = output_name_to_node_arg_index;  // 拷贝一份输出节点map
    

    然后遍历图中每个节点以及每个节点的输入:

    for (const auto& node : Nodes()) {
      // Go thru all node's inputs.
      for (const auto* input_arg : node.InputDefs()) {
        ...
      }
    }
    

    在输出节点name列表中查找当前输入name

    auto output_arg_iter = output_name_to_node_arg_index.find(input_arg->Name());
    

    如果没有找到,说明这个节点的输入就是图的输入,接下来还需要判断这个输入是否已经放在局部变量added_input_names中:

    if (output_name_to_node_arg_index.end() == output_arg_iter) {
      // This input arg is not the output of another node so must come from either a graph input or an initializer.
      const std::string& name = input_arg->Name();
      if (added_input_names.end() == added_input_names.find(name)) {
        ...
      }
    }
    

    如果已经放到局部变量added_input_names中,就可以判断节点的下一个输入或者下一个节点的输入。如果没有放到局部变量added_input_names中:

    bool is_initializer = name_to_initial_tensor_.find(name) != name_to_initial_tensor_.end();  // 判断当前input_arg是否已初始化过的tensor,如果是就不可以再放置到 graph_inputs_excluding_initializers_ 中
    if (!graph_inputs_manually_set_) {   // 如果未主动调用 SetInputs() 方法
      // if IR version < 4 all initializers must have a matching graph input
      // (even though the graph input is not allowed to override the initializer).
      // if IR version >= 4 initializers are not required to have a matching graph input.
      // any graph inputs that are to override initializers must be specified by calling SetInputs.
      if (!is_initializer || ir_version_ < 4) {
        graph_inputs_including_initializers_.push_back(input_arg);
      }
      if (!is_initializer) {
        // If input_arg is not of an initializer, we add it into graph_inputs_excluding_initializers_.
        graph_inputs_excluding_initializers_.push_back(input_arg);
      }
    } else {  // 如果主动调用了 SetInputs() 方法
      // graph_inputs_including_initializers_ has been manually populated by SetInputs.
      // Validation: the <input_arg> must be in graph inputs or initializers when it's manually set.
      if (!is_initializer) {
        const auto& inputs = graph_inputs_including_initializers_;
        bool in_inputs = std::find(inputs.begin(), inputs.end(), input_arg) != inputs.end();
        if (!in_inputs) {
          return Status(ONNXRUNTIME, FAIL,
                        name + " must be either specified in graph inputs or graph initializers.");
        }
      } else {
        // If arg_input is of an initializer, we remove it from graph_inputs_excluding_initializers_
        // whose initial content has both initializers and non-initializers.
        auto input_pos = std::find(graph_inputs_excluding_initializers_.begin(),
                                    graph_inputs_excluding_initializers_.end(),
                                    input_arg);
        if (input_pos != graph_inputs_excluding_initializers_.end()) {
          graph_inputs_excluding_initializers_.erase(input_pos);
        }
      }
    }
    added_input_names.insert(name);
    

    可以看到,这里会把当前的 input_arg 分别放到 graph_inputs_including_initializers_graph_inputs_excluding_initializers_ 中,并将name放在added_input_names中。

    如果该输入的name已经在输出节点name列表中,说明这个节点是中间输出结果,而非整个图的输出,因此应该将其从图的输出(graph_output_args)中删除,并放在 value_info_ 中:

    if (output_name_to_node_arg_index.end() == output_arg_iter) {
      ...
    }else if(graph_output_args.erase(output_arg_iter->first) >= 1){
      value_info_.insert(input_arg);
    }
    

    以上我们对Graph的三个成员变量:graph_inputs_including_initializers_graph_inputs_excluding_initializers_value_info_分别进行了赋值,其中前两者存储输入,后者存储中间结果。我们还需要处理图的输出结果:`graph_outputs_`:

    if (!graph_outputs_manually_set_) {
      // Set graph outputs in order.
      std::vector<size_t> graph_output_args_index;
      graph_output_args_index.reserve(graph_output_args.size());
      for (const auto& output_arg : graph_output_args) {          // graph_output_args原本存储了所有节点的输出,但是前面的代码已经把中间节点的输出给移除了,因此剩下的就是整个Graph的输出
        graph_output_args_index.push_back(output_arg.second);
      }
    
      std::sort(graph_output_args_index.begin(), graph_output_args_index.end());
      for (auto& output_arg_index : graph_output_args_index) {
        graph_outputs_.push_back(output_node_args_in_order[output_arg_index]);
      }
    }
    

    最后,还需要对 graph_overridable_initializers_ 进行处理:

    ComputeOverridableInitializers();
    

    进入这个函数内部:

    void Graph::ComputeOverridableInitializers() {
      graph_overridable_initializers_.clear();
      if (CanOverrideInitializer()) {
        // graph_inputs_excluding_initializers_ and graph_inputs_including_initializers_
        // are inserted in the same order. So we walk and compute the difference.
        auto f_incl = graph_inputs_including_initializers_.cbegin();
        const auto l_incl = graph_inputs_including_initializers_.cend();
        auto f_excl = graph_inputs_excluding_initializers_.cbegin();
        const auto l_excl = graph_inputs_excluding_initializers_.cend();
    
        while (f_incl != l_incl) {
          // Equal means not an initializer
          if (f_excl != l_excl && *f_incl == *f_excl) {
            ++f_incl;
            ++f_excl;
            continue;
          }
          graph_overridable_initializers_.push_back(*f_incl);
          ++f_incl;
        }
      }
    }
    

    这是一个很简单的算法,通过比较 graph_inputs_including_initializers_graph_inputs_excluding_initializers_,提取出 initializer 并放置到 graph_overridable_initializers_ 中。

    至此,我们完成了对 Graph::SetGraphInputsOutputs() 函数的解析。

    总结

    针对这个函数的解析不仅理解了如何从Graph的nodes中分析出graph的输入和输出,而且懂得了graph_overridable_initializers_以及value_info_的作用。

  • 相关阅读:
    【leetcode】1442. Count Triplets That Can Form Two Arrays of Equal XOR
    【leetcode】1441. Build an Array With Stack Operations
    【leetcode】1437. Check If All 1's Are at Least Length K Places Away
    cxCheckCombobox
    修改现有字段默认值
    2018.01.02 exprottoexcel
    Statusbar OwnerDraw
    dxComponentPrinter记录
    单据暂存操作思路整理
    设置模式9(装饰者,责任链,桥接,访问者)
  • 原文地址:https://www.cnblogs.com/xxxxxxxxx/p/16219014.html
Copyright © 2020-2023  润新知