• OneFlow: 启动 Runtime


    前言

    我们前面介绍了从 Op 到 Job,又从 Job 到 Plan,这篇文章将会分析运行时(Runtime)启动,分析 Actor 是如何启动的。运行时启动的时机,发生在启动 Session 的时候,将 Job 编译成一个物理可以执行的 Plan 之后,就可以按照 Plan 启动运行时,启动 Actor 了。

    流程回顾

    运行时 Runtime 在什么时候启动的呢?在 Python 调用 StartLazyGlobalSession 的时候,在这个方法初始化全局 OneFlow 对象,将 JobSet 编译成 Plan,使用这个 Plan 启动 Runtime。

    Runtime 的初始化流程如下。我们知道 Plan 是物理上可以执行的计算图,Plan 中的节点 TaskProto 则对应计算图上的节点,一个 Task 对应一个 Actor。Runtime 启动的时候,调用 HandoutTasks 将 Task 分发出去,构造 Actor。

    // oneflow/core/job/runtime.cpp
    Runtime::Runtime(const Plan& plan, const HashMap<std::string, Blob*>& variable_op_name2eager_blob) {
      {
        // NOTE(chengcheng): All runtime Global objects AddPlan
        Global<RegstMgr>::Get()->AddPlan(plan, variable_op_name2eager_blob);
        Global<ThreadMgr>::Get()->AddPlan(plan);
        Global<RuntimeJobDescs>::Get()->AddPlan(plan);
        collective_boxing_executor_plan_token_ =
            Global<boxing::collective::CollectiveBoxingExecutor>::Get()->AddPlan(plan);
      }
      std::vector<const TaskProto*> source_tasks;
      std::vector<const TaskProto*> other_tasks;
      int64_t this_machine_task_num = 0;
      for (const TaskProto& task : plan.task()) {
        if (task.machine_id() != GlobalProcessCtx::Rank()) { continue; }
        if (!HasNonCtrlConsumedRegstDescId(task)) {
          source_tasks.push_back(&task);
        } else {
          other_tasks.push_back(&task);
        }
        auto it = job_id2actor_size_.find(task.job_id());
        if (it == job_id2actor_size_.end()) {
          auto emplace_ret_pair = job_id2actor_size_.emplace(task.job_id(), 0);
          CHECK(emplace_ret_pair.second);
          it = emplace_ret_pair.first;
        }
        it->second++;
        this_machine_task_num++;
      }
      RuntimeCtx* runtime_ctx = Global<RuntimeCtx>::Get();
      runtime_ctx->NewCounter("constructing_actor_cnt", this_machine_task_num);
      HandoutTasks(source_tasks);
      HandoutTasks(other_tasks);
      runtime_ctx->WaitUntilCntEqualZero("constructing_actor_cnt");
      LOG(INFO) << "Actors on this machine constructed";
      OF_SESSION_BARRIER();
      LOG(INFO) << "Actors on every machine constructed";
      for (auto pair : job_id2actor_size_) {
        runtime_ctx->NewCounter(GetRunningActorCountKeyByJobId(pair.first), pair.second);
      }
      SendCmdMsg(source_tasks, ActorCmd::kStart);
    }
    

    HandoutTasks 接受 Task 数组作为参数,这些 Task 将会逐个添加到对应的 Thread 里面,最后通过基于消息机制的 ActorMsgBus 发送构造指令来构造 Actor。

    // oneflow/core/job/runtime.cpp: 36
    void SendCmdMsg(const std::vector<const TaskProto*>& tasks, ActorCmd cmd) {
      for (const TaskProto* task : tasks) {
        ActorMsg msg = ActorMsg::BuildCommandMsg(task->task_id(), cmd);
        Global<ActorMsgBus>::Get()->SendMsg(msg);
      }
    }
    
    void HandoutTasks(const std::vector<const TaskProto*>& tasks) {
      for (const TaskProto* task : tasks) {
        Global<ThreadMgr>::Get()->GetThrd(task->thrd_id())->AddTask(*task);
      }
      SendCmdMsg(tasks, ActorCmd::kConstructActor);
    }
    

    ThreadMgr

    下面是 ThreadMgr 的头文件,提供了两个成员方法,和两个函数。前面在启动 Runtime 的时候,会将具有全局信息的 Plan 加入到 ThreadMgr 当中。

    • AddPlan 成员方法,初始化 Thread 对象。
    • GetThrd 成员方法,根据线程 id 获取对应的线程。一个线程对应多个 Actor,启动 Actor 的时候,已经给它分配好了一个线程,Actor 启动的时候去寻找对应的线程即可。
    • SingleThreadLoop,单线程循环调用一个函数
    • MultiThreadLoop,多线程多次调用一个函数。一共启动 thread_num 个线程,但是如果执行的次数多个线程数怎么办?OneFlow 中提供了一个 BalancedSplitter,将线程均匀分开。这个方法使用 BlockingCounter 来进行同步,初始化为线程数量,每执行完一个线程的内容,就减一。当所有的线程都执行完毕,才运行往下走。
    // oneflow/core/thread/thread_manager.h
    namespace oneflow {
    
    class Plan;
    
    class ThreadMgr final {
     public:
      OF_DISALLOW_COPY_AND_MOVE(ThreadMgr);
      ThreadMgr() = default;
      ~ThreadMgr();
    
      void AddPlan(const Plan& plan);
      Thread* GetThrd(int64_t thrd_id);
    
     private:
      friend class Global<ThreadMgr>;
    
      HashMap<int64_t, std::unique_ptr<Thread>> threads_;
    };
    
    void SingleThreadLoop(size_t num, std::function<void(size_t i)> Callback);
    void MultiThreadLoop(size_t num, std::function<void(size_t i)> Callback);
    
    #define REGISTER_DEVICE_THREAD_CREATOR_WITH_STREAM_ID(device, creator) 
      REGISTER_CLASS_CREATOR(int, device, Thread, creator, const StreamId&)
    
    }  // namespace oneflow
    
    // oneflow/core/thread/thread_manager.cpp: 28
    namespace oneflow {
    
    ThreadMgr::~ThreadMgr() {
      for (auto& thread_pair : threads_) {
        ActorMsg msg = ActorMsg::BuildCommandMsg(-1, ActorCmd::kStopThread);
        thread_pair.second->GetMsgChannelPtr()->Send(msg);
        thread_pair.second.reset();
        LOG(INFO) << "actor thread " << thread_pair.first << " finish";
      }
    }
    
    Thread* ThreadMgr::GetThrd(int64_t thrd_id) {
      auto iter = threads_.find(thrd_id);
      CHECK(iter != threads_.end()) << "thread " << thrd_id << " not found";
      return iter->second.get();
    }
    
    void ThreadMgr::AddPlan(const Plan& plan) {
      const int64_t this_rank = GlobalProcessCtx::Rank();
      for (const TaskProto& task : plan.task()) {
        TaskId task_id = DeserializeTaskIdFromInt64(task.task_id());
        StreamId stream_id = task_id.stream_id();
        if (stream_id.device_id().rank() != this_rank) { continue; }
        int64_t thrd_id = SerializeStreamIdToInt64(stream_id);
        if (threads_.find(thrd_id) != threads_.end()) { continue; }
        Thread* thread =
            NewObj<int, Thread, const StreamId&>(stream_id.device_id().device_type(), stream_id);
        CHECK_NOTNULL(thread);
        threads_[thrd_id].reset(thread);
      }
    }
    
    void SingleThreadLoop(size_t num, std::function<void(size_t i)> Callback) {
      FOR_RANGE(size_t, i, 0, num) { Callback(i); }
    }
    
    void MultiThreadLoop(size_t num, std::function<void(size_t i)> Callback) {
      size_t thread_num = Global<ThreadPool>::Get()->thread_num();
      thread_num = std::min(num, thread_num);
      BalancedSplitter bs(num, thread_num);
      BlockingCounter bc(thread_num);
      FOR_RANGE(size_t, range_id, 0, thread_num) {
        Global<ThreadPool>::Get()->AddWork([&bc, &bs, range_id, Callback] {
          FOR_RANGE(size_t, i, bs.At(range_id).begin(), bs.At(range_id).end()) { Callback(i); }
          bc.Decrease();
        });
      }
      bc.WaitUntilCntEqualZero();
    }
    
    }  // namespace oneflow
    

    Thread

    接下来考察一下 Thread 这个类,从接口来看,这个类提供的接口允许添加 Task,给 Actor 发送消息。从类成员来看,需要存储各种映射,存储线程对象和 mutex,存储当前线程 id,是否使用本地的消息队列,是否开启 light actor。

    // oneflow/core/thread/thread.h
    namespace oneflow {
    
    class Thread {
     public:
      OF_DISALLOW_COPY_AND_MOVE(Thread);
      virtual ~Thread();
    
      void AddTask(const TaskProto&);
    
      Channel<ActorMsg>* GetMsgChannelPtr() { return &msg_channel_; }
    
      inline void EnqueueActorMsg(const ActorMsg& msg) {
        if (UseLocalMsgQueue()) {
          local_msg_queue_.push(msg);
        } else {
          msg_channel_.Send(msg);
        }
      }
    
      template<typename InputIt>
      inline void EnqueueActorMsg(InputIt first, InputIt last) {
        if (UseLocalMsgQueue()) {
          for (auto it = first; it != last; ++it) { local_msg_queue_.push(*it); }
        } else {
          for (auto it = first; it != last; ++it) { msg_channel_.Send(*it); }
        }
      }
    
      void JoinAllActor() { actor_thread_.join(); }
    
     protected:
      Thread();
      std::thread& mut_actor_thread() { return actor_thread_; }
      void PollMsgChannel(const ThreadCtx& thread_ctx);
      void set_thrd_id(int64_t val) { thrd_id_ = val; }
    
     private:
      void ConstructActor(int64_t actor_id, const ThreadCtx& thread_ctx);
    
      inline bool UseLocalMsgQueue() const {
        return local_msg_queue_enabled_ && std::this_thread::get_id() == actor_thread_.get_id();
      }
    
      HashMap<int64_t, TaskProto> id2task_;
      std::mutex id2task_mtx_;
    
      std::thread actor_thread_;
      Channel<ActorMsg> msg_channel_;
      HashMap<int64_t, std::unique_ptr<ActorBase>> id2actor_ptr_;
      HashMap<int64_t, int64_t> id2job_id_;
      std::queue<ActorMsg> local_msg_queue_;
      bool local_msg_queue_enabled_;
      int64_t thrd_id_;
      bool light_actor_enabled_;
    };
    
    }  // namespace oneflow
    

    Thread 的方法是如何实现的呢?

    • AddTask,加锁,然后直接往映射的数据结构中加东西。
    • ConstructActor,根据 actor_id 构建 Actor,根据是否 light,调用不同的方法初始化。如果是 light,那么调用 TryNewLightActor。如果不是,那么调用 NewActor。构造完成之后,往映射的数据结构中加东西。
    • PollMsgChannel,这个方法非常重要!!它做了什么事呢?看名字,拉取消息。如果消息是 kCmdMsg 类型的,那么这是一条关于控制命令,启动或终止 Actor。如果不是,那么将会把这条消息发送给 Actor 去执行。那么 PollMsgChannel 由谁调用呢?
    // oneflow/core/thread/thread.cpp
    namespace oneflow {
    
    Thread::Thread() {
      local_msg_queue_enabled_ =
          ParseBooleanFromEnv("ONEFLOW_THREAD_ENABLE_LOCAL_MESSAGE_QUEUE", false);
      light_actor_enabled_ = ParseBooleanFromEnv("ONEFLOW_ACTOR_ENABLE_LIGHT_ACTOR", false);
    }
    
    Thread::~Thread() {
      actor_thread_.join();
      CHECK(id2task_.empty());
      msg_channel_.Close();
    }
    
    void Thread::AddTask(const TaskProto& task) {
      std::unique_lock<std::mutex> lck(id2task_mtx_);
      CHECK(id2task_.emplace(task.task_id(), task).second);
    }
    
    void Thread::PollMsgChannel(const ThreadCtx& thread_ctx) {
      while (true) {
        if (local_msg_queue_.empty()) {
          CHECK_EQ(msg_channel_.ReceiveMany(&local_msg_queue_), kChannelStatusSuccess);
        }
        ActorMsg msg = std::move(local_msg_queue_.front());
        local_msg_queue_.pop();
        if (msg.msg_type() == ActorMsgType::kCmdMsg) {
          if (msg.actor_cmd() == ActorCmd::kStopThread) {
            CHECK(id2actor_ptr_.empty());
            break;
          } else if (msg.actor_cmd() == ActorCmd::kConstructActor) {
            ConstructActor(msg.dst_actor_id(), thread_ctx);
            continue;
          } else {
            // do nothing
          }
        }
        int64_t actor_id = msg.dst_actor_id();
        auto actor_it = id2actor_ptr_.find(actor_id);
        CHECK(actor_it != id2actor_ptr_.end());
        int process_msg_ret = actor_it->second->ProcessMsg(msg);
        if (process_msg_ret == 1) {
          LOG(INFO) << "thread " << thrd_id_ << " deconstruct actor " << actor_id;
          auto job_id_it = id2job_id_.find(actor_id);
          const int64_t job_id = job_id_it->second;
          id2job_id_.erase(job_id_it);
          id2actor_ptr_.erase(actor_it);
          Global<RuntimeCtx>::Get()->DecreaseCounter(GetRunningActorCountKeyByJobId(job_id));
        } else {
          CHECK_EQ(process_msg_ret, 0);
        }
      }
    }
    
    void Thread::ConstructActor(int64_t actor_id, const ThreadCtx& thread_ctx) {
      std::unique_lock<std::mutex> lck(id2task_mtx_);
      auto task_it = id2task_.find(actor_id);
      std::unique_ptr<ActorBase> actor_ptr;
      const TaskProto& task = task_it->second;
      if (light_actor_enabled_) { actor_ptr = TryNewLightActor(task, thread_ctx); }
      if (!actor_ptr) {
        actor_ptr = NewActor(task, thread_ctx);
        LOG(INFO) << "Thread " << thrd_id_ << " construct Actor " << TaskType_Name(task.task_type())
                  << " " << actor_id;
      } else {
        LOG(INFO) << "Thread " << thrd_id_ << " construct LightActor "
                  << TaskType_Name(task.task_type()) << " " << actor_id;
      }
      CHECK(id2actor_ptr_.emplace(actor_id, std::move(actor_ptr)).second);
      CHECK(id2job_id_.emplace(actor_id, task.job_id()).second);
      id2task_.erase(task_it);
      Global<RuntimeCtx>::Get()->DecreaseCounter("constructing_actor_cnt");
    }
    
    }  // namespace oneflow
    

    搜索代码,看看哪些地方调用了 PollMsgChannel。

    • cpu_thread.cpp
    • gpu_thread.cpp

    两种方法的结构是类似的,通过 std::thread 来启动 PollMsgChannel,接着这个 Thread 将从消息队列中拉取消息,然后执行。那这些 CpuThread 和 GpuThread 又是如何启动的呢?在 ThreadMgr 的 AddPlan 里面!

    // oneflow/core/thread/cpu_thread.cpp
    namespace oneflow {
    
    CpuThread::CpuThread(int64_t thrd_id) {
      set_thrd_id(thrd_id);
      mut_actor_thread() = std::thread([this, thrd_id]() {
        OF_PROFILER_NAME_THIS_HOST_THREAD("CPU Actor : (" + std::to_string(thrd_id) + ")");
        ThreadCtx ctx;
    #ifdef WITH_CUDA
        ctx.cb_event_chan = nullptr;
    #endif  // WITH_CUDA
        PollMsgChannel(ctx);
      });
    }
    
    REGISTER_DEVICE_THREAD_CREATOR_WITH_STREAM_ID(DeviceType::kCPU,
                                                  ([](const StreamId& stream_id) -> Thread* {
                                                    return new CpuThread(
                                                        SerializeStreamIdToInt64(stream_id));
                                                  }));
    
    }  // namespace oneflow
    

    Actor

    前面分析了线程是如何产生的,线程运行的核心是 Actor。一个线程上有多个 Actor,线程通过轮询消息队列,然后将消息发送给不同的 Actor 来执行。真正干活的 Actor 是如何构造,如何执行的呢?

    Actor 的构造很简单,通过 TaskProto 上面的类型,去选择一个对应的 Actor 进行初始化。

    // oneflow/core/actor/actor_base.cpp
    std::unique_ptr<ActorBase> NewActor(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
      ActorBase* rptr = NewObj<int32_t, ActorBase>(task_proto.task_type());
      const auto& job_descs = *Global<RuntimeJobDescs>::Get();
      rptr->Init(&job_descs.job_desc(task_proto.job_id()), task_proto, thread_ctx);
      return std::unique_ptr<ActorBase>(rptr);
    }
    

    Actor 的执行通过 ProcessMsg 方法来进行。前面我们已经看到了线程会轮询消息队列来拉取消息,然后将消息发送给对应的 Actor 进行处理。下面的分析可能有点零碎,核心要抓住一点,如何从拿到消息,到启动 Kernel。

    • 线程通过轮询消息队列,拉取消息。将消息发送给 Actor 去处理,Actor 交给 msg_handler_ 处理。
    // 1: success, and actor finish
    // 0: success, and actor not finish
    int ProcessMsg(const ActorMsg& msg) override { return (this->*msg_handler_)(msg); }
    
    • msg_handler_ 可以被设置,不同的 Actor 可以设置 msg_handler_ 来处理。
      // Msg Handler
      void set_msg_handler(MsgHandler val) { msg_handler_ = val; }
    #define OF_SET_MSG_HANDLER(val)                                   
      do {                                                            
        LOG(INFO) << "actor " << actor_id() << " switch to " << #val; 
        set_msg_handler(static_cast<MsgHandler>(val));                
      } while (0)
    
    • NaiveActor 中设置 handler,设置了 HandlerNormal。其他各种各样的 Actor 都可以设置 handler 来设置消息不同的处理方法。
    void NaiveActor::VirtualActorInit(const TaskProto&) {
      OF_SET_MSG_HANDLER(&NaiveActor::HandlerNormal);
    }
    
    • NaiveActor 中设置的 HandlerNormal 在 Actor 中提供了实现,它调用了 ActUntilFail 来执行 Act 方法。
    // oneflow/core/actor/actor.cpp: 258
    int Actor::HandlerNormal(const ActorMsg& msg) {
      if (msg.msg_type() == ActorMsgType::kEordMsg) {
        remaining_eord_cnt_ -= 1;
        CHECK(eord_regst_desc_ids_.insert(msg.eord_regst_desc_id()).second);
        if (naive_consumed_rs_.HasRegstDescId(msg.eord_regst_desc_id())) {
          is_naive_consumed_eord_ = true;
        } else if (inplace_consumed_rs_.HasRegstDescId(msg.eord_regst_desc_id())) {
          is_inplace_consumed_eord_ = true;
        } else {
          NormalProcessCustomizedEordMsg(msg);
        }
      } else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
        if (msg.SrcMachineId() == GlobalProcessCtx::Rank()) {
          Regst* regst = msg.regst();
          if (naive_consumed_rs_.HasRegstDescId(regst->regst_desc_id())) {
            CHECK_EQ(0, naive_consumed_rs_.TryPushBackRegst(regst));
            const auto& rdeq = naive_consumed_rs_.RegstDeq4RegstDescId(regst->regst_desc_id());
            CHECK(rdeq.empty() == false);
            if (rdeq.front()->regst_desc()->regst_desc_type().has_data_regst_desc()) {
              NormalProcessNaiveReadableDataRegstMsg(rdeq);
            }
          } else if (inplace_consumed_rs_.HasRegstDescId(regst->regst_desc_id())) {
            CHECK_EQ(0, inplace_consumed_rs_.TryPushBackRegst(regst));
            int64_t out_regst_desc_id = inplace_regst_desc_id_in2out_.at(regst->regst_desc_id());
            CHECK(regst->GetSoleBlob()->dptr()
                  == inplace_produced_rs_.Front(out_regst_desc_id)->GetSoleBlob()->dptr());
          } else if (TryUpdtStateAsProducedRegst(regst) == 0) {
            // do nothing
          } else {
            NormalProcessCustomizedReadableRegstMsg(msg);
          }
        } else {
          if (NormalTryProcessReadableMsgFromOtherMachine(msg) == false) {
            // process ctrl msg from other rank
            if (IsConsumedCtrlRegstDescId(msg.regst_desc_id())) {
              Regst* regst = msg.regst();
              CHECK(naive_consumed_rs_.HasRegstDescId(msg.regst_desc_id()));
              CHECK(Global<RegstMgr>::Get()->HasProducerTaskId4RegstDescId(msg.regst_desc_id()));
              CHECK_EQ(0, naive_consumed_rs_.TryPushBackRegst(regst, msg.regst_desc_id()));
              const auto& rdeq = naive_consumed_rs_.RegstDeq4RegstDescId(msg.regst_desc_id());
              CHECK(rdeq.empty() == false);
            } else {
              CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst()), 0);
            }
          }
        }
        ActUntilFail();
      } else if (msg.msg_type() == ActorMsgType::kCmdMsg) {
        CHECK_EQ(msg.actor_cmd(), ActorCmd::kStart);
        ActUntilFail();
      } else {
        UNIMPLEMENTED();
      }
      // handler halts
      bool has_naive_or_inplace = naive_consumed_rs_.total_regst_desc_cnt() != 0
                                  || inplace_consumed_rs_.total_regst_desc_cnt() != 0;
      bool naive_or_inplace_eord_and_empty =
          (is_naive_consumed_eord_ || is_inplace_consumed_eord_)
          && (naive_consumed_rs_.available_regst_desc_cnt() == 0
              && inplace_consumed_rs_.available_regst_desc_cnt() == 0);
      bool customized_eord = IsCustomizedReadAlwaysUnReadyFromNow();
      if ((has_naive_or_inplace && naive_or_inplace_eord_and_empty)
          || (!has_naive_or_inplace && customized_eord)) {
        CHECK_EQ(naive_consumed_rs_.available_regst_desc_cnt(), 0);
        AsyncReturnAllCustomizedReadableRegst();
        AsyncSendEORDMsgForAllProducedRegstDesc();
        if (remaining_eord_cnt_ == 0 && total_reading_cnt_ == 0) {
          OF_SET_MSG_HANDLER(nullptr);
          return 1;
        } else {
          OF_SET_MSG_HANDLER(&Actor::HandlerZombie);
          return 0;
        }
      }
      return 0;
    }
    
    • 当读和写都准备好了之后,ActUntilFail 就会调用 Act 方法去执行。
    // oneflow/core/actor/actor.cpp
    void Actor::ActUntilFail() {
      while (IsReadReady() && IsWriteReady()) {
        Act();
    
        AsyncSendCustomizedProducedRegstMsgToConsumer();
        AsyncSendNaiveProducedRegstMsgToConsumer();
        AsyncSendInplaceProducedRegstMsgToConsumer();
    
        AsyncSendCustomizedConsumedRegstMsgToProducer();
        AsyncSendNaiveConsumedRegstMsgToProducer();
        AsyncRetInplaceConsumedRegstIfNoConsumer();
    
        AsyncSendQueuedMsg();
      }
      // NOTE(liujuncheng): return inplace consumed
      AsyncSendQueuedMsg();
    }
    
    • Act 方法中,将会启动 Kernel,名字叫异步,相对主线程是异步的,因为在这个线程上执行。不过对于当前线程来说,并不是异步的,它是一行一行执行下来的。
    void NaiveActor::Act() {
      KernelCtx kernel_ctx = GenDefaultKernelCtx();
      AsyncLaunchKernel(kernel_ctx, [&](int64_t regst_desc_id) -> Regst* { return nullptr; });
    }
    
    • 启动 ExecKernel,ExecKernel 是一个包含了计算信息、存储信息的结构体。Kernel 启动的时候,需要传入 context,还有一个函数体。这个函数的作用是?
    // oneflow/core/actor/actor.h: 58
    struct ExecKernel {
      std::unique_ptr<const Kernel> kernel;
      HashMap<std::string, BlobInfo> bn_in_op2blob_info;
    };
    
    // oneflow/core/actor/actor.cpp: 470
    void Actor::AsyncLaunchKernel(const KernelCtx& kernel_ctx,
                                  std::function<Regst*(int64_t)> Regst4RegstDescId) {
      for (const ExecKernel& ek : exec_kernel_vec_) {
        ek.kernel->Launch(kernel_ctx, [&](const std::string& bn_in_op) -> Blob* {
          const auto blob_info_it = ek.bn_in_op2blob_info.find(bn_in_op);
          if (blob_info_it == ek.bn_in_op2blob_info.cend()) { return nullptr; }
          const BlobInfo& info = blob_info_it->second;
          if (info.regst_desc_id == -1) { return nullptr; }
          Regst* regst;
          if (info.rs != nullptr) {
            regst = info.rs->Front(info.regst_desc_id);
          } else {
            regst = Regst4RegstDescId(info.regst_desc_id);
          }
          if (regst == nullptr) { return nullptr; }
          if (info.ordinal >= 0) {
            return regst->GetBlobByOrdinal(info.ordinal);
          } else {
            return regst->GetBlobByLbi(info.lbi);
          }
        });
      }
    }
    
    • Kernel Launch -> Forward -> ForwardDataContent。ForwardHeader 应该是做输入的检查。ForwardDataContent 会调用计算的方法。
    // oneflow/core/kernel/kernel.cpp: 43
    void Kernel::Launch(const KernelCtx& ctx,
                        const std::function<Blob*(const std::string&)>& BnInOp2Blob) const {
      Forward(ctx, BnInOp2Blob);
    }
    
    void Kernel::Forward(const KernelCtx& ctx,
                         const std::function<Blob*(const std::string&)>& BnInOp2Blob) const {
      if (!blob_access_checker_disabled_) { SetOutputBlobProducerInferAccessChecker(BnInOp2Blob); }
      ForwardHeader(ctx, BnInOp2Blob);
      if ((!kernel_conf_.all_blobs_are_static())
          && IsAllBlobEmpty(op_attribute().output_bns(), BnInOp2Blob) && IsStateless()) {
        return;
      }
      if (!blob_access_checker_disabled_) { SetOutputBlobProducerComputeAccessChecker(BnInOp2Blob); }
      OF_PROFILER_ONLY_CODE(profiler::TraceKernelForwardDataContentStart(this, ctx, BnInOp2Blob));
      ForwardDataContent(ctx, BnInOp2Blob);
      OF_PROFILER_ONLY_CODE(profiler::TraceKernelForwardDataContentEnd(this, ctx, BnInOp2Blob));
      if (!blob_access_checker_disabled_) { SetOutputBlobConsumerAccessChecker(BnInOp2Blob); }
    }
    
    • ForwardDataContent 是 Kernel 提供的虚函数,每个子类实现不一样。UserKernel 和 OpKernel 用于定义扩展算子,OpKernel 中提供了 Compute 虚函数用于计算,需要注意的是 OpKernel 其实并没有继承 Kernel,OpKernel 作为 UserKernel 的一个成员存在。当调用 ForwardDataContent 的时候,它会调用 ForwardUserKernel,进而调用 OpKernel 的计算函数 Compute。
    void UserKernel::ForwardDataContent(
        const KernelCtx& ctx, const std::function<Blob*(const std::string&)>& BnInOp2Blob) const {
      ForwardUserKernel(BnInOp2Blob, opkernel_state_.get());
    }
    
    void UserKernel::ForwardUserKernel(const std::function<Blob*(const std::string&)>& BnInOp2Blob,
                                       user_op::OpKernelState* opkernel_state) const {
      const bool updated = ctx_->UpdateTensorWithCorrBlob(BnInOp2Blob);
    
    #ifdef WITH_CUDA_GRAPHS
      bool capturing = false;
      if (cuda_graph_ctx_) {
        if (!cuda_graph_ctx_->IsCapturing()) {
          if (cuda_graph_ctx_->IsCaptured() && (!updated)) {
            cuda_graph_ctx_->Launch();
            return;
          }
          capturing = true;
          cuda_graph_ctx_->BeginCapture();
        }
      }
    #endif  // WITH_CUDA_GRAPHS
    
      kernel_->Compute(ctx_.get(), opkernel_state);
    
    #ifdef WITH_CUDA_GRAPHS
      if (cuda_graph_ctx_ && capturing) {
        cuda_graph_ctx_->EndCapture();
        cuda_graph_ctx_->Launch();
      }
    #endif  // WITH_CUDA_GRAPHS
    }
    
    • Compute 是如何计算的呢?下面随便找一个 Kernel 来看看,我随便找了个 CpuAddKernel,关注下面的 Compute 函数。它的主要工作是从 KernelComputeContext 取出输入和输出的指针,最后调用 cpu_add 将所有的输入加到输出上。自此我们终于完成了一次 Kernel 的计算。
    // oneflow/user/kernels/add_n_kernel.cpp: 22
    template<typename T>
    void cpu_add(const int64_t n, T* out, const std::vector<const T*>& in) {
      for (int64_t i = 0; i != n; ++i) {
        out[i] = in.at(0)[i];
        for (int32_t j = 1; j < in.size(); ++j) { out[i] += in.at(j)[i]; }
      }
    }
    
    // oneflow/user/kernels/add_n_kernel.cpp: 32
    template<typename T>
    class CpuAddNKernel : public user_op::OpKernel {
     public:
      CpuAddNKernel() = default;
      ~CpuAddNKernel() = default;
    
      bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
    
     private:
      void Compute(user_op::KernelComputeContext* ctx) const override {
        size_t in_num = ctx->inputs().size();
    
        user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
        int64_t n = out->shape().elem_cnt();
        T* out_dptr = out->mut_dptr<T>();
    
        std::vector<const T*> in_dptrs(in_num);
        for (int32_t i = 0; i < in_num; ++i) {
          in_dptrs.at(i) = ctx->Tensor4ArgNameAndIndex("in", i)->dptr<T>();
        }
    
        cpu_add<T>(n, out_dptr, in_dptrs);
      }
    };
    

    总结

    这篇文章从 Runtime 启动开始,讲了如何启动线程,启动 Actor。线程通过轮询消息队列拉取消息,将消息转发给对应的 Actor 去执行。Actor 将启动 Kernel,Kernel 从 KernelComputeContext 获取输入和输出的信息,最后执行运算。

  • 相关阅读:
    KMP算法理解
    vimium 快捷键
    如何选择优化器 optimizer
    用python实现归并排序
    用python实现快速排序
    用python实现插入排序
    使用PyCharm进行远程开发和调试
    查看python iterpreter的路径和当前选择的解释器
    grid search
    一些书单
  • 原文地址:https://www.cnblogs.com/zzk0/p/15226851.html
Copyright © 2020-2023  润新知