Skip to content

Commit b2e89e6

Browse files
authored
[FT] FasterTransformer 3.0 Release (NVIDIA#696)
[FT] feat: Add FasterTransformer v3.0 1. Add supporting of INT8 quantization of cpp and TensorFlow op. 2. Provide the tools to quantize the model. 3. Fix the bugs that cmake 3.15 and 3.16 cannot build this project. 4. Deprecate the FasterTransformer v1
1 parent 66d1891 commit b2e89e6

265 files changed

Lines changed: 124599 additions & 23 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

FasterTransformer/README.md

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,47 +6,63 @@ This repository provides a script and recipe to run the highly optimized transfo
66
- [FasterTransformer](#fastertransformer)
77
- [Table Of Contents](#table-of-contents)
88
- [Model overview](#model-overview)
9-
- [FasterTransformer V1](#fastertransformer-v1)
10-
- [FasterTransformer V2](#fastertransformer-v2)
11-
- [FasterTransformer V2.1](#fastertransformer-v21)
9+
- [FasterTransformer v1](#fastertransformer-v1)
10+
- [FasterTransformer v2](#fastertransformer-v2)
11+
- [FasterTransformer v2.1](#fastertransformer-v21)
12+
- [FasterTransformer v3.0](#fastertransformer-v30)
1213
- [Architecture matrix](#architecture-matrix)
1314
- [Release notes](#release-notes)
1415
- [Changelog](#changelog)
1516
- [Known issues](#known-issues)
1617

1718
## Model overview
1819

19-
### FasterTransformer V1
20+
### FasterTransformer v1
2021

21-
FasterTransformer V1 provides a highly optimized BERT equivalent Transformer layer for inference, including C++ API, TensorFlow op and TensorRT plugin. The experiments show that FasterTransformer V1 can provide 1.3 ~ 2 times speedup on NVIDIA Tesla T4 and NVIDIA Tesla V100 for inference.
22+
FasterTransformer v1 provides a highly optimized BERT equivalent Transformer layer for inference, including C++ API, TensorFlow op and TensorRT plugin. The experiments show that FasterTransformer v1 can provide 1.3 ~ 2 times speedup on NVIDIA Tesla T4 and NVIDIA Tesla V100 for inference.
2223

23-
### FasterTransformer V2
24+
### FasterTransformer v2
2425

25-
FastTransformer V2 adds a highly optimized OpenNMT-tf based decoder and decoding for inference in FasterTransformer V1, including C++ API and TensorFlow op. The experiments show that FasterTransformer V2 can provide 1.5 ~ 11 times speedup on NVIDIA Telsa T4 and NVIDIA Tesla V 100 for inference.
26+
FastTransformer v2 adds a highly optimized OpenNMT-tf based decoder and decoding for inference in FasterTransformer v1, including C++ API and TensorFlow op. The experiments show that FasterTransformer v2 can provide 1.5 ~ 11 times speedup on NVIDIA Telsa T4 and NVIDIA Tesla V 100 for inference.
2627

27-
### FasterTransformer V2.1
28+
### FasterTransformer v2.1
2829

29-
FasterTransformer V2.1 optimizes some kernels of encoder and decoder, adding the support of PyTorch, the support of remove the padding of encoder and the support of sampling algorithm in decoding.
30+
FasterTransformer v2.1 optimizes some kernels of encoder and decoder, adding the support of PyTorch, the support of remove the padding of encoder and the support of sampling algorithm in decoding.
31+
32+
### FasterTransformer v3.0
33+
34+
FasterTransformer v3.0 adds the supporting of INT8 quantization for cpp and TensorFlow encoder model on Turing and Ampere GPUs.
3035

3136
### Architecture matrix
3237

3338
The following matrix shows the Architecture Differences between the model.
3439

35-
| Architecure | Encoder |Decoder | Decoding with beam search | Decoding with sampling |
36-
|---------------------------|---------------------|--------------------|---------------------------|------------------------|
37-
|FasterTransformer V1 | Yes | No | No | No |
38-
|FasterTransformer V2 | Yes | Yes | Yes | No |
39-
|FasterTransformer V2.1 | Yes | Yes | Yes | Yes |
40-
40+
| Architecure | Encoder | Encoder INT8 quantization |Decoder | Decoding with beam search | Decoding with sampling |
41+
|---------------------------|-------------------|----------------------------|--------------------|---------------------------|------------------------|
42+
|FasterTransformer v1 | Yes | No | No | No | No |
43+
|FasterTransformer v2 | Yes | No | Yes | Yes | No |
44+
|FasterTransformer v2.1 | Yes | No | Yes | Yes | Yes |
45+
|FasterTransformer v3.0 | Yes | Yes | Yes | Yes | Yes |
4146

4247
## Release notes
4348

44-
FasterTransformer V1 will be deprecated on July 2020.
49+
FasterTransformer v1 was deprecated on July 2020.
4550

46-
FasterTransformer V2 will be deprecated on Dec 2020.
51+
FasterTransformer v2 will be deprecated on Dec 2020.
52+
53+
FasterTransformer v2.1 will be deprecated on July 2021.
4754

4855
### Changelog
4956

57+
Sep 2020
58+
- **Release the FasterTransformer 3.0**
59+
- Support INT8 quantization of encoder of cpp and TensorFlow op.
60+
- Add bert-tf-quantization tool.
61+
- Fix the issue that Cmake 15 or Cmake 16 fail to build this project.
62+
63+
Aug 2020
64+
- Fix the bug of trt plugin.
65+
5066
June 2020
5167
- **Release the FasterTransformer 2.1**
5268
- Add [effective transformer](https://github.com/bytedance/effective_transformer) supporting into encoder.
@@ -85,14 +101,14 @@ March 2020
85101
- Add a normalization for inputs of decoder
86102

87103
February 2020
88-
* Release the FasterTransformer 2.0
89-
* Provide a highly optimized OpenNMT-tf based decoder and decoding, including C++ API and TensorFlow OP.
90-
* Refine the sample codes of encoder.
91-
* Add dynamic batch size feature into encoder op.
104+
- **Release the FasterTransformer 2.0**
105+
- Provide a highly optimized OpenNMT-tf based decoder and decoding, including C++ API and TensorFlow OP.
106+
- Refine the sample codes of encoder.
107+
- Add dynamic batch size feature into encoder op.
92108

93109
July 2019
94-
* Release the FasterTransformer 1.0
95-
* Provide a highly optimized bert equivalent transformer layer, including C++ API, TensorFlow OP and TensorRT plugin.
110+
- **Release the FasterTransformer 1.0**
111+
- Provide a highly optimized bert equivalent transformer layer, including C++ API, TensorFlow OP and TensorRT plugin.
96112

97113

98114
## Known issues

FasterTransformer/v3.0/.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
*~
2+
*.o
3+
build*/
4+
*.pyc
5+
.vscode/

FasterTransformer/v3.0/.gitmodules

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[submodule "sample/fastertransformer_bert/bert"]
2+
path = sample/tensorflow_bert/bert
3+
url = https://github.com/google-research/bert.git
4+
5+
[submodule "OpenNMT-tf"]
6+
path = OpenNMT-tf
7+
url = https://github.com/OpenNMT/OpenNMT-tf
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
cmake_minimum_required(VERSION 3.8 FATAL_ERROR) # for PyTorch extensions, version should be greater than 3.13
15+
project(FasterTransformer LANGUAGES CXX CUDA)
16+
17+
find_package(CUDA 10.0 REQUIRED)
18+
19+
option(BUILD_TRT "Build in TensorRT mode" OFF)
20+
option(BUILD_TF "Build in TensorFlow mode" OFF)
21+
option(BUILD_THE "Build in PyTorch eager mode" OFF)
22+
option(BUILD_THS "Build in TorchScript class mode" OFF)
23+
option(BUILD_THSOP "Build in TorchScript OP mode" OFF)
24+
25+
set(CXX_STD "11" CACHE STRING "C++ standard")
26+
27+
set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})
28+
29+
set(TF_PATH "" CACHE STRING "TensorFlow path")
30+
#set(TF_PATH "/usr/local/lib/python3.5/dist-packages/tensorflow")
31+
32+
if(BUILD_TF AND NOT TF_PATH)
33+
message(FATAL_ERROR "TF_PATH must be set if BUILD_TF(=TensorFlow mode) is on.")
34+
endif()
35+
36+
set(TRT_PATH "" CACHE STRING "TensorRT path")
37+
#set(TRT_PATH "/myspace/TensorRT-5.1.5.0")
38+
39+
if(BUILD_TRT AND NOT TRT_PATH)
40+
message(FATAL_ERROR "TRT_PATH must be set if BUILD_TRT(=TensorRT mode) is on.")
41+
endif()
42+
43+
list(APPEND CMAKE_MODULE_PATH ${CUDA_PATH}/lib64)
44+
find_package(CUDA REQUIRED)
45+
46+
if (${CUDA_VERSION} GREATER_EQUAL 11.0)
47+
message(STATUS "Add DCUDA11_MODE")
48+
add_definitions("-DCUDA11_MODE")
49+
endif()
50+
51+
# setting compiler flags
52+
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
53+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
54+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall")
55+
56+
if (SM STREQUAL 70 OR
57+
SM STREQUAL 75 OR
58+
SM STREQUAL 61 OR
59+
SM STREQUAL 60)
60+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_${SM},code=\\\"sm_${SM},compute_${SM}\\\" -rdc=true")
61+
if (SM STREQUAL 70 OR SM STREQUAL 75)
62+
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA")
63+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA")
64+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA")
65+
endif()
66+
if(BUILD_THE OR BUILD_THS OR BUILD_THSOP)
67+
string(SUBSTRING ${SM} 0 1 SM_MAJOR)
68+
string(SUBSTRING ${SM} 1 1 SM_MINOR)
69+
set(ENV{TORCH_CUDA_ARCH_LIST} "${SM_MAJOR}.${SM_MINOR}")
70+
endif()
71+
message("-- Assign GPU architecture (sm=${SM})")
72+
73+
else()
74+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} \
75+
-gencode=arch=compute_60,code=\\\"sm_60,compute_60\\\" \
76+
-gencode=arch=compute_61,code=\\\"sm_61,compute_61\\\" \
77+
-gencode=arch=compute_70,code=\\\"sm_70,compute_70\\\" \
78+
-gencode=arch=compute_75,code=\\\"sm_75,compute_75\\\" \
79+
-rdc=true")
80+
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA")
81+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA")
82+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA")
83+
if(BUILD_THE OR BUILD_THS OR BUILD_THSOP)
84+
set(ENV{TORCH_CUDA_ARCH_LIST} "6.0;6.1;7.0;7.5")
85+
endif()
86+
message("-- Assign GPU architecture (sm=60,61,70,75)")
87+
endif()
88+
89+
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0")
90+
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0")
91+
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall --ptxas-options=-v --resource-usage")
92+
93+
set(CMAKE_CXX_STANDARD "${CXX_STD}")
94+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
95+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
96+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
97+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD}")
98+
99+
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
100+
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3")
101+
102+
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
103+
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
104+
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
105+
106+
set(COMMON_HEADER_DIRS
107+
${PROJECT_SOURCE_DIR}
108+
${CUDA_PATH}/include
109+
)
110+
111+
set(COMMON_LIB_DIRS
112+
${CUDA_PATH}/lib64
113+
)
114+
115+
if(BUILD_TF)
116+
list(APPEND COMMON_HEADER_DIRS ${TF_PATH}/include)
117+
list(APPEND COMMON_LIB_DIRS ${TF_PATH})
118+
endif()
119+
120+
if(BUILD_TRT)
121+
list(APPEND COMMON_HEADER_DIRS ${TRT_PATH}/include)
122+
list(APPEND COMMON_LIB_DIRS ${TRT_PATH}/lib)
123+
endif()
124+
125+
set(PYTHON_PATH "python" CACHE STRING "Python path")
126+
if(BUILD_THE OR BUILD_THS OR BUILD_THSOP)
127+
execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; import os; import torch;
128+
print(os.path.dirname(torch.__file__),end='');"
129+
RESULT_VARIABLE _PYTHON_SUCCESS
130+
OUTPUT_VARIABLE TORCH_DIR)
131+
if (NOT _PYTHON_SUCCESS MATCHES 0)
132+
message(FATAL_ERROR "Torch config Error.")
133+
endif()
134+
list(APPEND CMAKE_PREFIX_PATH ${TORCH_DIR})
135+
find_package(Torch REQUIRED)
136+
137+
execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; from distutils import sysconfig;
138+
print(sysconfig.get_python_inc());
139+
print(sysconfig.get_config_var('SO'));"
140+
RESULT_VARIABLE _PYTHON_SUCCESS
141+
OUTPUT_VARIABLE _PYTHON_VALUES)
142+
if (NOT _PYTHON_SUCCESS MATCHES 0)
143+
message(FATAL_ERROR "Python config Error.")
144+
endif()
145+
string(REGEX REPLACE ";" "\\\\;" _PYTHON_VALUES ${_PYTHON_VALUES})
146+
string(REGEX REPLACE "\n" ";" _PYTHON_VALUES ${_PYTHON_VALUES})
147+
list(GET _PYTHON_VALUES 0 PY_INCLUDE_DIR)
148+
list(GET _PYTHON_VALUES 1 PY_SUFFIX)
149+
list(APPEND COMMON_HEADER_DIRS ${PY_INCLUDE_DIR})
150+
151+
execute_process(COMMAND ${PYTHON_PATH} "-c" "from torch.utils import cpp_extension; print(' '.join(cpp_extension._prepare_ldflags([],True,False)),end='');"
152+
RESULT_VARIABLE _PYTHON_SUCCESS
153+
OUTPUT_VARIABLE TORCH_LINK)
154+
if (NOT _PYTHON_SUCCESS MATCHES 0)
155+
message(FATAL_ERROR "PyTorch link config Error.")
156+
endif()
157+
endif()
158+
159+
160+
include_directories(
161+
${COMMON_HEADER_DIRS}
162+
)
163+
164+
link_directories(
165+
${COMMON_LIB_DIRS}
166+
)
167+
168+
add_subdirectory(tools)
169+
add_subdirectory(fastertransformer)
170+
add_subdirectory(sample)
171+
172+
if(BUILD_TF)
173+
add_custom_target(copy ALL COMMENT "Copying tensorflow test scripts")
174+
add_custom_command(TARGET copy
175+
POST_BUILD
176+
COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow/ ${PROJECT_BINARY_DIR} -r
177+
COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow_bert ${PROJECT_BINARY_DIR}/tensorflow -r
178+
)
179+
endif()
180+
181+
if(BUILD_THE OR BUILD_THS OR BUILD_THSOP)
182+
add_custom_target(copy ALL COMMENT "Copying pytorch test scripts")
183+
add_custom_command(TARGET copy
184+
POST_BUILD
185+
COMMAND cp ${PROJECT_SOURCE_DIR}/sample/pytorch/ ${PROJECT_BINARY_DIR} -r
186+
COMMAND mkdir -p ${PROJECT_BINARY_DIR}/pytorch/translation/data/
187+
COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow/utils/translation/test.* ${PROJECT_BINARY_DIR}/pytorch/translation/data/
188+
)
189+
endif()

0 commit comments

Comments
 (0)