1+ import argparse
2+ import os
3+ import pathlib
4+ import time
5+ import tempfile
6+
7+ import tensorflow as tf
8+ import numpy as np
9+
10+ from tensorflow .python .compiler .tensorrt import trt_convert as trt
11+
12+ import dllogger
13+
14+ from runtime import runner_utils
15+ from runtime import runner
16+ from model .resnet import model_architectures
17+ from utils import data_utils
18+ from utils import hvd_wrapper as hvd
19+
20+ OUTPUT_SAVED_MODEL_PATH = tempfile .mkdtemp (prefix = "tftrt-converted" )
21+ LOG_FREQUENCY = 100
22+
23+ def argument_parser () -> argparse .Namespace :
24+ parser = argparse .ArgumentParser ()
25+
26+ exclusive_args = parser .add_mutually_exclusive_group ()
27+ exclusive_args .add_argument ("--model" , type = str , default = None , help = "Saved model location to use for inference" )
28+ exclusive_args .add_argument ("--architecture" , type = str , choices = model_architectures .keys ())
29+
30+ parser .add_argument ("--log-path" , type = str , default = "./log.json" , help = "Path to log file" )
31+ parser .add_argument ("--tf-trt" , action = "store_true" , default = False , help = "Use TF-TRT for inference" )
32+ parser .add_argument ("--amp" , action = "store_true" , default = False , help = "Use AMP for inference" )
33+ parser .add_argument ("--data-dir" , type = str , required = False ,
34+ default = None , help = "Localization of validation data" )
35+ parser .add_argument ("--batch-size" , type = int , default = 1 , help = "Batch size for inference" )
36+
37+ return parser .parse_args ()
38+
39+ def main (args : argparse .Namespace ):
40+ hvd .init ()
41+
42+ dllogger .init (backends = [
43+ dllogger .JSONStreamBackend (verbosity = dllogger .Verbosity .VERBOSE , filename = args .log_path ),
44+ dllogger .StdOutBackend (verbosity = dllogger .Verbosity .VERBOSE )
45+ ])
46+ dllogger .log (data = vars (args ), step = 'PARAMETER' )
47+
48+ if args .model is None :
49+ saved_model_to_load = tempfile .mkdtemp (prefix = "tftrt-savedmodel" )
50+ r = runner .Runner (n_classes = 1001 , architecture = args .architecture , use_tf_amp = args .amp ,
51+ model_dir = saved_model_to_load )
52+ r .train ("batch" , 1 , 1 , args .batch_size , is_benchmark = True )
53+ r .evaluate ("batch" , 1 , args .batch_size , export_dir = saved_model_to_load ,
54+ is_benchmark = True )
55+
56+ saved_model_to_load = r .exported_path .decode ("utf-8" )
57+ else :
58+ saved_model_to_load = args .model
59+
60+ output_tensor_name = "y_preds_ref:0" if not args .tf_trt else "ArgMax:0"
61+ batch_size = args .batch_size
62+
63+ if args .tf_trt :
64+ converter = trt .TrtGraphConverter (input_saved_model_dir = str (saved_model_to_load ),
65+ precision_mode = "FP16" if args .amp else "FP32" )
66+ converter .convert ()
67+ converter .save (OUTPUT_SAVED_MODEL_PATH )
68+ saved_model_to_load = OUTPUT_SAVED_MODEL_PATH
69+ elif args .amp :
70+ os .environ ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE" ] = "1"
71+
72+ if args .data_dir is not None :
73+ filenames , _ , num_steps , _ , _ = runner_utils .parse_tfrecords_dataset (
74+ data_dir = str (args .data_dir ),
75+ mode = "validation" ,
76+ iter_unit = "epoch" ,
77+ num_iter = 1 ,
78+ global_batch_size = batch_size ,
79+ )
80+
81+
82+ dataset = data_utils .get_tfrecords_input_fn (filenames = filenames ,
83+ batch_size = batch_size ,
84+ height = 224 ,
85+ width = 224 ,
86+ training = False ,
87+ distort_color = False ,
88+ num_threads = 1 ,
89+ deterministic = True )
90+ iterator = dataset .make_initializable_iterator ()
91+ next_item = iterator .get_next ()
92+ else :
93+ num_steps = 60000 / batch_size
94+
95+
96+ with tf .Session () as sess :
97+ if args .data_dir is not None :
98+ sess .run (iterator .initializer )
99+ tf .saved_model .loader .load (sess ,
100+ [tf .saved_model .tag_constants .SERVING ],
101+ str (saved_model_to_load ))
102+
103+ try :
104+ start_time = time .time ()
105+ last_time = start_time
106+ image_processed = 0
107+ image_correct = 0
108+
109+ for samples_processed in range (int (num_steps )):
110+ if args .data_dir is not None :
111+ next_batch_image , next_batch_target = sess .run (next_item )
112+ else :
113+ if samples_processed == 0 :
114+ next_batch_image = np .random .normal (size = (batch_size , 224 , 224 , 3 ))
115+ next_batch_target = np .random .randint (0 , 1000 , size = (batch_size ,))
116+ output = sess .run ([output_tensor_name ], feed_dict = {"input_tensor:0" : next_batch_image })
117+ image_processed += args .batch_size
118+ image_correct += np .sum (output == next_batch_target )
119+
120+ if samples_processed % LOG_FREQUENCY == 0 and samples_processed != 0 :
121+ current_time = time .time ()
122+ current_throughput = LOG_FREQUENCY * batch_size / (current_time - last_time )
123+ dllogger .log (step = (0 , samples_processed ), data = {"throughput" : current_throughput })
124+ last_time = current_time
125+
126+ except tf .errors .OutOfRangeError :
127+ pass
128+ finally :
129+ dllogger .log (step = tuple (), data = {"throughput" : image_processed / (last_time - start_time ),
130+ "accuracy" : image_correct / image_processed })
131+
132+
133+ if __name__ == "__main__" :
134+ main (argument_parser ())
0 commit comments