-
Notifications
You must be signed in to change notification settings - Fork 399
Expand file tree
/
Copy pathdebugger_example.py
More file actions
72 lines (58 loc) · 2.11 KB
/
debugger_example.py
File metadata and controls
72 lines (58 loc) · 2.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
.. _debugger_example:
Debugging Torch-TensorRT Compilation
===================================================================
TensorRT conversion can perform many graph transformations and backend specific
optimizations that are sometimes hard to inspect. Torch-TensorRT provides a
Debugger utility to help visualize FX graphs around lowering passes, monitor
engine building, and capture profiling or TensorRT API traces.
In this example, we demonstrate how to:
1. Enable the Torch-TensorRT Debugger context
2. Capture and visualize FX graphs before and/or after specific lowering passes
3. Configure logging directory and verbosity
"""
import os
import tempfile
import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
temp_dir = os.path.join(tempfile.gettempdir(), "torch_tensorrt_debugger_example")
np.random.seed(0)
torch.manual_seed(0)
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]
model = models.resnet18(pretrained=False).to("cuda").eval()
exp_program = torch.export.export(model, tuple(inputs))
workspace_size = 20 << 30
min_block_size = 0
use_python_runtime = False
torch_executed_ops = {}
with torch_trt.dynamo.Debugger(
log_level="debug",
logging_dir=temp_dir,
engine_builder_monitor=False, # whether to monitor the engine building process
capture_fx_graph_after=[
"complex_graph_detection"
], # fx graph visualization after certain lowering pass
capture_fx_graph_before=[
"remove_detach"
], # fx graph visualization before certain lowering pass
):
trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
immutable_weights=False,
reuse_cached_engines=False,
)
trt_output = trt_gm(*inputs)
"""
The logging directory will contain the following files:
- /tmp/torch_tensorrt_debugger_example/
torch_tensorrt_logging.log
- /lowering_passes_visualization/
after_complex_graph_detection.svg
before_remove_detach.svg
"""