Skip to content

Commit a304f0a

Browse files
caisqtensorflower-gardener
authored andcommitted
tfdbg: improvements and fixes to tensor display in CLI
1) Enable scrolling to next regex match with command "/" following "/regex". 2) Enable scrolling to tensor indices with command such as "@[1, 2]" and "@100,30,0". 3) Display tensor indices at the top and bottom of the screen, and in scroll status info bar. 4) Handle invalid regex search commands, e.g., "/[", without crashing. Doc updated accordingly. Change: 137518091
1 parent b281537 commit a304f0a

8 files changed

Lines changed: 764 additions & 103 deletions

File tree

tensorflow/python/debug/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ py_library(
8989
deps = [
9090
":command_parser",
9191
":debugger_cli_common",
92+
":tensor_format",
9293
],
9394
)
9495

@@ -180,6 +181,7 @@ py_test(
180181
deps = [
181182
":curses_ui",
182183
":debugger_cli_common",
184+
":tensor_format",
183185
"//tensorflow/python:framework",
184186
"//tensorflow/python:framework_test_lib",
185187
],

tensorflow/python/debug/cli/command_parser.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,11 @@ def parse_tensor_name_with_slicing(in_str):
7777
Args:
7878
in_str: (str) Input name of the tensor, potentially followed by a slicing
7979
string. E.g.: Without slicing string: "hidden/weights/Variable:0", with
80-
slicing string: "hidden/weights/Varaible:0[1, :]"
80+
slicing string: "hidden/weights/Variable:0[1, :]"
8181
8282
Returns:
8383
(str) name of the tensor
84-
(str) sliciing string, if any. If no slicing string is present, return "".
84+
(str) slicing string, if any. If no slicing string is present, return "".
8585
"""
8686

8787
if in_str.count("[") == 1 and in_str.endswith("]"):
@@ -108,3 +108,27 @@ def validate_slicing_string(slicing_string):
108108
"""
109109

110110
return bool(re.search(r"^\[(\d|,|\s|:)+\]$", slicing_string))
111+
112+
113+
def parse_indices(indices_string):
114+
"""Parse a string representing indices.
115+
116+
For example, if the input is "[1, 2, 3]", the return value will be a list of
117+
indices: [1, 2, 3]
118+
119+
Args:
120+
indices_string: (str) a string representing indices. Can optionally be
121+
surrounded by a pair of brackets.
122+
123+
Returns:
124+
(list of int): Parsed indices.
125+
"""
126+
127+
# Strip whitespace.
128+
indices_string = re.sub(r"\s+", "", indices_string)
129+
130+
# Strip any brackets at the two ends.
131+
if indices_string.startswith("[") and indices_string.endswith("]"):
132+
indices_string = indices_string[1:-1]
133+
134+
return [int(element) for element in indices_string.split(",")]

tensorflow/python/debug/cli/command_parser_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,5 +129,34 @@ def testValidateInvalidSlicingStrings(self):
129129
self.assertFalse(command_parser.validate_slicing_string("[5, bar]"))
130130

131131

132+
class ParseIndicesTest(test_util.TensorFlowTestCase):
133+
134+
def testParseValidIndicesStringsWithBrackets(self):
135+
self.assertEqual([0], command_parser.parse_indices("[0]"))
136+
self.assertEqual([0], command_parser.parse_indices(" [0] "))
137+
self.assertEqual([-1, 2], command_parser.parse_indices("[-1, 2]"))
138+
self.assertEqual([3, 4, -5],
139+
command_parser.parse_indices("[3,4,-5]"))
140+
141+
def testParseValidIndicesStringsWithoutBrackets(self):
142+
self.assertEqual([0], command_parser.parse_indices("0"))
143+
self.assertEqual([0], command_parser.parse_indices(" 0 "))
144+
self.assertEqual([-1, 2], command_parser.parse_indices("-1, 2"))
145+
self.assertEqual([3, 4, -5], command_parser.parse_indices("3,4,-5"))
146+
147+
def testParseInvalidIndicesStringsWithoutBrackets(self):
148+
with self.assertRaisesRegexp(
149+
ValueError, r"invalid literal for int\(\) with base 10: 'a'"):
150+
self.assertEqual([0], command_parser.parse_indices("0,a"))
151+
152+
with self.assertRaisesRegexp(
153+
ValueError, r"invalid literal for int\(\) with base 10: '2\]'"):
154+
self.assertEqual([0], command_parser.parse_indices("1, 2]"))
155+
156+
with self.assertRaisesRegexp(
157+
ValueError, r"invalid literal for int\(\) with base 10: ''"):
158+
self.assertEqual([0], command_parser.parse_indices("3, 4,"))
159+
160+
132161
if __name__ == "__main__":
133162
googletest.main()

0 commit comments

Comments
 (0)