Skip to content

Commit d82e858

Browse files
sherrymtensorflower-gardener
authored andcommitted
Added replica_device_setter() for multi-replica training.
Example usage: cluster_spec = { "ps": ["ps0:2222", "ps1:2222"], "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]} with tf.device(tf.train.replica_device_setter(cluster=cluster_spec)): ... Change: 118245590
1 parent a25af1b commit d82e858

8 files changed

Lines changed: 326 additions & 96 deletions

File tree

tensorflow/all_files.bzl

Lines changed: 0 additions & 54 deletions
This file was deleted.

tensorflow/python/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
__all__.extend([
128128
'AttrValue',
129129
'ClusterDef',
130+
'ClusterSpec',
130131
'ConfigProto',
131132
'Event',
132133
'GPUOptions',
@@ -161,7 +162,6 @@
161162
'initialize_all_tables',
162163
'lin_space',
163164
'list_diff',
164-
'make_cluster_def',
165165
'parse_single_sequence_example',
166166
'py_func',
167167
'scalar_mul',

tensorflow/python/client/client_lib.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,8 @@
5555
from tensorflow.core.protobuf.tensorflow_server_pb2 import ClusterDef
5656
from tensorflow.core.protobuf.tensorflow_server_pb2 import JobDef
5757
from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
58+
from tensorflow.python.client.server_lib import ClusterSpec
5859
from tensorflow.python.client.server_lib import GrpcServer
59-
from tensorflow.python.client.server_lib import make_cluster_def
60-
6160

6261
from tensorflow.python.client.session import InteractiveSession
6362
from tensorflow.python.client.session import Session

tensorflow/python/client/server_lib.py

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -111,42 +111,64 @@ def create_local_server(start=True):
111111
return GrpcServer(server_def, start)
112112

113113

114-
def make_cluster_def(cluster_spec):
115-
"""Returns a `tf.ClusterDef` based on the given `cluster_spec`.
114+
class ClusterSpec(object):
115+
"""A class for representing a Cluster."""
116116

117-
Args:
118-
cluster_spec: A dictionary mapping one or more job names to lists
119-
of network addresses.
117+
def __init__(self, cluster):
118+
"""Creates a `ClusterSpec`.
120119
121-
Returns:
122-
A `tf.ClusterDef` protocol buffer.
123-
124-
Raises:
125-
TypeError: If `cluster_spec` is not a dictionary mapping strings to lists
126-
of strings.
127-
"""
128-
if not isinstance(cluster_spec, dict):
129-
raise TypeError("`cluster_spec` must be a dictionary mapping one or more "
130-
"job names to lists of network addresses")
131-
132-
cluster_def = tensorflow_server_pb2.ClusterDef()
133-
134-
# NOTE(mrry): Sort by job_name to produce deterministic protobufs.
135-
for job_name, task_list in sorted(cluster_spec.items()):
136-
try:
137-
job_name = compat.as_bytes(job_name)
138-
except TypeError:
139-
raise TypeError("Job name %r must be bytes or unicode" % job_name)
120+
Args:
121+
cluster: A dictionary mapping one or more job names to lists of network
122+
addresses, or a `tf.ClusterDef` protocol buffer.
140123
141-
job_def = cluster_def.job.add()
142-
job_def.name = job_name
124+
Raises:
125+
TypeError: If `cluster` is not a dictionary mapping strings to lists
126+
of strings, and not a `ClusterDef` proto buf.
127+
"""
128+
if isinstance(cluster, dict):
129+
self._cluster_spec = cluster
130+
self._make_cluster_def()
131+
elif isinstance(cluster, tensorflow_server_pb2.ClusterDef):
132+
self._cluster_def = cluster
133+
self._cluster_spec = {}
134+
for job_def in self._cluster_def.job:
135+
self._cluster_spec[job_def.name] = [t for t in job_def.tasks.values()]
136+
else:
137+
raise TypeError("`cluster` must be a dictionary mapping one or more "
138+
"job names to lists of network addresses, or a "
139+
"`ClusterDef` protocol buffer")
140+
141+
def as_cluster_spec(self):
142+
"""Returns a dictionary from job names to list of network addresses."""
143+
return self._cluster_spec
144+
145+
def as_cluster_def(self):
146+
"""Returns a `tf.ClusterDef` protocol buffer."""
147+
return self._cluster_def
148+
149+
def _make_cluster_def(self):
150+
"""Creates a `tf.ClusterDef` based on the given `cluster_spec`.
151+
152+
Raises:
153+
TypeError: If `cluster_spec` is not a dictionary mapping strings to lists
154+
of strings.
155+
"""
156+
self._cluster_def = tensorflow_server_pb2.ClusterDef()
143157

144-
for i, task_address in enumerate(task_list):
158+
# NOTE(mrry): Sort by job_name to produce deterministic protobufs.
159+
for job_name, task_list in sorted(self._cluster_spec.items()):
145160
try:
146-
task_address = compat.as_bytes(task_address)
161+
job_name = compat.as_bytes(job_name)
147162
except TypeError:
148-
raise TypeError(
149-
"Task address %r must be bytes or unicode" % task_address)
150-
job_def.tasks[i] = task_address
151-
152-
return cluster_def
163+
raise TypeError("Job name %r must be bytes or unicode" % job_name)
164+
165+
job_def = self._cluster_def.job.add()
166+
job_def.name = job_name
167+
168+
for i, task_address in enumerate(task_list):
169+
try:
170+
task_address = compat.as_bytes(task_address)
171+
except TypeError:
172+
raise TypeError(
173+
"Task address %r must be bytes or unicode" % task_address)
174+
job_def.tasks[i] = task_address

tensorflow/python/client/server_lib_test.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def testLargeFeed(self):
8383
class ServerDefTest(tf.test.TestCase):
8484

8585
def testLocalServer(self):
86-
cluster_def = tf.make_cluster_def({"local": ["localhost:2222"]})
86+
cluster_def = tf.ClusterSpec({"local": ["localhost:2222"]}).as_cluster_def()
8787
server_def = tf.ServerDef(cluster=cluster_def,
8888
job_name="local", task_index=0, protocol="grpc")
8989

@@ -94,9 +94,13 @@ def testLocalServer(self):
9494
job_name: 'local' task_index: 0 protocol: 'grpc'
9595
""", server_def)
9696

97+
# Verifies round trip from Proto->Spec->Proto is correct.
98+
cluster_spec = tf.ClusterSpec(cluster_def)
99+
self.assertProtoEquals(cluster_def, cluster_spec.as_cluster_def())
100+
97101
def testTwoProcesses(self):
98-
cluster_def = tf.make_cluster_def({"local": ["localhost:2222",
99-
"localhost:2223"]})
102+
cluster_def = tf.ClusterSpec({"local": ["localhost:2222",
103+
"localhost:2223"]}).as_cluster_def()
100104
server_def = tf.ServerDef(cluster=cluster_def,
101105
job_name="local", task_index=1, protocol="grpc")
102106

@@ -108,10 +112,14 @@ def testTwoProcesses(self):
108112
job_name: 'local' task_index: 1 protocol: 'grpc'
109113
""", server_def)
110114

115+
# Verifies round trip from Proto->Spec->Proto is correct.
116+
cluster_spec = tf.ClusterSpec(cluster_def)
117+
self.assertProtoEquals(cluster_def, cluster_spec.as_cluster_def())
118+
111119
def testTwoJobs(self):
112-
cluster_def = tf.make_cluster_def({
113-
"ps": ["ps0:2222", "ps1:2222"],
114-
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]})
120+
cluster_def = tf.ClusterSpec({"ps": ["ps0:2222", "ps1:2222"],
121+
"worker": ["worker0:2222", "worker1:2222",
122+
"worker2:2222"]}).as_cluster_def()
115123
server_def = tf.ServerDef(cluster=cluster_def,
116124
job_name="worker", task_index=2, protocol="grpc")
117125

@@ -126,6 +134,10 @@ def testTwoJobs(self):
126134
job_name: 'worker' task_index: 2 protocol: 'grpc'
127135
""", server_def)
128136

137+
# Verifies round trip from Proto->Spec->Proto is correct.
138+
cluster_spec = tf.ClusterSpec(cluster_def)
139+
self.assertProtoEquals(cluster_def, cluster_spec.as_cluster_def())
140+
129141

130142
if __name__ == "__main__":
131143
tf.test.main()

0 commit comments

Comments
 (0)