Skip to content

Commit 202bcf8

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Add groupwise quant support (#2512)
Summary: Pull Request resolved: #2512 Reviewed By: kimishpatel, mcr229 Differential Revision: D55079666 Pulled By: digantdesai fbshipit-source-id: 63042d71dd46a75c443bdb186da2174ebb5c79cd
1 parent f9cad4e commit 202bcf8

14 files changed

Lines changed: 327 additions & 40 deletions

File tree

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
url = https://github.com/Maratyszcza/FXdiv.git
3131
[submodule "backends/xnnpack/third-party/XNNPACK"]
3232
path = backends/xnnpack/third-party/XNNPACK
33-
url = https://github.com/google/XNNPACK.git
33+
url = https://github.com/digantdesai/XNNPACK.git
3434
[submodule "backends/arm/third-party/serialization_lib"]
3535
path = backends/arm/third-party/serialization_lib
3636
url = https://review.mlplatform.org/tosa/serialization_lib

backends/xnnpack/operators/node_visitor.py

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import sys
99

1010
from pathlib import Path
11-
from typing import cast, Dict, Optional, Tuple
11+
from typing import cast, Dict, List, Optional, Tuple
1212

1313
import torch
1414
from executorch.backends.transforms import get_shape
@@ -21,6 +21,7 @@
2121

2222
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
2323
ConstantDataOffset,
24+
PerChannelGroupQuant,
2425
PerChannelQuant,
2526
PerTensorQuant,
2627
PerTokenDynamicQuant,
@@ -229,12 +230,20 @@ def get_per_channel_dtype(
229230
if quant_params.dtype == torch.int32:
230231
return XNNDatatype.xnn_datatype_qcint32
231232
elif quant_params.dtype == torch.int8:
232-
# 4/8-bit per channel quantized weights
233-
return (
234-
XNNDatatype.xnn_datatype_qcint4
235-
if quant_params.is_qc4w
236-
else XNNDatatype.xnn_datatype_qcint8
237-
)
233+
if quant_params.is_per_channel_group:
234+
# 4-bit per channel group quantized weights
235+
# No 8-bit support yet
236+
assert (
237+
quant_params.is_qc4w is True
238+
), "Only 4-bit per channel group quantization is supported"
239+
return XNNDatatype.xnn_datatype_qbint4
240+
else:
241+
# 4/8-bit per channel quantized weights
242+
return (
243+
XNNDatatype.xnn_datatype_qcint4
244+
if quant_params.is_qc4w
245+
else XNNDatatype.xnn_datatype_qcint8
246+
)
238247
else:
239248
raise RuntimeError(
240249
f"Unable to resolve static quantized tensor dtype using quant params dtype: {quant_params.dtype}, [qmin, qmax]: {quant_params.qmin}, {quant_params.qmax} for per channel quantization"
@@ -266,10 +275,17 @@ def get_per_channel_dtype(
266275
def get_quant_params(self, quant_params: QuantParams) -> XNNQuantParams:
267276
if quant_params.per_channel:
268277
scale = cast(torch.Tensor, quant_params.scale)
269-
return PerChannelQuant(
270-
scale=scale.tolist(),
271-
channel_dim=quant_params.axis,
272-
)
278+
if quant_params.is_per_channel_group:
279+
return PerChannelGroupQuant(
280+
scale=scale.flatten().tolist(),
281+
channel_dim=quant_params.axis,
282+
group_size=quant_params.group_size,
283+
)
284+
else: # per_channel quant
285+
return PerChannelQuant(
286+
scale=scale.tolist(),
287+
channel_dim=quant_params.axis,
288+
)
273289
elif quant_params.is_dynamic:
274290
# NB:
275291
# We use per_token quantization for per_tensor quantization
@@ -284,6 +300,42 @@ def get_quant_params(self, quant_params: QuantParams) -> XNNQuantParams:
284300
zero_point=cast(int, quant_params.zp),
285301
)
286302

303+
@staticmethod
304+
def _check_per_channel_group_params(
305+
quant_params: QuantParams, dims: List[int]
306+
) -> None:
307+
# Make sure things are lining up for per_channel_group quantization case
308+
# Has to be done this late because we don't have clean access to the actual tensor
309+
assert quant_params.is_per_channel_group, "Not per_channel_group quantization"
310+
# linear weights will be in [oc, ic]. And per_channel quantization must be on axis 0
311+
num_groups = cast(torch.Tensor, quant_params.scale).shape[1]
312+
assert (
313+
quant_params.axis == 0
314+
), "For per_channel_group quant, axis must be 0, but got {axis}"
315+
assert (
316+
len(dims) == 2
317+
), "For per_channel_group quant, expecting linear weights to be 2d, but got {len(dims)}"
318+
assert (
319+
num_groups > 0 and quant_params.group_size > 0
320+
), "For per_channel_group quant, num_groups and group_size must be > 0, but got num_groups: {num_groups}, group_size: {quant_params.group_size}"
321+
output_channels = dims[quant_params.axis]
322+
input_channels = dims[quant_params.axis ^ 1]
323+
assert (
324+
output_channels == cast(torch.Tensor, quant_params.scale).shape[0]
325+
), "For per_channel_group quant, expecting output channels to match scale.shape[0], gut got: {output_channels}, scale.shape[0]: {quant_params.scale.shape[0]}"
326+
assert (
327+
input_channels % num_groups == 0
328+
), "For per_channel_group quant, expecting input channels to be divisible by num_groups, but got ic: {input_channels}, num_groups: {num_groups}"
329+
assert (
330+
input_channels % quant_params.group_size == 0
331+
), "For per_channel_group quant, expecting input channels to be divisible by group_size, but got ic: {input_channels}, group_size: {quant_params.group_size}"
332+
assert (
333+
input_channels / quant_params.group_size == num_groups
334+
), "For per_channel_group quant, expecting input channels // group_size == num_groups, but got ic: {input_channels}, group_size: {quant_params.group_size}, num_groups: {num_groups}"
335+
336+
# For now group quantization is only supported for 4b weights
337+
assert quant_params.is_qc4w, "Only 4b group quantization is supported"
338+
287339
def define_tensor(
288340
self,
289341
tensor: torch.fx.Node,
@@ -331,6 +383,10 @@ def define_tensor(
331383
dims = get_shape(tensor)
332384
dims = [1] if len(dims) == 0 else dims
333385

386+
# check for per_channel_group quantization
387+
if quant_params and quant_params.per_channel_group:
388+
self._check_per_channel_group_params(quant_params, dims)
389+
334390
# constant values serialize data
335391
buffer_idx = self.get_serialized_buffer_index(
336392
tensor,
@@ -376,6 +432,7 @@ def define_tensor(
376432
else:
377433
assert f"Unsupported weight per channel quantization axis for depthwise conv2d: {quant_params.axis}, expecting 0."
378434

435+
# Serialize tensor value
379436
ser_val = (
380437
XValue(xvalue_union=tvalue)
381438
if quant_params is None

backends/xnnpack/operators/op_skip_ops.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,21 @@ class OpChooseQparamsToken(OpSkipOps):
104104
"""
105105

106106
target = "quantized_decomposed.choose_qparams_per_token_asymmetric.default"
107+
108+
109+
@register_node_visitor
110+
class OpQuantizePerChannelGroupDefault(OpSkipOps):
111+
"""
112+
do nothing if node is quantize_per_channel_group.default
113+
"""
114+
115+
target = "quantized_decomposed.quantize_per_channel_group.default"
116+
117+
118+
@register_node_visitor
119+
class OpDequantizePerChannelGroupDefault(OpSkipOps):
120+
"""
121+
do nothing if node is dequantize_per_channel_group.default
122+
"""
123+
124+
target = "quantized_decomposed.dequantize_per_channel_group.default"

backends/xnnpack/operators/quant_params.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
is_input: bool,
5757
is_dynamic: bool = False,
5858
num_nonbatch_dims: int = 1,
59+
group_size: int = 0,
5960
) -> None:
6061
self.per_channel = per_channel
6162
self.q_input = q_input
@@ -77,12 +78,29 @@ def __init__(
7778
and self.dtype == torch.int8
7879
)
7980

81+
# Groupwise quantization for weight
82+
self.per_channel_group = False
83+
self.group_size = group_size
84+
if self.group_size > 0:
85+
assert (
86+
self.per_channel is True
87+
), "Only per channel quantization supports groupwise quantization"
88+
assert (
89+
cast(torch.Tensor, scale).ndim == 2
90+
), "Scale must be 2D for per channel groupwise quant"
91+
self.per_channel_group = True
92+
assert group_size > 0, "Group size must be greater than 0"
93+
self.is_per_channel_group = self.per_channel and self.group_size > 0
94+
8095
def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
8196
# Do nothing if already quantized by the Quantizer
8297
if tensor.dtype == self.dtype:
8398
return tensor
8499

85100
if self.per_channel:
101+
assert (
102+
self.per_channel_group is False
103+
), f"Not expecting per channel group quantization, got q dtype: {self.dtype}, tensor.dtype {tensor.dtype}"
86104
assert (
87105
tensor.shape[self.axis] == cast(torch.Tensor, self.scale).shape[0]
88106
), f"Invalid size of per channel quantization scales, axis: {self.axis}, scale size: {self.scale.shape}, tensor shape: {tensor.shape}"
@@ -148,6 +166,16 @@ def from_q_dq_node(
148166
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
149167
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
150168
]
169+
170+
_groupwise = False
171+
if quant_node.target in [
172+
exir_ops.edge.quantized_decomposed.quantize_per_channel_group.default,
173+
exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default,
174+
]:
175+
# This is a sub-category of per channel quantization
176+
per_channel = True
177+
_groupwise = True
178+
151179
scale = quant_node.args[1]
152180
zp = quant_node.args[2]
153181
axis = 0
@@ -166,16 +194,34 @@ def _get_tensor(node):
166194
scale = _get_tensor(scale)
167195
zp = _get_tensor(zp)
168196
axis = cast(int, quant_node.args[3])
197+
198+
if _groupwise:
199+
scale_tensor = cast(torch.Tensor, scale)
200+
assert (
201+
scale_tensor.ndim == 2
202+
), "Weight scale must be 2D for per_channel_group [de]quant node, got {scale.ndim}D"
203+
axis = 0 # axis is ignored for groupwise quantization
204+
169205
check_or_raise(
170206
bool(
171207
quant_node.args[-1] != torch.uint8
172208
or quant_node.args[-1] != torch.quint8
173209
),
174210
"XNNPACK does not support unsigned quantization",
175211
)
176-
dtype = cast(torch.dtype, quant_node.args[-1])
177-
qmax = cast(int, quant_node.args[-2])
178-
qmin = cast(int, quant_node.args[-3])
212+
213+
if _groupwise:
214+
_ = quant_node.args[-1] # output dtype - not used
215+
group_size = cast(int, quant_node.args[-2])
216+
dtype = cast(torch.dtype, quant_node.args[-3])
217+
qmax = cast(int, quant_node.args[-4])
218+
qmin = cast(int, quant_node.args[-5])
219+
else:
220+
group_size = 0
221+
dtype = cast(torch.dtype, quant_node.args[-1])
222+
qmax = cast(int, quant_node.args[-2])
223+
qmin = cast(int, quant_node.args[-3])
224+
179225
is_output = any(
180226
user_node.op == "output" for user_node in quant_node.users.keys()
181227
)
@@ -191,6 +237,7 @@ def _get_tensor(node):
191237
qmin,
192238
is_output,
193239
is_input,
240+
group_size=group_size,
194241
)
195242

196243
@classmethod

backends/xnnpack/partition/configs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,13 @@
137137

138138
# Modules which support dynamic quantization
139139
# These already support dynamic shape.
140-
SUPPORTED_DYN_QUANT_MODULES = [
140+
SUPPORTED_DYN_QUANT_LINEAR_MODULES = [
141141
torch.nn.Linear,
142142
torch.nn.functional.linear,
143143
]
144144

145+
SUPPORTED_DYN_QUANT_MODULES = SUPPORTED_DYN_QUANT_LINEAR_MODULES
146+
145147
# TODO delete this once we catch up to 100% of the supported op with dynamic shape support.
146148
# This is tobe used only during the transition when we may not want to partition all the
147149
# nodes for a dynamic model.

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from executorch.backends.xnnpack.partition.configs import (
1515
_SUPPORTED_MODULES_WITH_DYNAMIC_SHAPE,
1616
_SUPPORTED_OPS_WITH_DYNAMIC_SHAPE,
17+
SUPPORTED_DYN_QUANT_LINEAR_MODULES,
1718
SUPPORTED_DYN_QUANT_MODULES,
1819
SUPPORTED_MODULES,
1920
SUPPORTED_OPS,
@@ -26,7 +27,11 @@
2627
FuseBatchNormWithConvPass,
2728
)
2829
from executorch.backends.xnnpack.utils.quant_utils import is_dequant
29-
from executorch.backends.xnnpack.utils.utils import get_input_node, is_param_node
30+
from executorch.backends.xnnpack.utils.utils import (
31+
get_input_node,
32+
get_source_fn,
33+
is_param_node,
34+
)
3035
from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend
3136

3237
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
@@ -333,10 +338,14 @@ def choose_qparams_tensor(cqp: torch.fx.Node, ep: ExportedProgram) -> bool: # n
333338
def dequant_per_token(dq: torch.fx.Node, ep: ExportedProgram) -> bool: # noqa
334339
node = list(dq.users.keys())[0]
335340
assert isinstance(node, torch.fx.Node)
336-
return node.target in [
337-
exir_ops.edge.aten.mm.default,
338-
exir_ops.edge.aten.addmm.default,
339-
]
341+
return (
342+
node.target
343+
in [
344+
exir_ops.edge.aten.mm.default,
345+
exir_ops.edge.aten.addmm.default,
346+
]
347+
or get_source_fn(node) in SUPPORTED_DYN_QUANT_LINEAR_MODULES
348+
)
340349

341350
@_constraint(exir_ops.edge.quantized_decomposed.quantize_per_token.default)
342351
def quant_per_token(q: torch.fx.Node, ep: ExportedProgram) -> bool: # noqa
@@ -363,6 +372,38 @@ def choose_qparams_per_token_asymmetric(
363372
and XnnpackOperatorSupport.check_constraint(q, ep)
364373
)
365374

375+
@_constraint(
376+
exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default
377+
)
378+
def dequant_per_channel_group_default(
379+
dq: torch.fx.Node, ep: ExportedProgram # noqa
380+
) -> bool:
381+
# Currently only supported by dqlinear weights
382+
permute_node = list(dq.users.keys())[0]
383+
assert isinstance(permute_node, torch.fx.Node)
384+
# We must have a transpose on [add]mm weights
385+
if permute_node.target != exir_ops.edge.aten.permute_copy.default:
386+
return False
387+
mm_node = list(permute_node.users.keys())[0]
388+
assert isinstance(mm_node, torch.fx.Node)
389+
return mm_node.target in [
390+
exir_ops.edge.aten.mm.default,
391+
exir_ops.edge.aten.addmm.default,
392+
]
393+
394+
@_constraint(exir_ops.edge.quantized_decomposed.quantize_per_channel_group.default)
395+
def quant_per_channel_group_default(
396+
q: torch.fx.Node, ep: ExportedProgram # noqa
397+
) -> bool:
398+
# we shouldn't have this with folded quant weights but doesn't hurt to lower it
399+
dq = list(q.users.keys())[0]
400+
assert isinstance(dq, torch.fx.Node)
401+
return (
402+
dq.target
403+
== exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default
404+
and XnnpackOperatorSupport.dequant_per_channel_default(dq, ep)
405+
)
406+
366407
@_constraint(exir_ops.edge.aten.pow.Tensor_Scalar)
367408
def pow_tensor_scalar(node: torch.fx.Node, ep: ExportedProgram) -> bool: # noqa
368409
"""
@@ -612,13 +653,15 @@ class XnnpackQuantizedPartitioner(XnnpackFloatingPointPartitioner):
612653
_Q_OPS = [
613654
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
614655
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
656+
exir_ops.edge.quantized_decomposed.quantize_per_channel_group.default,
615657
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
616658
exir_ops.edge.quantized_decomposed.quantize_per_token.default,
617659
]
618660

619661
_DQ_OPS = [
620662
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
621663
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
664+
exir_ops.edge.quantized_decomposed.dequantize_per_channel_group.default,
622665
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
623666
exir_ops.edge.quantized_decomposed.dequantize_per_token.default,
624667
]
@@ -763,13 +806,17 @@ class XnnpackPartitioner(Partitioner):
763806
_Q_OPS = [
764807
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
765808
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
809+
exir_ops.edge.quantized_decomposed.quantize_per_channel_group.default,
766810
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
811+
exir_ops.edge.quantized_decomposed.quantize_per_token.default,
767812
]
768813

769814
_DQ_OPS = [
770815
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
771816
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
817+
exir_ops.edge.quantized_decomposed.quantize_per_channel_group.default,
772818
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
819+
exir_ops.edge.quantized_decomposed.dequantize_per_token.default,
773820
]
774821

775822
_QPARAM_OPS = [

0 commit comments

Comments
 (0)