• [源码解析] TensorFlow 分布式环境(4) WorkerCache


    [源码解析] TensorFlow 分布式环境(4) --- WorkerCache

    我们接下来介绍缓存机制。为什么要缓存?因为集群内部有众多 worker。在 Master 与 Worker 之间,Worker 和 Worker 之间都需要交互,所以有必要把 Worker 和其 Grpc 通道都缓存起来。可以说,在 TensorFlow 分布式环境下处处可见缓存的使用。

    本系列其他文章是:

    [翻译] TensorFlow 分布式之论文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

    [翻译] TensorFlow 分布式之论文篇 "Implementation of Control Flow in TensorFlow"

    [源码解析] TensorFlow 分布式环境(1) --- 总体架构

    [源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑

    [源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑

    1. WorkerCache

    WorkerCache 的作用就是获取 WorkerInterface 实例,WorkerInterface 实例可以访问远端 WorkerSerivice 服务。WorkerInterface 实例的典型就是 GrpcRemoteWorker。

    1.1 如何使用

    前面初始化 MasterEnv 时,WorkerCacheFactory 被配置到 master_env_.worker_cache_factory 之中。

    master_env_.worker_cache_factory =
        [this](const WorkerCacheFactoryOptions& options,
               WorkerCacheInterface** worker_cache) {
          return WorkerCacheFactory(options, worker_cache);
        };
    

    后续在 Master::CreateSession 之中,有如下删减版代码,从中可以知道如何从工厂类之中获取 worker_cache(WorkerCacheInterface实例),以及后续如何使用 worker_cache 进行操作。

    void Master::CreateSession(const CreateSessionRequest* req,
                               CreateSessionResponse* resp, MyClosure done) {
      SchedClosure([this, req, resp, done]() {
          // 配置option
          WorkerCacheFactoryOptions worker_cache_factory_options;
          worker_cache_factory_options.protocol = &grpc_protocol;
          worker_cache_factory_options.rpc_options = &req->config().rpc_options();
        
          // 建立 worker_cache
          // Create the worker cache from the computed server_def.
          status = env_->worker_cache_factory(worker_cache_factory_options,
                                              &worker_cache);
    
          // 使用 worker_cache 来完成后续操作
          status =
              DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
                                             worker_cache, remote_devices.get());
    
      });
    }
    

    1.2 配置

    WorkerCacheFactoryOptions 等价于 ServerDef,它包含 ClusterDef,job_name,task_index 等信息。

    // Options passed to the worker_cache_factory function.
    struct WorkerCacheFactoryOptions {
      const ClusterDef* cluster_def = nullptr;
      const string* job_name = nullptr;
      int task_index;
      const string* protocol = nullptr;
      const RPCOptions* rpc_options = nullptr;
    
      WorkerCacheFactoryOptions() {}
    
      // Construct from a ServerDef proto.
      //
      // Note: server_def must outlive WorkerCacheFactoryOptions!
      WorkerCacheFactoryOptions(const ServerDef& server_def) {
        if (server_def.has_cluster() && !server_def.job_name().empty()) {
          cluster_def = &server_def.cluster();
          job_name = &server_def.job_name();
          task_index = server_def.task_index();
          protocol = &server_def.protocol();
          rpc_options = &server_def.default_session_config().rpc_options();
        }
      }
    };
    

    1.3 工厂类

    WorkerCacheFactory 是一个函数,其作用如下:

    • 使用 ParseChannelSpec 来得到 GrpcChannelSpec 实例,GrpcChannelSpec 等价于 ClusterSpec,其包含集群基本配置信息。
    • 使用 NewGrpcChannelCache 拿到一个GrpcChannelCache channel_cache。这里使用到了 GetChannelCreationFunction。
    • 使用 NewGrpcWorkerCacheWithLocalWorker(channel_cache) 得到 worker_cache。
    Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
                                          WorkerCacheInterface** worker_cache) {
    
      // 得到 GrpcChannelSpec
      GrpcChannelSpec channel_spec;
      TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
    
      // 得到 GrpcChannelCache
      std::shared_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache(
          channel_spec, GetChannelCreationFunction(), *options.rpc_options));
    
      string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
                                           "/task:", options.task_index);
    
      const string host_port = channel_cache->TranslateTask(name_prefix);
      int requested_port;
    
      auto colon_index = host_port.find_last_of(':');
      if (!strings::safe_strto32(host_port.substr(colon_index + 1),
                                 &requested_port)) {
        return errors::Internal("Could not parse port for local server from \"",
                                host_port, "\".");
      }
      if (requested_port != bound_port_) {
        return errors::InvalidArgument("Requested port ", requested_port,
                                       " differs from expected port ", bound_port_);
      }
      // 得到 Worker Cache
      *worker_cache = NewGrpcWorkerCacheWithLocalWorker(
          channel_cache, grpc_worker_env(), worker_impl(), name_prefix);
      return Status::OK();
    }
    

    1.3.1 ParseChannelSpec

    ParseChannelSpec 被用来得到 GrpcChannelSpec 实例,GrpcChannelSpec 等价于 ClusterSpec,其包含集群基本配置信息。

    Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
                                        GrpcChannelSpec* channel_spec) {
      for (const auto& job : options.cluster_def->job()) {
        std::map<int, string> host_ports;
        for (const auto& task : job.tasks()) {
          string& host_port = host_ports[task.first];
          if (!host_port.empty()) {
            return errors::InvalidArgument("JobDef for job \"", job.name(),
                                           "\" specified two addresses for task \"",
                                           task.first, "\": ", host_port, " and ",
                                           task.second);
          }
          if (job.name() == *options.job_name && task.first == options.task_index) {
            host_port = strings::StrCat(host_name_, ":", bound_port_);
          } else {
            host_port = task.second;
          }
        }
        TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports));
      }
      return Status::OK();
    }
    

    1.3.2 NewGrpcChannelCache

    NewGrpcChannelCache 用于创建 GrpcChannelCache 实例,可以看到,每个 Job 对应了一个 SparseGrpcChannelCache。如果只有一个 SparseGrpcChannelCache,则直接返回,否则把这些 SparseGrpcChannelCache 组合在一起构建一个 MultiGrpcChannelCache 返回。其中传入的channel_func 是 GetChannelCreationFunction。我们后续会介绍。

    GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec,
                                          ChannelCreationFunction channel_func,
                                          const RPCOptions& options) {
      const int num_jobs = spec.host_ports_jobs().size();
      if (!num_jobs) {
        return nullptr;
      }
      std::vector<GrpcChannelCache*> caches;
      caches.reserve(num_jobs);
      for (auto& job : spec.host_ports_jobs()) {
        caches.push_back(
            new SparseGrpcChannelCache(job.job_id, job.host_ports, channel_func,
                                       options.num_channels_per_target()));
      }
      return caches.size() == 1 ? caches[0]
                                : new MultiGrpcChannelCache(
                                      caches, options.num_channels_per_target());
    }
    

    1.3.3 NewGrpcWorkerCacheWithLocalWorker

    NewGrpcWorkerCacheWithLocalWorker 方法创建 GrpcWorkerCache 实例。

    WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
        std::shared_ptr<GrpcChannelCache> cc, GrpcWorkerEnv* worker_env,
        WorkerInterface* local_worker, const string& local_target) {
      return new GrpcWorkerCache(cc, local_worker, local_target, worker_env);
    }
    

    local_worker 参数是通过 worker_impl() 得到并且传入的,其生成是在 GrpcServer::Init 之中,就是本地的 GrpcWorker。

    GrpcWorker* worker_impl() const { return worker_impl_.get(); }
    
    std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env,
                                              const ConfigProto& config) {
      return std::unique_ptr<GrpcWorker>(new GrpcWorker(env, config));
    }
    
    Status GrpcServer::Init(const GrpcServerOptions& opts) {
      
        // 省略
      
        worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
                                      : NewGrpcWorker(&worker_env_, config);
      
      	// 省略
    }  
    

    我们梳理一下工厂类目前流程,可以看到,最开始输入是 WorkerCacheFactoryOptions,然后一步一步的通过各个函数的处理,最后生成了 GrpcWorkerCache。

    图 1 工厂类流程

    1.4 WorkerCacheInterface

    1.4.1 接口

    WorkerCacheInterface 是接口类,上面图之中 GrpcWorkerCache 就是这个接口的派生类。

    class WorkerCacheInterface {
     public:
      virtual ~WorkerCacheInterface() {}
    
      // Updates *workers with strings naming the remote worker tasks to
      // which open channels have been established.
      virtual void ListWorkers(std::vector<string>* workers) const = 0;
      virtual void ListWorkersInJob(const string& job_name,
                                    std::vector<string>* workers) const = 0;
    
      // If "target" names a remote task for which an RPC channel exists
      // or can be constructed, returns a pointer to a WorkerInterface object
      // wrapping that channel. The returned value must be destroyed by
      // calling `this->ReleaseWorker(target, ret)`
      virtual WorkerInterface* GetOrCreateWorker(const string& target) = 0;
    
      // Release a worker previously returned by this->GetOrCreateWorker(target).
      //
      // TODO(jeff,sanjay): Consider moving target into WorkerInterface.
      // TODO(jeff,sanjay): Unify all worker-cache impls and factor out a
      //                    per-rpc-subsystem WorkerInterface creator.
      virtual void ReleaseWorker(const string& target, WorkerInterface* worker) {
        // Subclasses may override to reuse worker objects.
        delete worker;
      }
    
      // Set *locality with the DeviceLocality of the specified remote device
      // within its local environment.  Returns true if *locality
      // was set, using only locally cached data.  Returns false
      // if status data for that device was not available.  Never blocks.
      virtual bool GetDeviceLocalityNonBlocking(const string& device,
                                                DeviceLocality* locality) = 0;
    
      // Set *locality with the DeviceLocality of the specified remote device
      // within its local environment.  Callback gets Status::OK if *locality
      // was set.
      virtual void GetDeviceLocalityAsync(const string& device,
                                          DeviceLocality* locality,
                                          StatusCallback done) = 0;
    
      // TODO(b/189159585): Define a general client cache maker function to
      // construct client cache of different types sharing the same underling RPC
      // channels, to replace the eager and coordination cache function.
      // Build and return a EagerClientCache object wrapping that channel.
      virtual Status GetEagerClientCache(
          std::unique_ptr<eager::EagerClientCache>* eager_client_cache) = 0;
    
      // Build and return a CoordinationClientCache object wrapping that channel.
      virtual Status GetCoordinationClientCache(
          std::unique_ptr<CoordinationClientCache>* coordination_client_cache) = 0;
    
      // Start/stop logging activity.
      virtual void SetLogging(bool active) {}
    
      // Discard any saved log data.
      virtual void ClearLogs() {}
    
      // Return logs for the identified step in *ss.  Any returned data will no
      // longer be stored.
      virtual bool RetrieveLogs(int64_t step_id, StepStats* ss) { return false; }
    };
    

    WorkerCachePartial 又继承了 WorkerCacheInterface。

    // Implements the part of the interface that caches and returns remote
    // device status attributes.
    class WorkerCachePartial : public WorkerCacheInterface {
     public:
      bool GetDeviceLocalityNonBlocking(const string& device,
                                        DeviceLocality* locality) override;
    
      void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
                                  StatusCallback) override;
    
      ~WorkerCachePartial() override {}
    
      // Clear all entries from the DeviceStatus cache.
      void FlushStatusCache();
    
     private:
      mutex mu_;
    
      // Initiate a GetStatusAsync to the remote task named by "task", and
      // update the cache with all the DeviceAttributes reported.
      Status RefreshDeviceStatus(const string& device_name);
    
      typedef std::unordered_map<string, DeviceAttributes> StatusMap;
      StatusMap device_status_cache_ TF_GUARDED_BY(mu_);
    };
    

    1.4.2 GrpcWorkerCache

    GrpcWorkerCache 则继承了 WorkerCachePartial。

    class GrpcWorkerCache : public WorkerCachePartial {
     public:
      explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,
                               WorkerInterface* local_worker,
                               const string& local_target,
                               GrpcWorkerEnv* worker_env)
          : local_target_(local_target),
            local_worker_(local_worker),
            channel_cache_(channel_cache),
            worker_env_(worker_env),
            next_round_robin_assignment_(0) {}
    
      const string local_target_;
      WorkerInterface* const local_worker_;  // Not owned.
      std::shared_ptr<GrpcChannelCache> channel_cache_;
      WorkerCacheLogger logger_;
      GrpcWorkerEnv* worker_env_;  // Not owned
    
      mutex assignment_mu_;
      std::unordered_map<std::string, size_t> target_assignments_
          TF_GUARDED_BY(assignment_mu_);
      size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
    };
    

    其主要功能是使用 ListWorkers 罗列出集群内所有 worker 的名字。

    void ListWorkers(std::vector<string>* workers) const override {
      channel_cache_->ListWorkers(workers);
    }
    
    void ListWorkersInJob(const string& job_name,
                            std::vector<string>* workers) const override {
    	channel_cache_->ListWorkersInJob(job_name, workers);
    }
    

    GetOrCreateWorker 会根据 Worker 的 RPC 通道建立 worker,如果是本地,则直接返回 local_worker_,就是我们前面设置的本地 GrpcWorker。

    WorkerInterface* GetOrCreateWorker(const string& target) override {
      if (target == local_target_) {
        return local_worker_;
      } else {
        SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
        if (!channel) {
          return nullptr;
        }
        size_t index = AssignWorkerToThread(target);
        return NewGrpcRemoteWorker(
            channel, worker_env_->GetCompletionQueue(index),
            worker_env_->GetThreadPool(), &logger_, target);
      }
    }
    

    2. RPC 通道

    Worker 运行在 RPC 通道之上,所以我们接下来看看如何建立这个 RPC 通道。因为 Worker 有缓存,同样的,RPC 通道也有缓存。GrpcChannelCache 就是这个缓存,其被用来获取/创建集群之中远端 Worker 的 RPC 通道。

    2.1 GrpcChannelCache 接口

    GrpcChannelCache 是接口类,定义了一系列接口,比如:

    • ListWorkers 可以返回集群之中的 Worker 名称。
    • TranslateTask :把 Worker 名字 转换为地址信息,格式是 host:port。
    • FindWorkerChannel :从缓存中查找 grpc::Channel 实例,如果缓存之中没有,就依据地址信息动态生成一个实例,再将其放入缓存。
    class GrpcChannelCache {
     public:
      virtual ~GrpcChannelCache() {}
    
      // Populates *workers with names of all workers which this object
      // was created to handle.  Worker names are in the format
      //  /job:<job identifier>/task:<task id>
      // e.g. /job:mnist/task:2
      virtual void ListWorkers(std::vector<string>* workers) = 0;
      virtual void ListWorkersInJob(const string& job_name,
                                    std::vector<string>* workers) = 0;
    
      // If found, returns a gRPC channel that is connected to the remote
      // worker named by 'target'. 'target' is of the following
      // format: /job:<job identifier>/task:<task id>
      // E.g., /job:mnist/task:2
      virtual SharedGrpcChannelPtr FindWorkerChannel(const string& target) = 0;
    
      // Translates a string in the form `/job:X/task:Z` into a host_port.
      virtual string TranslateTask(const string& task) = 0;
    };
    

    2.2 缓存机制

    CachingGrpcChannelCache 是缓存类,可以避免每次创建 grpc::Channel 的开销。其定义如下,具体就是派生了 GrpcChannelCache 的 GenericCachingChannelCache。

    // GrpcChannelCache that caches results to FindWorkerChannel() calls.
    using CachingGrpcChannelCache = GenericCachingChannelCache<GrpcChannelCache>;
    

    GenericCachingChannelCache,用于缓存FindWorkerChannel()调用的结果,首先从缓存中查找 grpc::Channel 实例,如果缓存之中没有,就依据地址信息调用 FindChannelOnce 动态生成一个实例,再将其放入缓存。

    GenericCachingChannelCache 允许使用多个通道与同一目标通信以提高吞吐量。当同一目标存在多个通道时,每次调用FindWorkerChannel时,都会以 round robin 循环方式选择这些通道。

    注意,因为有如下定义,所以 absl::flat_hash_map<string, ChannelState> channels_ 就是 ::grpc::Channel 缓存 集合。

    typedef std::shared_ptr<::grpc::Channel> SharedGrpcChannelPtr;
    

    具体代码是:

    template <typename ChannelCacheT>
    class GenericCachingChannelCache : public ChannelCacheT {
     public:
      explicit GenericCachingChannelCache(int num_channels_per_target)
          : num_channels_per_target_(
                num_channels_per_target > 0 ? num_channels_per_target : 1) {}
    
      ~GenericCachingChannelCache() override {}
    
      SharedGrpcChannelPtr FindWorkerChannel(const string& target) override {
        {
          mutex_lock l(mu_);
          auto iter = channels_.find(target);
          if (iter != channels_.end()) {
            return GetNextChannelPtrAndUpdateState(iter->second);
          }
        }
        ChannelState new_chan_state;
        for (int indx = 0; indx < num_channels_per_target_; indx++) {
          auto ch = FindChannelOnce(target);
          if (!ch) return nullptr;
          new_chan_state.channels.push_back(ch);
        }
        new_chan_state.last_used = num_channels_per_target_ - 1;
    
        {
          mutex_lock l(mu_);
          typename absl::flat_hash_map<string, ChannelState>::iterator iter;
          bool was_inserted;
          std::tie(iter, was_inserted) = channels_.insert({target, new_chan_state});
          return GetNextChannelPtrAndUpdateState(iter->second);
        }
      }
    
     protected:
      // Find the ClientChannel for "target".  Only called when no channel was
      // found in the channels_ cache for "target".  A non nullptr result will be
      // cached in channels_.
      virtual SharedGrpcChannelPtr FindChannelOnce(const string& target) = 0;
    
     private:
      struct ChannelState {
        std::vector<SharedGrpcChannelPtr> channels; 
        int last_used;
      };
    
      // Should be called with mu_ held.
      SharedGrpcChannelPtr GetNextChannelPtrAndUpdateState(
          ChannelState& chan_state) {
        // Following statement is marked as Crash OK as this is an invariant of
        // code flow in this class.
        CHECK_EQ(chan_state.channels.size(), num_channels_per_target_);  // Crash OK
        chan_state.last_used =
            (chan_state.last_used + 1) % num_channels_per_target_;
        return chan_state.channels[chan_state.last_used];
      }
    
      const int num_channels_per_target_;
      // TODO(zhifengc): Eviction when the map becomes too big.
      mutex mu_;
      absl::flat_hash_map<string, ChannelState> channels_ TF_GUARDED_BY(mu_);
    };
    

    2.3 业务派生类

    从 CachingGrpcChannelCache 又派生出了两个类,具体如下:

    2.3.1 叶子节点

    SparseGrpcChannelCache 是叶子结点,集群之中每个 Job 对应了一个 SparseGrpcChannelCache,SparseGrpcChannelCache 内部的 grpc::Channel 集合就是 Job 的 Task 对应的 grpc::Channel 集合,每个 Task 对应一个 grpc::Channel 。

    SparseGrpcChannelCache 主要变量如下:

    • const string job_id_ :本类对应了哪一个 Job。
    • const std::map<int, string> host_ports_ :本 Job 对应 Task 的 host:port 列表。
    • const ChannelCreationFunction channel_func_ :生成 grpc:Channel 的方法。

    SparseGrpcChannelCache 主要功能如下:

    • ListWorkers :该方法返回本 Job 对应的 Task 名称列表。
    • TranslateTask:依据某个 Task 名字来得到其地址信息(格式为host:port ),例如, /job:ps/replica:1/task:1 的地址可能就是 ps1:1111;
    • FindChannelOnce :依据某个 Task 名字来创建对应的 grpc::Channel。具体是先通过 TranslateTask 获取到 worker 对应的 task id,然后得到地址信息,最后用地址信息来构建 grpc::Channel。
    class SparseGrpcChannelCache : public CachingGrpcChannelCache {
     public:
      SparseGrpcChannelCache(const string& job_id,
                             const std::map<int, string>& host_ports,
                             ChannelCreationFunction channel_func,
                             int num_channels_per_target)
          : CachingGrpcChannelCache(num_channels_per_target),
            job_id_(job_id),
            host_ports_(host_ports),
            channel_func_(std::move(channel_func)) {
      }
      ~SparseGrpcChannelCache() override {}
    
      void ListWorkers(std::vector<string>* workers) override {
        workers->reserve(workers->size() + host_ports_.size());
        for (const auto& id_host_port : host_ports_) {
          workers->emplace_back(MakeAddress(job_id_, id_host_port.first));
        }
      }
    
      void ListWorkersInJob(const string& job_name,
                            std::vector<string>* workers) override {
        if (job_name == job_id_) {
          ListWorkers(workers);
        }
      }
    
      string TranslateTask(const string& target) override {
        DeviceNameUtils::ParsedName parsed;
        if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
          return "";
        }
    
        if (!parsed.has_job || parsed.job != job_id_) {
          return "";
        }
        if (!parsed.has_replica || parsed.replica != 0) {
          return "";
        }
        int32_t task = parsed.has_task ? parsed.task : -1;
        auto iter = host_ports_.find(task);
        if (iter == host_ports_.end()) {
          return "";
        }
        return iter->second;
      }
    
     protected:
      SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
        const string host_port = TranslateTask(target);
        if (host_port.empty()) {
        if (host_port.empty()) {
          return nullptr;
        }
        auto chan_ptr = channel_func_(host_port);
        return chan_ptr;
      }
    
     private:
    
      const string job_id_;
      const std::map<int, string> host_ports_;
      const ChannelCreationFunction channel_func_;
      TF_DISALLOW_COPY_AND_ASSIGN(SparseGrpcChannelCache);
    };
    

    2.3.2 非叶子结点

    为了提高 SparseGrpcChannelCache 查找过程以及对集群所有 Worker 节点 的组合管理,TF 把 集群内的 SparseGrpcChannelCache 组合起来,构建了 MultiGrpcChannelCache。MultiGrpcChannelCache 会把访问过的 SparseGrpcChannelCache 缓存起来。

    // A ChannelCache that is the union of multiple ChannelCaches.
    // Takes ownership of the caches passed to the constructor.
    class MultiGrpcChannelCache : public CachingGrpcChannelCache {
     public:
      explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches,
                                     int num_channels_per_target)
          : CachingGrpcChannelCache(num_channels_per_target), caches_(caches) {}
    
      ~MultiGrpcChannelCache() override {
        for (GrpcChannelCache* cache : caches_) {
          delete cache;
        }
      }
    
      void ListWorkers(std::vector<string>* workers) override {
        for (GrpcChannelCache* cache : caches_) {
          cache->ListWorkers(workers);
        }
      }
    
      void ListWorkersInJob(const string& job_name,
                            std::vector<string>* workers) override {
        for (GrpcChannelCache* cache : caches_) {
          cache->ListWorkersInJob(job_name, workers);
        }
      }
    
      string TranslateTask(const string& target) override {
        mutex_lock l(mu_);  // could use reader lock
        GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
        if (cache == nullptr) {
          for (GrpcChannelCache* c : caches_) {
            string r = c->TranslateTask(target);
            if (!r.empty()) {
              target_caches_.insert({target, c});
              cache = c;
              break;
            }
          }
        }
        return cache->TranslateTask(target);
      }
    
     protected:
      SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
        for (GrpcChannelCache* cache : caches_) {
          SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target));
          if (ch) {
            mutex_lock l(mu_);
            target_caches_.insert({target, cache});
            return ch;
          }
        }
        return nullptr;
      }
    
     private:
      // List of channels used by this MultiGrpcChannelCache.
      const std::vector<GrpcChannelCache*> caches_;
    
      mutex mu_;
      // Cache of channels keyed by the target they are handling.
      // The same GrpcChannelCache can appear multiple times in the cache.
      std::unordered_map<string, GrpcChannelCache*> target_caches_
          TF_GUARDED_BY(mu_);
    };
    

    目前结构如下:

    图 2 缓存逻辑关系

    2.4 生成 GrpcChannelCache

    前面在生成 GrpcChannelCache 时候,传入了 GetChannelCreationFunction,当时没有介绍,我们现在梳理一下。

      // 得到 GrpcChannelCache
      std::shared_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache(
          channel_spec, GetChannelCreationFunction(), *options.rpc_options));
    

    2.4.1 目标&使用

    我们首先看看如何使用或者说目标,就是通过 target(host:port类型的字符串)来生成一个 SharedGrpcChannelPtr,我们知道,SharedGrpcChannelPtr 就是 grpc::Channel。

    SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
      const string host_port = TranslateTask(target);
      if (host_port.empty()) {
      if (host_port.empty()) {
        return nullptr;
      }
      auto chan_ptr = channel_func_(host_port);
      VLOG(5) << "Channel created for: job: " << job_id_
              << " host_port: " << host_port << " target : " << target
              << " Ptr: " << chan_ptr.get();
      return chan_ptr;
    }
    

    2.4.2 NewHostPortGrpcChannel

    首先要介绍 NewHostPortGrpcChannel,NewHostPortGrpcChannel 是 TF 现存的 API。其主要作用是调用 ::grpc::CreateCustomChannel(gRPC API)得到一个 grpc::Channel,配置到 SharedGrpcChannelPtr* channel_pointer 之上,然后返回 channel_pointer(也就是 grpc::Channel)。这个方法的返回结果是我们满意的,但是调用方法不对,需要封装或转换一下。

    Status NewHostPortGrpcChannel(const string& target,
                                  const RPCOptions* rpc_options,
                                  SharedGrpcChannelPtr* channel_pointer) {
      // Minimally ensure that the target is valid
      TF_RETURN_IF_ERROR(ValidateHostPortPair(target));
    
      ::grpc::ChannelArguments args = GetChannelArguments(rpc_options);
      *channel_pointer = ::grpc::CreateCustomChannel(
          "dns:///" + target, ::grpc::InsecureChannelCredentials(), args);
      return Status::OK();
    }
    

    2.4.3 ConvertToChannelCreationFunction

    ConvertToChannelCreationFunction 方法是用来把传入的 new_channel_func_ptr 方法转换一下,把 new_channel_func_ptr 变成一个只需要传入 const string& target 就可以生成 SharedGrpcChannelPtr 的方法。

    ChannelCreationFunction ConvertToChannelCreationFunction(
        const std::function<Status(string, const RPCOptions*,
                                   SharedGrpcChannelPtr*)>& new_channel_func_ptr) {
      return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr {
        SharedGrpcChannelPtr channel_ptr;
        if (new_channel_func_ptr(target, /*rpc_options=*/nullptr, &channel_ptr)
                .ok()) {
          return channel_ptr;
        } else {
          return nullptr;
        }
      };
    }
    

    2.4.4 GetChannelCreationFunction

    GetChannelCreationFunction 就是使用 NewHostPortGrpcChannel 作为传入参数,得到一个 ConvertToChannelCreationFunction 的方法,因为这个方法才是可以被 WorkerCache工厂类利用的方法。

    ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
      // We can do this because SparseGrpcChannelCache is robust to nullptr being
      // returned by the channel creation function
      return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
    }
    

    2.4.5 使用分析

    回到我们的调用。channel_func_ 就是 GetChannelCreationFunction,于是直接调用就可以得到 grpc::Channel。

    SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
      const string host_port = TranslateTask(target);
      auto chan_ptr = channel_func_(host_port);
    }
    

    至此,我们拓展之前的逻辑如下,中间增加了一个步骤,通过传入 target 就可以得到 grpc::Channel:

    图 3 如何转换

    3. Cache 在系统中的位置

    我们虽然总结了 Cache 如何初始化,如何使用,但是我们迷失了 Cache 在系统之中的位置,现在我们看看究竟在系统之中,Cache 处于什么位置。GrpcWorkerCache 内部的 GrpcChannelCache 指向了系统内部的 gRPC Channel Cache,用来获取缓存的 gRPC 通道。local_worker 存储了本地 Worker。

    图 4 Cache 的位置

    当调用 GrpcWorkerCache 的 GetOrCreateWorker 时候,如果 target 是本地,就直接返回 local_worker(就是我们前面设置的本地 GrpcWorker),否则根据 Worker 的 RPC 通道来生成一个远端 GrpcRemoteWorker。

    图 5 生成 worker

    在 Master,Worker,MasterSesision,WorkerSession 之中,处处可见 WorkerCacheInterface(也就是GrpcWorkerCache)的身影,很多类都有一个指向 WorkerCacheInterface 的成员变量,使用相当广泛。

    4. 查找设备集

    为了创建 WorkerSession,MasterSession 需要知道远端所有 Worker 之上的设备集合,所以 Master 会在创建 MasterSession 之前遍历所有 Worker,获取其上的设备信息,因为其利用了 GrpcWorkerCache 的功能,所以我们在这里一起讲解。基本逻辑如下:

    • 根据 GrpcWorkerCache::ListWorkers 获取集群中所有 Worker 的名字。
    • 依据 worker_name 调用 GetOrCreateWorker 在 worker_cache 内部查找 WorkerInterface 对象,如果有就获取,没有就构建。
    • 然后构建 GetStatusRequest,发送给找到的 Worker,具体通过 GetStatusAsync 完成。
    • Worker 返回 GetStatusResponse 之后,将调用回调函数 cb (WhenFound方法)之中的函数对象来获取 Worke 的设备信息。这里需要对获取到的设备信息进行处理,添加 worker_name。

    图 6 获取设备

    4.1 DeviceFinder

    4.1.1 定义

    DeviceFinder 是一个函数对象,实现了查找远端worker设备的算法,我们先给出成员变量如下:

    class DeviceFinder {
      ~DeviceFinder() {
        for (Device* dev : found_) delete dev;
      }
    
      typedef DeviceFinder ME;
      const MasterEnv* env_;
      WorkerCacheInterface* worker_cache_;
      std::vector<DeviceNameUtils::ParsedName> filters_;
    
      mutex mu_;
      int num_pending_ TF_GUARDED_BY(mu_);
      condition_variable pending_zero_;
      std::vector<Device*> found_ TF_GUARDED_BY(mu_);
      // List of targets to be contacted by this DeviceFinder. The
      // respective `bool` in `seen_targets_` indicates whether we have
      // heard from this target or not.
      std::vector<string> targets_;
      std::vector<bool> seen_targets_ TF_GUARDED_BY(mu_);
      Status status_;
    
      TF_DISALLOW_COPY_AND_ASSIGN(DeviceFinder);
    };
    

    4.1.2 初始化

    主要逻辑是:根据 GrpcWorkerCache::ListWorkers 获取集群中所有的 Worker 的名字列表。

    explicit DeviceFinder(
        const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
        WorkerCacheInterface* worker_cache)
        : env_(env), worker_cache_(worker_cache) {
      CHECK(worker_cache) << "Worker cache was null!";
      auto process_filter = [this](const string& filter) {
        DeviceNameUtils::ParsedName parsed;
        if (DeviceNameUtils::ParseFullName(filter, &parsed)) {
          filters_.push_back(parsed);
        } else {
          LOG(FATAL) << "Skipping invalid filter: " << filter;
        }
      };
      for (const string& filter : device_filters) {
        process_filter(filter);
      }
      // Enumerates all known workers' target. A target name is a
      // prefix of a device name. E.g., /job:mnist/replica:0/task:10.
      if (filters_.empty()) {
        // If no filters were specified, we list all known workers in
        // `worker_cache`.
        std::vector<string> workers;
        worker_cache->ListWorkers(&workers);
        std::swap(workers, targets_);
      } else {
        // When applying filters, we must include the local worker, even if it
        // does not match any of the filters.
        CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
        const string& local_device_name = env_->local_devices[0]->name();
        DeviceNameUtils::ParsedName local_parsed_name;
        CHECK(DeviceNameUtils::ParseFullName(local_device_name,
                                             &local_parsed_name));
        bool all_filters_have_job = true;
        std::unordered_set<string> filter_job_names({local_parsed_name.job});
        for (const DeviceNameUtils::ParsedName& filter : filters_) {
          all_filters_have_job = all_filters_have_job && filter.has_job;
          if (filter.has_job) {
            filter_job_names.insert(filter.job);
          }
        }
    
        std::vector<string> workers;
        if (all_filters_have_job) {
          // If all of the device filters have a job specified, then we only need
          // to list the workers in the jobs named in the filter, because a worker
          // in any other job would not match any filter.
          for (const string& job_name : filter_job_names) {
            VLOG(2) << "Selectively listing workers in job: " << job_name;
            std::vector<string> workers_in_job;
            worker_cache->ListWorkersInJob(job_name, &workers_in_job);
            workers.insert(workers.end(), workers_in_job.begin(),
                           workers_in_job.end());
          }
        } else {
          // If any of the device filters does not have a job specified, then we
          // must list the workers from all jobs.
          VLOG(2) << "Listing workers in all jobs because some device "
                  << "filter has no job specified. Filters were:";
          if (device_filters.empty()) {
            VLOG(2) << "- <NO FILTERS>";
          } else {
            for (const string& filter : device_filters) {
              VLOG(2) << "- " << filter;
            }
          }
          worker_cache->ListWorkers(&workers);
        }
        for (const string& name : workers) {
          if (MatchFilters(name) ||
              DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {
            targets_.push_back(name);
          }
        }
      }
      seen_targets_.assign(targets_.size(), false);
    }
    

    4.1.3 GetRemoteDevices

    GetRemoteDevices 方法会获取远端设备,逻辑如下:

    • 利用 finder.Start() 来给集群内部所有 Worker 广播 GetStatusRequest。
    • 利用 finder.Wait() 收集所有 Worker 返回的 GetStatusResponse 消息。
    • 利用 finder.GetRemoteDevices 获取查询结果,并且返回给客户。
    static Status GetRemoteDevices(
        const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
        WorkerCacheInterface* worker_cache,
        std::vector<std::unique_ptr<Device>>* out_remote) {
      DeviceFinder finder(device_filters, env, worker_cache);
      finder.Start();
      TF_RETURN_IF_ERROR(finder.Wait());
      finder.GetRemoteDevices(env->local_devices, out_remote);
      return Status::OK();
    }
    
    4.1.3.1 Start

    Start 方法会把计数器 num_pending_ 初始化为 Worker 数目,然后遍历 Worker,逐一调用 NewRemoteDevices 进行处理。

    void Start() {
      {
        mutex_lock l(mu_);
        num_pending_ = targets_.size();
        if (num_pending_ == 0) {
          pending_zero_.notify_all();
        }
      }
      // Talk to all workers to get the list of available devices.
      using std::placeholders::_1;
      using std::placeholders::_2;
      for (size_t i = 0; i < targets_.size(); ++i) {
        // TODO(mrry): Propagate a timeout here, since `this->WhenFound()` may
        // never be called.
        NewRemoteDevices(env_->env, worker_cache_, targets_[i],
                         std::bind(&ME::WhenFound, this, i, _1, _2));
      }
    }
    

    NewRemoteDevices 逻辑如下:

    • 依据 worker_name 调用 GetOrCreateWorker 在 worker_cache 内部查找 WorkerInterface 对象,如果有就获取,没有就构建。
    • 然后构建 GetStatusRequest,发送给找到的 Worker,具体通过 GetStatusAsync 完成。
    • Worker 返回 GetStatusResponse 之后,将调用回调函数 cb (WhenFound方法)之中的函数对象来获取 Worke 的设备信息。这里需要对获取到的设备信息进行处理,添加 worker_name。
    void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
                          const string& worker_name, NewRemoteDevicesDone done) {
      WorkerInterface* wi = worker_cache->GetOrCreateWorker(worker_name);
      if (wi == nullptr) {
        std::vector<Device*> empty;
        done(errors::NotFound("Device ", worker_name, " is not found."), &empty);
        return;
      }
      struct Call {
        GetStatusRequest req; // 发送消息
        GetStatusResponse resp; // 相应消息
      };
      Call* call = new Call;
      // 回调函数
      auto cb = [env, worker_cache, worker_name, done, wi,
                 call](const Status& status) {
        Status s = status;
        std::vector<Device*> remote_devices;
        auto cleanup = gtl::MakeCleanup(
            [&worker_cache, &worker_name, &wi, &done, &remote_devices, &s, call] {
              worker_cache->ReleaseWorker(worker_name, wi);
              done(s, &remote_devices);
              delete call;
            });
        if (s.ok()) {
          DeviceNameUtils::ParsedName worker_name_parsed;
          if (!DeviceNameUtils::ParseFullName(worker_name, &worker_name_parsed) ||
              !worker_name_parsed.has_job || !worker_name_parsed.has_replica ||
              !worker_name_parsed.has_task) {
            s = errors::InvalidArgument("Could not parse worker name: ",
                                        worker_name);
            return;
          }
          remote_devices.reserve(call->resp.device_attributes_size());
          for (const DeviceAttributes& da : call->resp.device_attributes()) {
            DeviceNameUtils::ParsedName device_name_parsed;
            CHECK(DeviceNameUtils::ParseFullName(da.name(), &device_name_parsed))
                << "Device attribute name '" << da.name() << "' could not be "
                << "parsed. Device Attribute: " << da.DebugString();
            // Preserve the exact name, if possible.
            if (device_name_parsed.job == worker_name_parsed.job &&
                device_name_parsed.replica == worker_name_parsed.replica &&
                device_name_parsed.task == worker_name_parsed.task) {
              auto d = new RemoteDevice(env, da);
              remote_devices.push_back(d);
            } else {
              DeviceAttributes da_rewritten = da;
              da_rewritten.set_name(DeviceNameUtils::FullName(
                  worker_name_parsed.job, worker_name_parsed.replica,
                  worker_name_parsed.task, device_name_parsed.type,
                  device_name_parsed.id));
              auto d = new RemoteDevice(env, da_rewritten);
    
              // Experimental: Skipping over adding any TPU-type devices that aren't
              // on the job called "worker" (but still adds the CPUs of other jobs).
              if (getenv("TPU_NO_POPULATE_DEVICE_LIST_FROM_CLUSTER_SPEC") !=
                  nullptr) {
                if (worker_name_parsed.job == "worker" ||
                    device_name_parsed.type.find("TPU") == std::string::npos) {
                  remote_devices.push_back(d);
                }
              } else {
                remote_devices.push_back(d);
              }
            }
          }
        }
      };
      wi->GetStatusAsync(/*opts=*/nullptr, &call->req, &call->resp,
                         /*fail_fast=*/false, cb);
    }
    
    4.1.3.2 Wait

    Wait 方法之中,如果计数器不为 0,则一直调用 pending_zero_.wait_for 等待,期间主线程会周期性睡眠 10 秒钟。

    Status Wait() {
      mutex_lock l(mu_);
      // TODO(mrry): Propagate a timeout here, since `num_pending_` may
      // never become zero.
      while (num_pending_ != 0) {
        pending_zero_.wait_for(l, std::chrono::milliseconds(kLoggingPeriodMs));
        if (num_pending_ != 0) {
          for (size_t i = 0; i < targets_.size(); ++i) {
            if (!seen_targets_[i]) {
              LOG(INFO)
                  << "CreateSession still waiting for response from worker: "
                  << targets_[i];
            }
          }
        }
      }
      return status_;
    }
    
    4.1.3.3 回调函数

    Start 的回调函数如下,如果收到了某个 Worker 的GetStatusResponse 消息,则 Start 会调用到此。WhenDone将计数器减 1,如果计数器为 0,则调用 pending_zero_.notify_all(),这样 wait 之中的 pending_zero_.wait_for 语句 会被唤醒,GetRemoteDevices 方法就会利用 finder.GetRemoteDevices 获取查询结果,并且返回给客户。

    void WhenFound(int target_index, const Status& s,
                   std::vector<Device*>* devices) {
      mutex_lock l(mu_);
      seen_targets_[target_index] = true;
      if (!s.ok()) {
        LOG(ERROR) << "CreateSession failed because worker "
                   << targets_[target_index] << " returned error: " << s;
        status_.Update(s);
      } else {
        found_.insert(found_.end(), devices->begin(), devices->end());
        devices->clear();
      }
      --num_pending_;
      if (num_pending_ == 0) {
        pending_zero_.notify_all();
      }
    }
    

    4.2 Worker 交互

    NewRemoteDevices 之中会通过 GetStatusAsync 来构建 GetStatusRequest,发送给找到的 Worker。

    WorkerInterface* wi = worker_cache->GetOrCreateWorker(worker_name);
    wi->GetStatusAsync(/*opts=*/nullptr, &call->req, &call->resp,
                         /*fail_fast=*/false, cb);
    

    4.2.1 GrpcRemoteWorker

    wi 就是找到的 WorkerInterface,实际就是 GrpcRemoteWorker,这是 gRPC 的客户端,通过 stub 调用远端 WorkerService 相应的服务接口。

    void GetStatusAsync(CallOptions* call_opts, const GetStatusRequest* request,
                        GetStatusResponse* response, bool fail_fast,
                        StatusCallback done) override {
      IssueRequest(request, response, getstatus_, std::move(done), call_opts,
                   fail_fast);
    }
    

    4.2.2 GrpcWorkerService

    远端 Worker 之中,接收到消息是在 GrpcWorkerService 之中,当收到 GetStatusRequest 消息,将 由 GetStatusHandler 回调处理,GetStatusHandler 是一个宏。

    #define HANDLE_CALL(method, may_block_on_compute_pool)                        \
      void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
        auto closure = [this, call]() {                                           \
          Status s = worker_->method(&call->request, &call->response);            \
          if (!s.ok()) {                                                          \
            VLOG(3) << "Bad response from " << #method << ": " << s;              \
          }                                                                       \
          call->SendResponse(ToGrpcStatus(s));                                    \
        };                                                                        \
        if ((may_block_on_compute_pool)) {                                        \
          worker_->env()->env->SchedClosure(std::move(closure));                  \
        } else {                                                                  \
          worker_->env()->compute_pool->Schedule(std::move(closure));             \
        }                                                                         \
        ENQUEUE_REQUEST(method, false);                                           \
      }
    
      HANDLE_CALL(GetStatus, false);
    

    4.2.3 Worker

    最后来到 Worker 类,其实它也只是转交给 DeviceMgr,并最终通过 GetStatusResponse 消息返回给远端调用方。

    void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
                                GetStatusResponse* response, bool fail_fast,
                                StatusCallback done) {
      const DeviceMgr* dm = env_->device_mgr;
      std::vector<DeviceAttributes> devices;
      dm->ListDeviceAttributes(&devices);
      response->mutable_device_attributes()->Reserve(devices.size());
      for (auto& d : devices) {
        response->add_device_attributes()->Swap(&d);
      }
      done(Status::OK());
    }
    

    4.2.4 DeviceMgr

    ListDeviceAttributes 有两种本地设备信息汇总的实现,具体如下。

    void StaticDeviceMgr::ListDeviceAttributes(
        std::vector<DeviceAttributes>* devices) const {
      devices->reserve(devices_.size());
      for (const auto& dev : devices_) {
        devices->emplace_back(dev->attributes());
      }
    }
    

    实现 2 如下:

    void DynamicDeviceMgr::ListDeviceAttributes(
        std::vector<DeviceAttributes>* devices) const {
      tf_shared_lock l(devices_mu_);
      devices->reserve(dynamic_devices_.size());
      for (const auto& d : dynamic_devices_) {
        devices->emplace_back(d->attributes());
      }
    }
    

    至此,我们分析完了 Cache 和查找设备集,接下来我们去看看业务如何处理。

    0xFF 参考

    TensorFlow Internals

    TensorFlow架构与设计:概述

    TensorFlow内核剖析

    TensorFlow架构与设计:OP本质论

    [译] TensorFlow 白皮书

    2017TensorFlow开发者峰会

    https://jcf94.com/2018/02/28/2018-02-28-tfunpacking3/

    TensorFlow 拆包(五):Distributed

    TensorFlow Architecture

    『深度长文』Tensorflow代码解析(五)

    什么是in-graph replication和between-graph replication?

    [腾讯机智] TensorFlow源码解析(1): 创建会话

    05tensorflow分布式会话

    第八节,配置分布式TensorFlow

    TensorFlow 分布式(Distributed TensorFlow)

    tensorflow源码解析之distributed_runtime

    Distributed TensorFlow: A Gentle Introduction

    一文说清楚Tensorflow分布式训练必备知识

    TensorFlow中的Placement启发式算法模块——Placer

    TensorFlow的图切割模块——Graph Partitioner

    TensorFlow中的通信机制——Rendezvous(一)本地传输

    TensorFlow分布式采坑记

    TensorFlow技术内幕(九):模型优化之分布式执行

    Tensorflow架构流程]

    gRPC源码分析(c++)

  • 相关阅读:
    异常:Neither BindingResult nor plain target object for bean name 'command' available as request attribute
    SpringMVC 如何定义类型转换器
    Springmvc 进行数据类型转换
    mvc:view-controller 标签
    将POST请求转换为DELETE、PUT等请求的方法
    关于HiddenHttpMethodFilter
    关于抽象方法的调用的
    一对一关联关系基于主键映射的异常 IdentifierGenerationException
    (二十四)js内存回收方法
    (二十三)函数柯里化
  • 原文地址:https://www.cnblogs.com/rossiXYZ/p/16046365.html
Copyright © 2020-2023  润新知