-
Notifications
You must be signed in to change notification settings - Fork 75.3k
Expand file tree
/
Copy pathconfig_test.py
More file actions
92 lines (76 loc) · 3.42 KB
/
config_test.py
File metadata and controls
92 lines (76 loc) · 3.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the open source DTensor Python API."""
import os
# pylint: disable=g-direct-tensorflow-import
from tensorflow.dtensor.python import config
from tensorflow.dtensor.python.tests import test_util
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.platform import test as tf_test
# pylint: enable=g-direct-tensorflow-import
class ConfigTest(tf_test.TestCase):
def setUp(self):
super().setUp()
test_util.reset_logical_devices('CPU', 2)
if test_util.is_gpu_present():
test_util.reset_logical_devices('GPU', 2)
def tearDown(self):
os.environ.pop(config._DT_JOBS, [])
super().tearDown()
def test_env_vars(self):
self.assertEqual(config.client_id(), 0)
self.assertEqual(config.num_clients(), 1)
self.assertEqual(config.job_name(), 'localhost')
self.assertEqual(config.full_job_name(), 'localhost/replica:0/task:0')
self.assertEqual(config.jobs(), [])
def test_list_devices(self):
device_type = config.preferred_device_type()
local_devices = [
tf_device.DeviceSpec.from_string(
f'/job:localhost/replica:0/task:0/device:{device_type}:0'),
tf_device.DeviceSpec.from_string(
f'/job:localhost/replica:0/task:0/device:{device_type}:1'),
]
self.assertEqual(config.local_devices(device_type), local_devices)
self.assertEqual(config.num_local_devices(device_type), 2)
self.assertEqual(config.num_global_devices(device_type), 2)
# The eager context should not be initialized by any of the calls
self.assertFalse(context.context()._initialized) # pylint: disable=protected-access
def test_sort_jobs_with_bns_names(self):
# bns names must be sorted in the bns order.
dtensor_jobs = [
'/bns/localhost/{task_id}'.format(task_id=i) for i in range(16)
]
os.environ[config._DT_JOBS] = ','.join(dtensor_jobs)
self.assertListEqual(dtensor_jobs, config.jobs())
dtensor_jobs = [
'/bns/localhost/{task_id}:8888'.format(task_id=i) for i in range(16)
]
os.environ[config._DT_JOBS] = ','.join(dtensor_jobs)
self.assertListEqual(dtensor_jobs, config.jobs())
dtensor_jobs = [
'/bns/localhost/{task_id}'.format(task_id=100 - i) for i in range(16)
]
os.environ[config._DT_JOBS] = ','.join(dtensor_jobs)
with self.assertRaisesRegex(ValueError, 'Unexpected DTENSOR_JOBS'):
config.jobs()
def test_jobs_with_ip_port(self):
# The ip port format is not a bns address, and not required to sorted.
dtensor_jobs = ['localhost:{port}'.format(port=16 - i) for i in range(16)]
os.environ[config._DT_JOBS] = ','.join(dtensor_jobs)
self.assertListEqual(dtensor_jobs, config.jobs())
if __name__ == '__main__':
tf_test.main()