2121
2222from runtime .losses import partial_losses
2323from runtime .parse_results import process_performance_stats
24+ from model .tf_trt import export_model , TFTRTModel
2425
2526
2627def train (params , model , dataset , logger ):
@@ -101,6 +102,11 @@ def train_step(features, labels, warmup_batch=False):
101102 break
102103 if hvd .rank () == 0 :
103104 checkpoint .save (file_prefix = os .path .join (params .model_dir , "checkpoint" ))
105+ if params .use_savedmodel :
106+ prec = 'amp' if params .use_amp else 'fp32'
107+ model .save (os .path .join (params .model_dir , f'saved_model_{ prec } ' ))
108+ if params .use_tftrt :
109+ export_model (params .model_dir , prec , os .path .join (params .model_dir , f'tf-trt_model_{ prec } ' ))
104110
105111 logger .flush ()
106112
@@ -110,9 +116,15 @@ def evaluate(params, model, dataset, logger, restore_checkpoint=True):
110116 print ("No fold specified for evaluation. Please use --fold [int] to select a fold." )
111117 ce_loss = tf .keras .metrics .Mean (name = 'ce_loss' )
112118 f1_loss = tf .keras .metrics .Mean (name = 'dice_loss' )
113- checkpoint = tf .train .Checkpoint (model = model )
114119 if params .model_dir and restore_checkpoint :
115- checkpoint .restore (tf .train .latest_checkpoint (params .model_dir )).expect_partial ()
120+ prec = 'amp' if params .use_amp else 'fp32'
121+ if params .use_savedmodel :
122+ model = tf .keras .models .load_model (os .path .join (params .model_dir , f'saved_model_{ prec } ' ))
123+ elif params .use_tftrt :
124+ model = TFTRTModel (model_dir = params .model_dir , precision = prec )
125+ else :
126+ checkpoint = tf .train .Checkpoint (model = model )
127+ checkpoint .restore (tf .train .latest_checkpoint (params .model_dir )).expect_partial ()
116128
117129 def validation_step (features , labels ):
118130 output_map = model (features , training = False )
@@ -135,9 +147,15 @@ def validation_step(features, labels):
135147
136148
137149def predict (params , model , dataset , logger ):
138- checkpoint = tf . train . Checkpoint ( model = model )
150+ prec = 'amp' if params . use_amp else 'fp32'
139151 if params .model_dir :
140- checkpoint .restore (tf .train .latest_checkpoint (params .model_dir )).expect_partial ()
152+ if params .use_savedmodel :
153+ model = tf .keras .models .load_model (os .path .join (params .model_dir , f'saved_model_{ prec } ' ))
154+ elif params .use_tftrt :
155+ model = TFTRTModel (model_dir = params .model_dir , precision = prec )
156+ else :
157+ checkpoint = tf .train .Checkpoint (model = model )
158+ checkpoint .restore (tf .train .latest_checkpoint (params .model_dir )).expect_partial ()
141159
142160 @tf .function
143161 def prediction_step (features ):
0 commit comments