11import jinja2
2+ from enum import Enum
23from pathlib import Path
34import numpy as np
5+ from abc import ABCMeta , abstractmethod
6+ from .quantization_util import get_quantization_params , quantize
7+
48
59_template_dir = Path (__file__ ).parent / "templates"
610_template2_dir = Path (__file__ ).parent / "templates_v2"
4145 np .int32 : "int32_t" ,
4246 np .uint32 : "uint32_t" ,
4347 np .float : "float" ,
48+ np .dtype ('int8' ): "int8_t" ,
49+ np .dtype ('uint8' ): "uint8_t" ,
50+ np .dtype ('int16' ): "int16_t" ,
51+ np .dtype ('uint16' ): "uint16_t" ,
52+ np .dtype ('int32' ): "int32_t" ,
53+ np .dtype ('uint32' ): "uint32_t" ,
4454 np .dtype ('float32' ): "float" ,
4555 }
4656env2 = jinja2 .Environment (
5363 NUMPY_2_CMAP = NUMPY_2_CMAP ,
5464 )
5565
66+ class QuantizationType (Enum ):
67+ NONE = 0
68+ PER_TENSOR_ASYMMETRIC = 1
69+ PER_CHANNEL_ASYMMETRIC = 2
70+ PER_TENSOR_SYMMETRIC = 3
71+ PER_CHANNEL_SYMMETRIC = 4
72+
73+ class UnknownQuantizationTypeError (Exception ):
74+ pass
75+
76+ class QuantizationParams (object ):
77+ def __init__ (self , tensor ):
78+ self .tensor = tensor # Store ref to parent
79+ self .ref_name = tensor .ref_name
80+ self .zp = []
81+ self .scale = []
82+ #self.num_channels = 0
83+
84+ @property
85+ def ref_zp (self ):
86+ if not self .ref_name :
87+ print ("WARNING: No reference name set for Quantization Param" )
88+ return "%s_zp" % self .ref_name
89+ @property
90+ def ref_scale (self ):
91+ if not self .ref_name :
92+ print ("WARNING: No reference name set for Quantization Param" )
93+ return "%s_scale" % self .ref_name
94+
95+ def render_set_quantization_params (self ):
96+ if self .zp :
97+ return env2 .get_template ('set_quantization_params.cpp' ).render (qp = self )
98+ else :
99+ return ""
100+
101+ @property
102+ def num_channels (self ):
103+ if self .zp :
104+ return len (self .zp )
105+ else :
106+ return 0 ;
107+
108+ @property
109+ def quantization_type (self ):
110+ if self .num_channels == 1 :
111+ return "PerTensorQuantizationParams"
112+ elif self .num_channels > 1 :
113+ return "PerChannelQuantizationParams"
114+ else :
115+ raise UnknownQuantizationTypeError
116+
56117
57118class Tensor :
58- def __init__ (self , name , np_array , ref_name = None , quantize_params = [] ):
119+ def __init__ (self , name , np_array , ref_name = None , quantization_type = QuantizationType . NONE , quantize_dim = None , narrow_range = False , num_quant_bits = 8 ):
59120 self .name = name
60121 self .np_array = np_array
61122 self .ref_name = ref_name
62- self .quantize_params = quantize_params
123+ self .quantize_params = QuantizationParams (self )
124+ self .quantization_type = quantization_type
125+ self .quantize_dim = quantize_dim
126+ self .narrow_range = narrow_range
127+ self .quantized = False
128+ self .num_quant_bits = num_quant_bits
63129
64130 @property
65131 def shape (self ):
@@ -68,6 +134,8 @@ def shape(self):
68134 @property
69135 def dtype (self ):
70136 return NUMPY_2_CMAP [self .np_array .dtype ]
137+ def get_dtype (self ):
138+ return self .dtype
71139
72140 @property
73141 def utype (self ):
@@ -77,28 +145,108 @@ def flatten(self):
77145 return self .np_array .flatten ()
78146
79147 def render_constant (self ):
80- if self .ref_name :
81- return env2 .get_template ('def_constant.hpp' ).render (tensor = self )
82- else :
83- return ""
148+ return env2 .get_template ('def_constant.hpp' ).render (tensor = self )
84149 def render_declaration (self ):
85150 if self .ref_name :
86151 return env2 .get_template ('declare_rom_tensor.cpp' ).render (tensor = self )
87152 else :
88153 return env2 .get_template ('declare_ram_tensor.cpp' ).render (tensor = self )
154+
155+ def is_quantized (self ):
156+ return self .quantized and self .is_quantizable ()
157+
158+ def is_quantizable (self ):
159+ return self .quantization_type != QuantizationType .NONE
89160
161+ @property
162+ def symmetric (self ):
163+ if self .is_quantizable () and (self .quantization_type == QuantizationType .PER_TENSOR_SYMMETRIC or self .quantization_type == QuantizationType .PER_CHANNEL_SYMMETRIC ):
164+ return True
165+ else :
166+ return False
167+
168+ @property
169+ def per_tensor_quantization (self ):
170+ return self .is_quantizable () and (self .quantization_type == QuantizationType .PER_TENSOR_ASYMMETRIC or self .quantization_type == QuantizationType .PER_TENSOR_SYMMETRIC )
171+
172+ @property
173+ def per_channel_quantization (self ):
174+ return self .is_quantizable () and self .quantize_dim != None and (self .quantization_type == QuantizationType .PER_CHANNEL_ASYMMETRIC or self .quantization_type == QuantizationType .PER_CHANNEL_SYMMETRIC )
175+
176+ def get_quantization_params (self ):
177+ if not self .is_quantizable ():
178+ return (None , None )
179+ if not self .quantize_params .zp and not self .quantize_params .scale :
180+ # Else compute them
181+ if self .per_channel_quantization :
182+ num_dims = len (self .np_array .shape )
183+ num_channels = self .np_array .shape [self .quantize_dim ]
184+ for i in range (num_channels ):
185+ c = tuple ([ i if j == self .quantize_dim else slice (None ) for j in range (num_dims )])
186+ zp , scale = get_quantization_params (self .np_array [c ], symmetric = self .symmetric , narrow_range = self .narrow_range , num_quant_bits = self .num_quant_bits )
187+ self .quantize_params .zp .append (zp )
188+ self .quantize_params .scale .append (scale )
189+ else :
190+ zp , scale = get_quantization_params (self .np_array , symmetric = self .symmetric , narrow_range = self .narrow_range , num_quant_bits = self .num_quant_bits )
191+ self .quantize_params .zp .append (zp )
192+ self .quantize_params .scale .append (scale )
193+ return (self .quantize_params .zp , self .quantize_params .scale )
194+
195+ def quantize (self ):
196+ if self .quantized :
197+ return
198+ if not self .is_quantizable ():
199+ return None
200+ zp , scale = self .get_quantization_params ()
201+ if self .per_channel_quantization :
202+ if self .symmetric :
203+ if self .num_quant_bits == 8 :
204+ dtype = np .int8
205+ else :
206+ dtype = np .int32
207+ else :
208+ if self .num_quant_bits == 8 :
209+ dtype = np .uint8
210+ else :
211+ dtype = np .uint32
212+
213+ num_dims = len (self .np_array .shape )
214+ num_channels = self .np_array .shape [self .quantize_dim ]
215+ q_array = np .zeros (self .np_array .shape , dtype = dtype )
216+ for i in range (num_channels ):
217+ c = tuple ([ i if j == self .quantize_dim else slice (None ) for j in range (num_dims )])
218+ slc = self .np_array [c ]
219+ if isinstance (slc , np .float32 ):
220+ tmp = np .ndarray ((1 ), dtype = self .np_array .dtype )
221+ tmp [0 ] = slc
222+ else :
223+ tmp = slc
224+ q = quantize (tmp , zp [i ], scale [i ], self .symmetric , self .narrow_range , self .num_quant_bits )
225+ q_array [c ] = q
226+ self .np_array = q_array
227+ else :
228+ q = quantize (self .np_array , zp [0 ], scale [0 ], self .symmetric , self .narrow_range , self .num_quant_bits )
229+ self .np_array = q
230+ self .quantized = True
90231
91232class Operator :
92233 def __init__ (self , op_type , name , dtypes = [], param_str = None ):
234+ """
235+ dtypes should be bound to get_dtype methods on a tensor
236+ """
93237 self .op_type = op_type
94238 self .name = name
95- self .dtypes = dtypes
239+ self ._dtypes = dtypes
96240 self .param_str = param_str
97241 self .array_template = env2 .get_template ('array_template.cpp' )
98242 self .input_map = {}
99243 self .output_map = {}
100244 self .type_signature = env2 .get_template ('op_type_signature.cpp' ).render (op = self )
101245
246+ @property
247+ def dtype (self ):
248+ return [dt () for dt in self ._dtypes ]
249+
102250 def set_inputs (self , input_map ):
103251 self .input_map = input_map
104252 return self
@@ -112,6 +260,11 @@ def render_declaration(self):
112260
113261 def render_eval (self ):
114262 return env2 .get_template ('eval_operator.cpp' ).render (op = self )
263+ def quantize (self ):
264+ for thing in self .input_map :
265+ self .input_map [thing ].quantize ()
266+ for thing in self .output_map :
267+ self .output_map [thing ].quantize ()
115268
116269class SingleOpTest :
117270 def __init__ (self , test_group , test_name , target_op ):
@@ -133,6 +286,15 @@ def add_tensor_comparison(self, a, b):
133286 self .tensor_set .add (a )
134287 self .tensor_set .add (b )
135288
289+ def quantize (self ):
290+ self .target_op .quantize ()
291+ # Duplicate quantization because we can
292+ for (a , b ) in self .compare_tensors :
293+ a .quantize ()
294+ b .quantize ()
295+ for thing in tensor_set :
296+ thing .quantize ()
297+
136298 def render (self ):
137299 const_snippets = []
138300 tensor_decls = []
0 commit comments