-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Expand file tree
/
Copy pathsample.py
More file actions
179 lines (146 loc) · 6.1 KB
/
sample.py
File metadata and controls
179 lines (146 loc) · 6.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import sys
# This sample uses an MNIST PyTorch model to create a TensorRT Inference Engine
import model
import numpy as np
import tensorrt as trt
sys.path.insert(1, os.path.join(sys.path[0], os.path.pardir))
import common
# You can set the logger severity higher to suppress messages (or lower to display more messages).
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
class ModelData(object):
INPUT_NAME = "data"
INPUT_SHAPE = (1, 1, 28, 28)
OUTPUT_NAME = "prob"
OUTPUT_SIZE = 10
DTYPE = trt.float32
def populate_network(network, weights):
# Configure the network layers based on the weights provided.
input_tensor = network.add_input(
name=ModelData.INPUT_NAME, dtype=ModelData.DTYPE, shape=ModelData.INPUT_SHAPE
)
def add_matmul_as_fc(net, input, outputs, w, b):
assert len(input.shape) >= 3
m = 1 if len(input.shape) == 3 else input.shape[0]
k = int(np.prod(input.shape) / m)
assert np.prod(input.shape) == m * k
n = int(w.size / k)
assert w.size == n * k
assert b.size == n
input_reshape = net.add_shuffle(input)
input_reshape.reshape_dims = trt.Dims2(m, k)
filter_const = net.add_constant(trt.Dims2(n, k), w)
mm = net.add_matrix_multiply(
input_reshape.get_output(0),
trt.MatrixOperation.NONE,
filter_const.get_output(0),
trt.MatrixOperation.TRANSPOSE,
)
bias_const = net.add_constant(trt.Dims2(1, n), b)
bias_add = net.add_elementwise(
mm.get_output(0), bias_const.get_output(0), trt.ElementWiseOperation.SUM
)
output_reshape = net.add_shuffle(bias_add.get_output(0))
output_reshape.reshape_dims = trt.Dims4(m, n, 1, 1)
return output_reshape
conv1_w = weights["conv1.weight"].cpu().numpy()
conv1_b = weights["conv1.bias"].cpu().numpy()
conv1 = network.add_convolution_nd(
input=input_tensor,
num_output_maps=20,
kernel_shape=(5, 5),
kernel=conv1_w,
bias=conv1_b,
)
conv1.stride_nd = (1, 1)
pool1 = network.add_pooling_nd(
input=conv1.get_output(0), type=trt.PoolingType.MAX, window_size=(2, 2)
)
pool1.stride_nd = trt.Dims2(2, 2)
conv2_w = weights["conv2.weight"].cpu().numpy()
conv2_b = weights["conv2.bias"].cpu().numpy()
conv2 = network.add_convolution_nd(
pool1.get_output(0), 50, (5, 5), conv2_w, conv2_b
)
conv2.stride_nd = (1, 1)
pool2 = network.add_pooling_nd(conv2.get_output(0), trt.PoolingType.MAX, (2, 2))
pool2.stride_nd = trt.Dims2(2, 2)
fc1_w = weights["fc1.weight"].cpu().numpy()
fc1_b = weights["fc1.bias"].cpu().numpy()
fc1 = add_matmul_as_fc(network, pool2.get_output(0), 500, fc1_w, fc1_b)
relu1 = network.add_activation(
input=fc1.get_output(0), type=trt.ActivationType.RELU
)
fc2_w = weights["fc2.weight"].cpu().numpy()
fc2_b = weights["fc2.bias"].cpu().numpy()
fc2 = add_matmul_as_fc(
network, relu1.get_output(0), ModelData.OUTPUT_SIZE, fc2_w, fc2_b
)
fc2.get_output(0).name = ModelData.OUTPUT_NAME
network.mark_output(tensor=fc2.get_output(0))
def build_engine(weights):
# For more information on TRT basics, refer to the introductory samples.
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))
config = builder.create_builder_config()
runtime = trt.Runtime(TRT_LOGGER)
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, common.GiB(1))
# Populate the network using weights from the PyTorch model.
populate_network(network, weights)
# Build and return an engine.
plan = builder.build_serialized_network(network, config)
return runtime.deserialize_cuda_engine(plan)
# Loads a random test case from pytorch's DataLoader
def load_random_test_case(model, pagelocked_buffer):
# Select an image at random to be the test case.
img, expected_output = model.get_random_testcase()
# Copy to the pagelocked input buffer
np.copyto(pagelocked_buffer, img)
return expected_output
def main():
common.add_help(description="Runs an MNIST network using a PyTorch model")
# Train the PyTorch model
mnist_model = model.MnistModel()
mnist_model.learn()
weights = mnist_model.get_weights()
# Do inference with TensorRT.
engine = build_engine(weights)
# Build an engine, allocate buffers and create a stream.
# For more information on buffer allocation, refer to the introductory samples.
inputs, outputs, bindings = common.allocate_buffers(engine)
context = engine.create_execution_context()
# Use context manager for proper stream lifecycle management
with common.CudaStreamContext() as stream:
case_num = load_random_test_case(mnist_model, pagelocked_buffer=inputs[0].host)
# For more information on performing inference, refer to the introductory samples.
# The common.do_inference function will return a list of outputs - we only have one in this case.
[output] = common.do_inference(
context,
engine=engine,
bindings=bindings,
inputs=inputs,
outputs=outputs,
stream=stream,
)
pred = np.argmax(output)
common.free_buffers(inputs, outputs)
print("Test Case: " + str(case_num))
print("Prediction: " + str(pred))
if __name__ == "__main__":
main()