Skip to content

Commit 848d554

Browse files
mrrytensorflower-gardener
authored andcommitted
Prototype of an in-process gRPC server for TensorFlow/Python.
Adds support for binding a TensorFlow server to any port, to support single-process testing. This interface is a work in progress. In particular, it supports launching a server, but the support for clean shutdown is incomplete. Change: 116593644
1 parent 13024c5 commit 848d554

18 files changed

Lines changed: 548 additions & 96 deletions

tensorflow/core/distributed_runtime/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,18 @@ cc_library(
269269
],
270270
)
271271

272+
cc_library(
273+
name = "server_lib",
274+
srcs = ["server_lib.cc"],
275+
hdrs = ["server_lib.h"],
276+
deps = [
277+
"//tensorflow/core:framework",
278+
"//tensorflow/core:framework_internal",
279+
"//tensorflow/core:lib",
280+
"//tensorflow/core:protos_all_cc",
281+
],
282+
)
283+
272284
# TODO(mrry): Move executor_test.cc to ../common_runtime when once it no longer depends
273285
# on grpc_testlib.
274286
tf_cc_tests(

tensorflow/core/distributed_runtime/rpc/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ cc_library(
211211
srcs = [
212212
"grpc_server_lib.cc",
213213
],
214-
hdrs = ["grpc_server_lib.h"],
214+
linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel
215215
deps = [
216216
"@grpc//:grpc++_unsecure",
217217
":async_service_interface",
@@ -230,8 +230,10 @@ cc_library(
230230
"//tensorflow/core/distributed_runtime:master_env",
231231
"//tensorflow/core/distributed_runtime:master_session",
232232
"//tensorflow/core/distributed_runtime:process_util",
233+
"//tensorflow/core/distributed_runtime:server_lib",
233234
"//tensorflow/core/distributed_runtime:worker_env",
234235
],
236+
alwayslink = 1,
235237
)
236238

237239
cc_binary(
@@ -247,6 +249,7 @@ cc_binary(
247249
"//tensorflow/core:framework_internal",
248250
"//tensorflow/core:lib",
249251
"//tensorflow/core:protos_all_cc",
252+
"//tensorflow/core/distributed_runtime:server_lib",
250253
],
251254
)
252255

@@ -276,6 +279,7 @@ cc_binary(
276279
"//tensorflow/core:core_cpu",
277280
"//tensorflow/core:framework_internal",
278281
"//tensorflow/core:lib",
282+
"//tensorflow/core/distributed_runtime:server_lib",
279283
],
280284
)
281285

@@ -344,5 +348,6 @@ tf_cc_tests(
344348
"//tensorflow/core:test_main",
345349
"//tensorflow/core:testlib",
346350
"//tensorflow/core/distributed_runtime:process_util",
351+
"//tensorflow/core/distributed_runtime:server_lib",
347352
],
348353
)

tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc

Lines changed: 97 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
17-
1816
#include <memory>
1917

2018
#include "grpc++/grpc++.h"
@@ -33,6 +31,7 @@ limitations under the License.
3331
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
3432
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
3533
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
34+
#include "tensorflow/core/distributed_runtime/server_lib.h"
3635
#include "tensorflow/core/distributed_runtime/worker_env.h"
3736
#include "tensorflow/core/framework/op.h"
3837
#include "tensorflow/core/lib/strings/strcat.h"
@@ -41,14 +40,14 @@ limitations under the License.
4140
#include "tensorflow/core/public/session_options.h"
4241

4342
namespace tensorflow {
44-
4543
namespace {
46-
class TensorFlowServer : public ServerInterface {
44+
45+
class GrpcServer : public ServerInterface {
4746
public:
48-
TensorFlowServer(const ServerDef& server_def, Env* env)
47+
GrpcServer(const ServerDef& server_def, Env* env)
4948
: server_def_(server_def), env_(env), state_(NEW) {}
5049

51-
~TensorFlowServer() {
50+
~GrpcServer() {
5251
Stop();
5352
Join();
5453

@@ -59,8 +58,14 @@ class TensorFlowServer : public ServerInterface {
5958
// to destroy them.
6059
delete master_env_.worker_cache; // Shared with worker_env.worker_cache.
6160

62-
delete worker_env_.device_mgr;
61+
// We must delete graph_mgr before device_mgr, due to shared
62+
// ownership of OpKernels in the executors. (The graph_mgr will
63+
// free all stateless OpKernels, and pass over borrowed stateful
64+
// OpKernels, which are also held in their respective devices'
65+
// OpSegments.)
6366
delete worker_env_.graph_mgr;
67+
delete worker_env_.device_mgr;
68+
6469
delete worker_env_.rendezvous_mgr;
6570

6671
// Do not delete (as these are not owned by the server):
@@ -91,6 +96,56 @@ class TensorFlowServer : public ServerInterface {
9196
return errors::Internal("Could not parse worker name.");
9297
}
9398

99+
// Look up the port that has been requested for this task in `server_def_`.
100+
requested_port_ = -1;
101+
for (const auto& job : server_def_.cluster().job()) {
102+
if (job.name() == server_def_.job_name()) {
103+
auto iter = job.tasks().find(server_def_.task_index());
104+
if (iter == job.tasks().end()) {
105+
return errors::InvalidArgument("Task ", server_def_.task_index(),
106+
" was not defined in job \"",
107+
server_def_.job_name(), "\"");
108+
} else if (!str_util::NumericParse32(
109+
str_util::Split(iter->second, ':')[1],
110+
&requested_port_)) {
111+
return errors::Internal(
112+
"Could not parse port for local server from \"", iter->second,
113+
"\"");
114+
} else {
115+
break;
116+
}
117+
}
118+
}
119+
if (requested_port_ == -1) {
120+
return errors::Internal("Job \"", server_def_.job_name(),
121+
"\" was not defined in cluster");
122+
}
123+
124+
// N.B. The order of initialization here is intricate, because we
125+
// wish to allow `requested_port_ == 0` (for choosing any port,
126+
// mostly for testing). Therefore, the construction of the channel
127+
// and worker caches depends on `bound_port_`, which is not set
128+
// until we call `builder.BuildAndStart()`. We must create the
129+
// service objects before calling `builder.BuildAndStart()`, but
130+
// `master_env_` and `worker_env_` are only partially
131+
// configured. However, this is not dangerous, because we do not
132+
// start serving requests until `this->Start()` is called, which
133+
// happens after this method returns.
134+
//
135+
// TODO(mrry): Provide a general mechanism for dynamically setting
136+
// the identities of tasks in the worker pool after the service is
137+
// running.
138+
::grpc::ServerBuilder builder;
139+
builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port_),
140+
::grpc::InsecureServerCredentials(), &bound_port_);
141+
master_service_ = NewGrpcMasterService(&master_env_, &builder);
142+
worker_service_ = NewGrpcWorkerService(&worker_env_, &builder);
143+
server_ = builder.BuildAndStart();
144+
145+
if (!server_) {
146+
return errors::Internal("Could not start gRPC server");
147+
}
148+
94149
GrpcChannelSpec channel_spec;
95150
for (const auto& job : server_def_.cluster().job()) {
96151
int max_task_id = -1;
@@ -99,7 +154,12 @@ class TensorFlowServer : public ServerInterface {
99154
}
100155
std::vector<string> host_ports(max_task_id + 1);
101156
for (const auto& task : job.tasks()) {
102-
host_ports[task.first] = task.second;
157+
if (job.name() == server_def_.job_name() &&
158+
task.first == server_def_.task_index()) {
159+
host_ports[task.first] = strings::StrCat("localhost:", bound_port_);
160+
} else {
161+
host_ports[task.first] = task.second;
162+
}
103163
}
104164
channel_spec.AddHostPortsJob(job.name(), host_ports, host_ports.size());
105165
}
@@ -133,12 +193,6 @@ class TensorFlowServer : public ServerInterface {
133193
mutex_lock l(mu_);
134194
switch (state_) {
135195
case NEW: {
136-
::grpc::ServerBuilder builder;
137-
builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port_),
138-
::grpc::InsecureServerCredentials());
139-
master_service_ = NewGrpcMasterService(&master_env_, &builder);
140-
worker_service_ = NewGrpcWorkerService(&worker_env_, &builder);
141-
server_ = builder.BuildAndStart();
142196
master_thread_.reset(
143197
env_->StartThread(ThreadOptions(), "TF_master_service",
144198
[this] { master_service_->HandleRPCsLoop(); }));
@@ -196,16 +250,19 @@ class TensorFlowServer : public ServerInterface {
196250
}
197251
}
198252

199-
const string& target() const override { return target_; }
253+
const string target() const override {
254+
return strings::StrCat("grpc://localhost:", bound_port_);
255+
}
200256

201257
private:
202258
// The overall server configuration.
203259
const ServerDef server_def_;
204260
Env* env_;
205261

206262
// The port requested for this server.
207-
// TODO(mrry): Support requested_port_ == 0 to bind to any available port.
208263
int requested_port_;
264+
// The port to which this server is bound.
265+
int bound_port_ = 0;
209266

210267
// The `SessionOptions.target` to be used when connecting to this
211268
// server (as a master).
@@ -238,15 +295,30 @@ class TensorFlowServer : public ServerInterface {
238295

239296
std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_);
240297
};
241-
} // namespace
242298

243-
Status NewServer(const ServerDef& server_def,
244-
std::unique_ptr<ServerInterface>* out_server) {
245-
std::unique_ptr<TensorFlowServer> ret(
246-
new TensorFlowServer(server_def, Env::Default()));
247-
TF_RETURN_IF_ERROR(ret->Init());
248-
*out_server = std::move(ret);
249-
return Status::OK();
250-
}
299+
class GrpcServerFactory : public ServerFactory {
300+
public:
301+
bool AcceptsOptions(const ServerDef& server_def) override {
302+
return server_def.protocol() == "grpc";
303+
}
251304

305+
Status NewServer(const ServerDef& server_def,
306+
std::unique_ptr<ServerInterface>* out_server) override {
307+
std::unique_ptr<GrpcServer> ret(new GrpcServer(server_def, Env::Default()));
308+
TF_RETURN_IF_ERROR(ret->Init());
309+
*out_server = std::move(ret);
310+
return Status::OK();
311+
}
312+
};
313+
314+
// Registers a `ServerFactory` for `GrpcServer` instances.
315+
class GrpcServerRegistrar {
316+
public:
317+
GrpcServerRegistrar() {
318+
ServerFactory::Register("GRPC_SERVER", new GrpcServerFactory());
319+
}
320+
};
321+
static GrpcServerRegistrar registrar;
322+
323+
} // namespace
252324
} // namespace tensorflow

tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h

Lines changed: 0 additions & 65 deletions
This file was deleted.

tensorflow/core/distributed_runtime/rpc/grpc_server_lib_test.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
16+
#include "tensorflow/core/distributed_runtime/server_lib.h"
1717

1818
#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
1919
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -25,6 +25,7 @@ namespace tensorflow {
2525
// when no calls are made against the server.
2626
TEST(Server, StopAfterNoop) {
2727
ServerDef def;
28+
def.set_protocol("grpc");
2829
def.set_job_name("localhost");
2930
def.set_task_index(0);
3031
JobDef* job_def = def.mutable_cluster()->add_job();
@@ -42,6 +43,7 @@ TEST(Server, StopAfterNoop) {
4243
// when a simple call is made against the server.
4344
TEST(Server, StopAfterCall) {
4445
ServerDef def;
46+
def.set_protocol("grpc");
4547
def.set_job_name("localhost");
4648
def.set_task_index(0);
4749
JobDef* job_def = def.mutable_cluster()->add_job();

tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ limitations under the License.
1919
#include "grpc++/security/credentials.h"
2020
#include "grpc++/server_builder.h"
2121

22-
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
22+
#include "tensorflow/core/distributed_runtime/server_lib.h"
2323

2424
#include "tensorflow/core/lib/core/errors.h"
2525
#include "tensorflow/core/lib/core/status.h"
@@ -31,10 +31,13 @@ limitations under the License.
3131
#include "tensorflow/core/util/command_line_flags.h"
3232

3333
// This binary starts a TensorFlow server (master and worker).
34+
//
35+
// TODO(mrry): Replace with a py_binary that uses `tf.GrpcServer()`.
3436
namespace tensorflow {
3537
namespace {
3638

3739
Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) {
40+
options->set_protocol("grpc");
3841
string cluster_spec;
3942
int task_index = 0;
4043
const bool parse_result = ParseFlags(

tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License.
1717
#include "grpc++/security/credentials.h"
1818
#include "grpc++/server_builder.h"
1919

20-
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
20+
#include "tensorflow/core/distributed_runtime/server_lib.h"
2121

2222
#include "tensorflow/core/lib/core/errors.h"
2323
#include "tensorflow/core/lib/core/status.h"
@@ -33,6 +33,7 @@ namespace tensorflow {
3333
namespace {
3434

3535
Status ParseFlagsForTask(int argc, char* argv[], ServerDef* options) {
36+
options->set_protocol("grpc");
3637
string job_spec;
3738
int num_cpus = 1;
3839
int num_gpus = 0;

0 commit comments

Comments
 (0)