diff --git a/FasterTransformer/v1/fastertransformer/common.h b/FasterTransformer/v1/fastertransformer/common.h index df71f6a4c..c90673694 100644 --- a/FasterTransformer/v1/fastertransformer/common.h +++ b/FasterTransformer/v1/fastertransformer/common.h @@ -14,7 +14,7 @@ * limitations under the License. */ #pragma once - +#include #include #include #include diff --git a/FasterTransformer/v1/fastertransformer/trt_plugin/trt_model.h b/FasterTransformer/v1/fastertransformer/trt_plugin/trt_model.h index 0e6d620e3..8ec545fe4 100644 --- a/FasterTransformer/v1/fastertransformer/trt_plugin/trt_model.h +++ b/FasterTransformer/v1/fastertransformer/trt_plugin/trt_model.h @@ -65,7 +65,7 @@ class TRT_Transformer auto from_tensor = network->addInput(INPUT_BLOB_NAME, dtype_, nvinfer1::Dims2{seq_len_, hidden_dim_}); auto mask_tensor = network->addInput(MASK_BLOB_NAME, dtype_, nvinfer1::Dims2{seq_len_, seq_len_}); - assert(input_tensor); + assert(from_tensor); assert(mask_tensor); nvinfer1::ITensor* output_tensor = nullptr;