@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
1515
16- #include " tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
17-
1816#include < memory>
1917
2018#include " grpc++/grpc++.h"
@@ -33,6 +31,7 @@ limitations under the License.
3331#include " tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
3432#include " tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
3533#include " tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
34+ #include " tensorflow/core/distributed_runtime/server_lib.h"
3635#include " tensorflow/core/distributed_runtime/worker_env.h"
3736#include " tensorflow/core/framework/op.h"
3837#include " tensorflow/core/lib/strings/strcat.h"
@@ -41,14 +40,14 @@ limitations under the License.
4140#include " tensorflow/core/public/session_options.h"
4241
4342namespace tensorflow {
44-
4543namespace {
46- class TensorFlowServer : public ServerInterface {
44+
45+ class GrpcServer : public ServerInterface {
4746 public:
48- TensorFlowServer (const ServerDef& server_def, Env* env)
47+ GrpcServer (const ServerDef& server_def, Env* env)
4948 : server_def_(server_def), env_(env), state_(NEW) {}
5049
51- ~TensorFlowServer () {
50+ ~GrpcServer () {
5251 Stop ();
5352 Join ();
5453
@@ -59,8 +58,14 @@ class TensorFlowServer : public ServerInterface {
5958 // to destroy them.
6059 delete master_env_.worker_cache ; // Shared with worker_env.worker_cache.
6160
62- delete worker_env_.device_mgr ;
61+ // We must delete graph_mgr before device_mgr, due to shared
62+ // ownership of OpKernels in the executors. (The graph_mgr will
63+ // free all stateless OpKernels, and pass over borrowed stateful
64+ // OpKernels, which are also held in their respective devices'
65+ // OpSegments.)
6366 delete worker_env_.graph_mgr ;
67+ delete worker_env_.device_mgr ;
68+
6469 delete worker_env_.rendezvous_mgr ;
6570
6671 // Do not delete (as these are not owned by the server):
@@ -91,6 +96,56 @@ class TensorFlowServer : public ServerInterface {
9196 return errors::Internal (" Could not parse worker name." );
9297 }
9398
99+ // Look up the port that has been requested for this task in `server_def_`.
100+ requested_port_ = -1 ;
101+ for (const auto & job : server_def_.cluster ().job ()) {
102+ if (job.name () == server_def_.job_name ()) {
103+ auto iter = job.tasks ().find (server_def_.task_index ());
104+ if (iter == job.tasks ().end ()) {
105+ return errors::InvalidArgument (" Task " , server_def_.task_index (),
106+ " was not defined in job \" " ,
107+ server_def_.job_name (), " \" " );
108+ } else if (!str_util::NumericParse32 (
109+ str_util::Split (iter->second , ' :' )[1 ],
110+ &requested_port_)) {
111+ return errors::Internal (
112+ " Could not parse port for local server from \" " , iter->second ,
113+ " \" " );
114+ } else {
115+ break ;
116+ }
117+ }
118+ }
119+ if (requested_port_ == -1 ) {
120+ return errors::Internal (" Job \" " , server_def_.job_name (),
121+ " \" was not defined in cluster" );
122+ }
123+
124+ // N.B. The order of initialization here is intricate, because we
125+ // wish to allow `requested_port_ == 0` (for choosing any port,
126+ // mostly for testing). Therefore, the construction of the channel
127+ // and worker caches depends on `bound_port_`, which is not set
128+ // until we call `builder.BuildAndStart()`. We must create the
129+ // service objects before calling `builder.BuildAndStart()`, but
130+ // `master_env_` and `worker_env_` are only partially
131+ // configured. However, this is not dangerous, because we do not
132+ // start serving requests until `this->Start()` is called, which
133+ // happens after this method returns.
134+ //
135+ // TODO(mrry): Provide a general mechanism for dynamically setting
136+ // the identities of tasks in the worker pool after the service is
137+ // running.
138+ ::grpc::ServerBuilder builder;
139+ builder.AddListeningPort (strings::StrCat (" 0.0.0.0:" , requested_port_),
140+ ::grpc::InsecureServerCredentials (), &bound_port_);
141+ master_service_ = NewGrpcMasterService (&master_env_, &builder);
142+ worker_service_ = NewGrpcWorkerService (&worker_env_, &builder);
143+ server_ = builder.BuildAndStart ();
144+
145+ if (!server_) {
146+ return errors::Internal (" Could not start gRPC server" );
147+ }
148+
94149 GrpcChannelSpec channel_spec;
95150 for (const auto & job : server_def_.cluster ().job ()) {
96151 int max_task_id = -1 ;
@@ -99,7 +154,12 @@ class TensorFlowServer : public ServerInterface {
99154 }
100155 std::vector<string> host_ports (max_task_id + 1 );
101156 for (const auto & task : job.tasks ()) {
102- host_ports[task.first ] = task.second ;
157+ if (job.name () == server_def_.job_name () &&
158+ task.first == server_def_.task_index ()) {
159+ host_ports[task.first ] = strings::StrCat (" localhost:" , bound_port_);
160+ } else {
161+ host_ports[task.first ] = task.second ;
162+ }
103163 }
104164 channel_spec.AddHostPortsJob (job.name (), host_ports, host_ports.size ());
105165 }
@@ -133,12 +193,6 @@ class TensorFlowServer : public ServerInterface {
133193 mutex_lock l (mu_);
134194 switch (state_) {
135195 case NEW: {
136- ::grpc::ServerBuilder builder;
137- builder.AddListeningPort (strings::StrCat (" 0.0.0.0:" , requested_port_),
138- ::grpc::InsecureServerCredentials ());
139- master_service_ = NewGrpcMasterService (&master_env_, &builder);
140- worker_service_ = NewGrpcWorkerService (&worker_env_, &builder);
141- server_ = builder.BuildAndStart ();
142196 master_thread_.reset (
143197 env_->StartThread (ThreadOptions (), " TF_master_service" ,
144198 [this ] { master_service_->HandleRPCsLoop (); }));
@@ -196,16 +250,19 @@ class TensorFlowServer : public ServerInterface {
196250 }
197251 }
198252
199- const string& target () const override { return target_; }
253+ const string target () const override {
254+ return strings::StrCat (" grpc://localhost:" , bound_port_);
255+ }
200256
201257 private:
202258 // The overall server configuration.
203259 const ServerDef server_def_;
204260 Env* env_;
205261
206262 // The port requested for this server.
207- // TODO(mrry): Support requested_port_ == 0 to bind to any available port.
208263 int requested_port_;
264+ // The port to which this server is bound.
265+ int bound_port_ = 0 ;
209266
210267 // The `SessionOptions.target` to be used when connecting to this
211268 // server (as a master).
@@ -238,15 +295,30 @@ class TensorFlowServer : public ServerInterface {
238295
239296 std::unique_ptr<::grpc::Server> server_ GUARDED_BY (mu_);
240297};
241- } // namespace
242298
243- Status NewServer (const ServerDef& server_def,
244- std::unique_ptr<ServerInterface>* out_server) {
245- std::unique_ptr<TensorFlowServer> ret (
246- new TensorFlowServer (server_def, Env::Default ()));
247- TF_RETURN_IF_ERROR (ret->Init ());
248- *out_server = std::move (ret);
249- return Status::OK ();
250- }
299+ class GrpcServerFactory : public ServerFactory {
300+ public:
301+ bool AcceptsOptions (const ServerDef& server_def) override {
302+ return server_def.protocol () == " grpc" ;
303+ }
251304
305+ Status NewServer (const ServerDef& server_def,
306+ std::unique_ptr<ServerInterface>* out_server) override {
307+ std::unique_ptr<GrpcServer> ret (new GrpcServer (server_def, Env::Default ()));
308+ TF_RETURN_IF_ERROR (ret->Init ());
309+ *out_server = std::move (ret);
310+ return Status::OK ();
311+ }
312+ };
313+
314+ // Registers a `ServerFactory` for `GrpcServer` instances.
315+ class GrpcServerRegistrar {
316+ public:
317+ GrpcServerRegistrar () {
318+ ServerFactory::Register (" GRPC_SERVER" , new GrpcServerFactory ());
319+ }
320+ };
321+ static GrpcServerRegistrar registrar;
322+
323+ } // namespace
252324} // namespace tensorflow
0 commit comments