-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexport_executorch.py
More file actions
737 lines (650 loc) · 29.8 KB
/
export_executorch.py
File metadata and controls
737 lines (650 loc) · 29.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
# SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
import os
from pathlib import Path
from typing import Sequence
import torch
from arm_backend_monkey_patch import apply_arm_backend_monkey_patch
from executorch.backends.arm.quantizer import (
VgfQuantizer,
get_symmetric_quantization_config,
)
from executorch.backends.arm.quantizer.arm_quantizer import (
QuantizationSpec,
annotate_input_qspec_map,
annotate_output_qspec,
is_annotated,
mark_node_as_annotated,
)
from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner
from executorch.exir.backend.partitioner import Partitioner, PartitionResult
from executorch.exir.capture._config import EdgeCompileConfig
from executorch.exir.passes.quantize_io_pass import extract_io_quant_params
from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.fx.passes.infra.pass_base import PassResult
from torchao.quantization.pt2e.observer import FixedQParamsObserver
from torchao.quantization.pt2e.quantizer.quantizer import (
QuantizationAnnotation,
SharedQuantizationSpec,
)
from executorch_grid_sampler_backend import (
GridSamplerOnlyPartitioner,
pad_rgb_grid_sampler_channels_pass,
)
from utils import _emit_export_reports, _enable_vgf_debug_logging
apply_arm_backend_monkey_patch()
def _is_q_op(node) -> bool:
target = str(node.target)
return (
"quantized_decomposed.quantize_per_tensor" in target
or "quantized_decomposed.quantize_per_channel" in target
)
def _is_dq_op(node) -> bool:
target = str(node.target)
return (
"quantized_decomposed.dequantize_per_tensor" in target
or "quantized_decomposed.dequantize_per_channel" in target
)
def _min_bilinear_resize_stage_count(
input_size: int,
output_size: int,
) -> int:
if input_size <= 0 or output_size <= 0:
raise ValueError(
f"Resize sizes must be positive, got input={input_size}, output={output_size}"
)
if output_size >= input_size or input_size < 16 * output_size:
return 1
stage_count = 1
max_total_factor = 16
while input_size >= output_size * max_total_factor:
stage_count += 1
max_total_factor *= 16
return stage_count
def _build_bilinear_resize_axis_sizes(
input_size: int,
output_size: int,
*,
stage_count: int,
) -> list[int]:
if stage_count <= 0:
raise ValueError(f"stage_count must be positive, got {stage_count}")
if stage_count == 1:
return [output_size]
if output_size >= input_size or input_size < 16 * output_size:
return [input_size] * (stage_count - 1) + [output_size]
next_size = min(input_size, output_size * (16 ** (stage_count - 1)) - 1)
if next_size <= 0:
raise ValueError(
f"Computed invalid intermediate resize size {next_size} "
f"for input={input_size}, output={output_size}, stages={stage_count}"
)
return [next_size] + _build_bilinear_resize_axis_sizes(
next_size,
output_size,
stage_count=stage_count - 1,
)
def _make_bilinear_resize_stage_shapes(
input_hw: tuple[int, int],
output_hw: tuple[int, int],
) -> list[tuple[int, int]]:
stage_count = max(
_min_bilinear_resize_stage_count(input_hw[0], output_hw[0]),
_min_bilinear_resize_stage_count(input_hw[1], output_hw[1]),
)
if stage_count == 1:
return [output_hw]
axis_sizes = [
_build_bilinear_resize_axis_sizes(
input_size=input_size,
output_size=output_size,
stage_count=stage_count,
)
for input_size, output_size in zip(input_hw, output_hw, strict=True)
]
return list(zip(*axis_sizes, strict=True))
#
# Work around ExecuTorch Arm lowering emitting invalid TOSA for exact large
# bilinear downscales such as 1/16. Split oversized downscales into multiple
# legal resize stages before quantization/lowering.
#
# Upstream issue:
# https://github.com/pytorch/executorch/issues/19069
def _rewrite_large_bilinear_downsample_pass(graph_module):
modified = False
rewritten_nodes = 0
resize_target = torch.ops.aten.upsample_bilinear2d.vec
for node in list(graph_module.graph.nodes):
if node.op != "call_function" or node.target != resize_target:
continue
if len(node.args) != 4 or node.args[2] is not False:
continue
input_node = node.args[0]
input_val = getattr(input_node, "meta", {}).get("val")
output_val = node.meta.get("val")
if input_val is None or output_val is None:
continue
if getattr(input_val, "ndim", None) != 4 or getattr(output_val, "ndim", None) != 4:
continue
input_hw = (int(input_val.shape[2]), int(input_val.shape[3]))
output_hw = (int(output_val.shape[2]), int(output_val.shape[3]))
stage_shapes = _make_bilinear_resize_stage_shapes(input_hw, output_hw)
if len(stage_shapes) == 1:
continue
prev_node = input_node
prev_val = input_val
replacement_node = None
with graph_module.graph.inserting_before(node):
for idx, stage_shape in enumerate(stage_shapes):
is_final_stage = idx == len(stage_shapes) - 1
resize_node = graph_module.graph.call_function(
resize_target,
args=(prev_node, list(stage_shape), False, None),
)
resize_node.name = (
node.name if is_final_stage else f"{node.name}_split_{idx}"
)
if is_final_stage:
resize_node.meta = dict(node.meta)
else:
stage_val = resize_target(prev_val, list(stage_shape), False, None)
resize_node.meta = dict(node.meta)
resize_node.meta["val"] = stage_val
resize_node.meta["tensor_meta"] = _extract_tensor_metadata(stage_val)
prev_val = stage_val
prev_node = resize_node
replacement_node = resize_node
if replacement_node is None:
continue
node.replace_all_uses_with(replacement_node)
graph_module.graph.erase_node(node)
rewritten_nodes += 1
modified = True
if modified:
graph_module.graph.lint()
graph_module.recompile()
return PassResult(graph_module=graph_module, modified=modified), rewritten_nodes
class VgfPartitionerWrapper(Partitioner):
"""
Wrap VGF partitioning to keep only true graph-IO Q/DQ at top level.
Why this exists:
- In INT+FP mode, Arm's TOSA partitioner does not automatically detag boundary
Q/DQ nodes. That can cause VGF to swallow graph-boundary Q/DQ, and then
`extract_io_quant_params` cannot match/strip them later.
- The custom grid-sampler backend also relies on keeping the dequant before
`grid_sampler` and the quant after it visible at top level. That backend
wants to delegate and replace the whole dequant -> grid_sampler -> quant
sequence as one quantized grid-sampler operation.
How this differs from regular `VgfPartitioner`:
- Regular VGF partitioning keeps whatever tags the TOSA partitioner emits.
- This wrapper post-processes those tags and detags only:
1) quantize nodes that directly consume placeholders
2) dequantize nodes that directly feed the graph output
3) quantize/dequantize nodes adjacent to `grid_sampler`
- Internal quant/dequant inside delegated partitions are intentionally kept.
"""
def __init__(
self,
inner_partitioner,
*,
quantized_input_idxs: Sequence[int] = (),
quantized_output_idxs: Sequence[int] = (),
) -> None:
super().__init__()
self.inner = inner_partitioner
self._quantized_input_idxs = frozenset(quantized_input_idxs)
self._quantized_output_idxs = frozenset(quantized_output_idxs)
def ops_to_not_decompose(self, ep):
return self.inner.ops_to_not_decompose(ep)
def partition(self, exported_program) -> PartitionResult:
result = self.inner.partition(exported_program)
gm = result.tagged_exported_program.graph_module
user_input_names = list(result.tagged_exported_program.graph_signature.user_inputs)
quantized_input_names = {
user_input_names[idx]
for idx in self._quantized_input_idxs
if idx < len(user_input_names)
}
output_node = next((node for node in gm.graph.nodes if node.op == "output"), None)
output_parent_to_idx = {}
if output_node is not None:
output_parent_to_idx = {
parent: idx for idx, parent in enumerate(output_node.all_input_nodes)
}
for node in gm.graph.nodes:
tag = node.meta.get("delegation_tag")
if tag is None:
continue
if _is_q_op(node):
is_graph_input_q = any(
inp.op == "placeholder"
and inp.name in quantized_input_names
for inp in node.all_input_nodes
)
feeds_grid_sampler = any(
"aten.grid_sampler_2d.default" in str(inp.target)
for inp in node.all_input_nodes
)
if is_graph_input_q or feeds_grid_sampler:
del node.meta["delegation_tag"]
elif _is_dq_op(node):
is_graph_output_dq = any(
user.op == "output"
and output_parent_to_idx.get(node) in self._quantized_output_idxs
for user in node.users
)
feeds_grid_sampler = any(
"aten.grid_sampler_2d.default" in str(user.target)
for user in node.users
)
if is_graph_output_dq or feeds_grid_sampler:
del node.meta["delegation_tag"]
active_tags = {
node.meta["delegation_tag"]
for node in gm.graph.nodes
if "delegation_tag" in node.meta
}
result.partition_tags = {
tag: spec
for tag, spec in result.partition_tags.items()
if tag in active_tags
}
return result
def _feeds_grid_sampler_grid_input(node) -> bool:
# The grid placeholder may reach grid_sampler through inserted Q/DQ nodes,
# so a direct-user check is not enough here. We use this to keep sampling
# grids out of the channels-last image rewrite even after quantization.
visited = set()
def visit(current) -> bool:
if current in visited or not hasattr(current, "users"):
return False
visited.add(current)
for user in current.users:
if (
user.op == "call_function"
and user.target == exir_ops.edge.aten.grid_sampler_2d.default
and len(user.args) > 1
and user.args[1] is current
):
return True
if user.op != "call_function":
continue
if _is_q_op(user) or _is_dq_op(user):
if visit(user):
return True
return False
return visit(node)
class VgfQuantizerWrapper(VgfQuantizer):
"""
VGF quantizer wrapper that customizes graph IO and `grid_sampler` quantization.
It applies caller-provided qspecs to graph inputs/outputs positionally, and
also rewrites `grid_sampler` annotations so callers can separately control
the image input, grid input, and output qspecs at the grid-sampler boundary.
That keeps the quantized image path aligned with the separately lowered
grid-sampler backend while still allowing a different quantization scheme
for the sampling grid itself.
"""
def __init__(
self,
compile_spec,
*,
input_qspecs: Sequence[QuantizationSpec | None],
output_qspecs: Sequence[QuantizationSpec | None],
grid_sampler_image_input_qspec: QuantizationSpec | None = None,
grid_sampler_grid_input_qspec: QuantizationSpec | None = None,
grid_sampler_output_qspec: QuantizationSpec | None = None,
) -> None:
super().__init__(compile_spec)
self._input_qspecs = tuple(input_qspecs)
self._output_qspecs = tuple(output_qspecs)
self._grid_sampler_image_input_qspec = grid_sampler_image_input_qspec
self._grid_sampler_grid_input_qspec = grid_sampler_grid_input_qspec
self._grid_sampler_output_qspec = grid_sampler_output_qspec
def _annotate_io(self, model, quantization_config):
# Overrides the base quantizer's IO annotation hook so callers can
# explicitly choose the boundary qspec for each graph input/output.
del quantization_config
placeholder_nodes = [
node
for node in model.graph.nodes
if node.op == "placeholder" and len(node.users) > 0
]
for idx, node in enumerate(placeholder_nodes):
if idx >= len(self._input_qspecs):
continue
if is_annotated(node):
continue
input_qspec = self._input_qspecs[idx]
if input_qspec is None:
continue
annotate_output_qspec(node, input_qspec)
mark_node_as_annotated(node)
# FX/export graphs normally have one `output` node whose inputs are the
# actual graph outputs, so iterate that node's parents rather than
# expecting one node per returned value.
output_node = next(
(node for node in model.graph.nodes if node.op == "output"),
None,
)
if output_node is None or is_annotated(output_node):
return
did_annotate_any_output = False
for idx, parent in enumerate(output_node.all_input_nodes):
if idx >= len(self._output_qspecs):
continue
output_qspec = self._output_qspecs[idx]
if output_qspec is None:
continue
annotate_input_qspec_map(output_node, parent, output_qspec)
did_annotate_any_output = True
if did_annotate_any_output:
mark_node_as_annotated(output_node)
def _annotate_grid_sample(self, model) -> None:
def _annotate_grid_sample_image_pad(pad_node) -> None:
# Keep the RGB->RGBA pad in the same fixed image domain as the
# delegated grid-sampler image input. Without this, Arm's default
# activation annotation gives the inserted pad its own activation
# qparams, and PT2E inserts an extra requantize before grid_sample.
pad_annotation = pad_node.meta.get("quantization_annotation")
if pad_annotation is None:
pad_annotation = QuantizationAnnotation()
if pad_annotation.input_qspec_map is None:
pad_annotation.input_qspec_map = {}
source_node = pad_node.args[0]
pad_input_qspec_map = dict(pad_annotation.input_qspec_map)
pad_input_qspec_map[source_node] = self._grid_sampler_image_input_qspec
pad_annotation.input_qspec_map = pad_input_qspec_map
pad_annotation.output_qspec = SharedQuantizationSpec((source_node, pad_node))
pad_node.meta["quantization_annotation"] = pad_annotation
mark_node_as_annotated(pad_node)
for node in model.graph.nodes:
if node.op != "call_function" or node.target not in {
torch.ops.aten.grid_sampler.default,
torch.ops.aten.grid_sampler_2d.default,
}:
continue
quant_annotation = node.meta.get("quantization_annotation")
if quant_annotation is None:
quant_annotation = QuantizationAnnotation()
if quant_annotation.input_qspec_map is None:
quant_annotation.input_qspec_map = {}
image_input_node = node.args[0]
grid_input_node = node.args[1]
input_qspec_map = dict(quant_annotation.input_qspec_map)
if self._grid_sampler_image_input_qspec is not None:
if (
image_input_node.op == "call_function"
and image_input_node.target == torch.ops.aten.pad.default
):
_annotate_grid_sample_image_pad(image_input_node)
input_qspec_map[image_input_node] = SharedQuantizationSpec(
image_input_node
)
else:
input_qspec_map[image_input_node] = (
self._grid_sampler_image_input_qspec
)
if self._grid_sampler_grid_input_qspec is not None:
input_qspec_map[grid_input_node] = self._grid_sampler_grid_input_qspec
else:
input_qspec_map.pop(grid_input_node, None)
quant_annotation.input_qspec_map = input_qspec_map
quant_annotation.output_qspec = self._grid_sampler_output_qspec
node.meta["quantization_annotation"] = quant_annotation
def _annotate_for_static_quantization_config(self, model):
# This override is reached via the base quantizer's annotate() method
# during PT2E prepare. Reuse the base static-annotation flow, then patch
# up grid_sampler annotations for the custom backend boundary behavior.
model = super()._annotate_for_static_quantization_config(model)
if (
self._grid_sampler_image_input_qspec is not None
or self._grid_sampler_grid_input_qspec is not None
or self._grid_sampler_output_qspec is not None
):
self._annotate_grid_sample(model)
return model
def _set_tensors_to_channels_last_pass(graph_module):
# Convert the input image tensor to channels-last memory format.
# This is so that the VGF backend won't need to insert a transpose and we can alias inputs as images in scenario runner.
for node in graph_module.graph.nodes:
if node.op == "placeholder":
if node.meta["val"].ndim == 4:
# The second input to grid_sampler is the sampling grid, not an image.
# Leaving that tensor in its native layout preserves the logical
# [N, H, W, 2] coordinate order when we later serialize scenario
# tensor dims and .npy contents from FX metadata.
feeds_grid_sampler_grid_input = _feeds_grid_sampler_grid_input(node)
if feeds_grid_sampler_grid_input:
continue
node.meta["val"] = node.meta["val"].contiguous(
memory_format=torch.channels_last
)
# If the input is immediately quantised, change that node's output to channels last too, as the input to the VGF backend will be this quantised tensor
for user in node.users:
if (
user.op == "call_function"
and user.target
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
):
user.meta["val"] = user.meta["val"].contiguous(
memory_format=torch.channels_last
)
# Same for the output image
for node in graph_module.graph.nodes:
if node.op == "output":
for output in node.args[0]:
if output.meta["val"].ndim != 4:
continue
output.meta["val"] = output.meta["val"].contiguous(
memory_format=torch.channels_last
)
# If the preceding node is a dequantize, change its input to channels last too, as the output of the VGF backend will be this quantised tensor
if (
output.op == "call_function"
and output.target
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
):
input = output.args[0]
input.meta["val"] = input.meta["val"].contiguous(
memory_format=torch.channels_last
)
# Same for the inputs and outputs of grid sampler nodes, to avoid transposes on the VGF side of the boundary and to
# allow the grid sampler backend to delegate them and alias as images
for node in graph_module.graph.nodes:
if (
node.op == "call_function"
and node.target == exir_ops.edge.aten.grid_sampler_2d.default
):
# Input image
input_image = node.args[0]
input_image.meta["val"] = input_image.meta["val"].contiguous(
memory_format=torch.channels_last
)
# If input image is quantised, the input to that dequantize node needs to be marked as channels_last too, since the grid sampler backend will want to delegate the dequant -> grid_sampler -> quant sequence as one unit and alias the dequant input as the image
if (
input_image.op == "call_function"
and input_image.target
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
):
input = input_image.args[0]
input.meta["val"] = input.meta["val"].contiguous(
memory_format=torch.channels_last
)
# Output image
node.meta["val"] = node.meta["val"].contiguous(
memory_format=torch.channels_last
)
# If the output image is quantised, the output of the quantize node after grid sampler also needs to be marked as channels_last, since the grid sampler backend will want to delegate the grid_sampler -> quant sequence as one unit and alias the quant output as the image
for user in node.users:
if (
user.op == "call_function"
and user.target
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
):
user.meta["val"] = user.meta["val"].contiguous(
memory_format=torch.channels_last
)
return PassResult(graph_module=graph_module, modified=True)
def export_model(
model: torch.nn.Module,
model_inputs,
*,
artifacts_root: str | Path = Path("artifacts"),
intermediate_artifacts_root: str | Path | None = None,
quantized_input_qspecs: Sequence[QuantizationSpec | None] = (),
quantized_output_qspecs: Sequence[QuantizationSpec | None] = (),
module_name_quant_configs: dict[str, object | None] | None = None,
grid_sampler_grid_input_qspec: QuantizationSpec | None = None,
quantize_grid_sample_output: bool = True,
enable_pt2e_quantization: bool = True,
) -> tuple[object, tuple[torch.Tensor, ...], dict]:
verbose = os.environ.get("EXPORT_EXECUTORCH_VERBOSE", "0") == "1"
_enable_vgf_debug_logging(verbose)
example_inputs = tuple(model_inputs)
if not example_inputs:
raise ValueError("export_model requires at least one model input tensor.")
quantized_input_qspecs = tuple(quantized_input_qspecs)
quantized_output_qspecs = tuple(quantized_output_qspecs)
module_name_quant_configs = dict(module_name_quant_configs or {})
if enable_pt2e_quantization and len(quantized_input_qspecs) != len(example_inputs):
raise ValueError(
"quantized_input_qspecs must have one entry per model input when PT2E quantization is enabled, "
f"got {len(quantized_input_qspecs)} qspecs for {len(example_inputs)} inputs"
)
artifacts_root = Path(artifacts_root)
artifacts_root.mkdir(parents=True, exist_ok=True)
if intermediate_artifacts_root is None:
intermediate_artifacts_root = artifacts_root
intermediate_artifacts_root = Path(intermediate_artifacts_root)
intermediate_artifacts_root.mkdir(parents=True, exist_ok=True)
# Need FP for calculations of grid sample positions and int16 for the quantisation to 16-bit SNORM.
vgf_compile_spec = VgfCompileSpec(
"TOSA-1.0+INT+FP+int16"
).dump_intermediate_artifacts_to(
str(intermediate_artifacts_root / "vgf_intermediates")
)
exported_program = torch.export.export(model, example_inputs, strict=True)
captured_graph = exported_program.module()
resize_rewrite_result, rewritten_resize_count = (
_rewrite_large_bilinear_downsample_pass(captured_graph)
)
captured_graph = resize_rewrite_result.graph_module
if verbose and rewritten_resize_count:
print(
"Rewrote oversized bilinear downsample nodes before quantization:",
rewritten_resize_count,
)
output_node = next(
node for node in captured_graph.graph.nodes if node.op == "output"
)
output_count = len(output_node.all_input_nodes)
if enable_pt2e_quantization and len(quantized_output_qspecs) != output_count:
raise ValueError(
"quantized_output_qspecs must have one entry per model output when PT2E quantization is enabled, "
f"got {len(quantized_output_qspecs)} qspecs for {output_count} outputs"
)
quantized_input_idxs = [
idx for idx, qspec in enumerate(quantized_input_qspecs) if qspec is not None
]
quantized_output_idxs = [
idx for idx, qspec in enumerate(quantized_output_qspecs) if qspec is not None
]
# Pad RGB inputs to grid sampler to RGBA so we can alias as 4-channel images
pad_rgb_grid_sampler_channels_pass(captured_graph)
# PT2E quantization with Arm VGF quantizer (8-bit symmetric)
if enable_pt2e_quantization:
default_grid_position_qspec = QuantizationSpec(
dtype=torch.int16,
observer_or_fake_quant_ctr=FixedQParamsObserver.with_args(
scale=1.0 / 32767.0,
zero_point=0,
dtype=torch.int16,
qscheme=torch.per_tensor_symmetric,
quant_min=-32767,
quant_max=32767,
),
quant_min=-32767,
quant_max=32767,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
)
if grid_sampler_grid_input_qspec is None:
grid_sampler_grid_input_qspec = default_grid_position_qspec
internal_image_qspec = QuantizationSpec(
dtype=torch.int8,
observer_or_fake_quant_ctr=FixedQParamsObserver.with_args(
scale=1.0 / 127.0,
zero_point=0,
dtype=torch.qint8,
qscheme=torch.per_tensor_symmetric,
quant_min=-127,
quant_max=127,
),
quant_min=-127,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
)
quantizer = VgfQuantizerWrapper(
vgf_compile_spec,
input_qspecs=quantized_input_qspecs,
output_qspecs=quantized_output_qspecs,
grid_sampler_image_input_qspec=internal_image_qspec,
grid_sampler_grid_input_qspec=grid_sampler_grid_input_qspec,
grid_sampler_output_qspec=(
internal_image_qspec if quantize_grid_sample_output else None
),
)
symmetric_int8_config = get_symmetric_quantization_config(
is_per_channel=True,
is_qat=False,
is_dynamic=False,
act_qmin=-127,
act_qmax=127,
weight_qmin=-127,
weight_qmax=127,
)
# `set_io(...)` enables the base IO-annotation path. This wrapper then
# replaces that behavior in `_annotate_io(...)` using the caller-provided
# per-input and per-output qspecs, but the base quantizer still expects a
# real IO config here and may use it as a fallback/default.
quantizer.set_global(symmetric_int8_config).set_io(symmetric_int8_config)
for module_name, config in module_name_quant_configs.items():
quantizer.set_module_name(module_name, config)
quantized_graph = quantizer.quantize_with_submodules(
captured_graph,
calibration_samples=[example_inputs],
is_qat=False,
)
exported = torch.export.export(quantized_graph, example_inputs)
else:
exported = torch.export.export(captured_graph, example_inputs, strict=True)
from executorch.exir import to_edge_transform_and_lower
vgf_partitioner = VgfPartitionerWrapper(
VgfPartitioner(vgf_compile_spec),
quantized_input_idxs=quantized_input_idxs,
quantized_output_idxs=quantized_output_idxs,
)
grid_sampler_partitioner = GridSamplerOnlyPartitioner()
edge_program = to_edge_transform_and_lower(
exported,
transform_passes=[
_set_tensors_to_channels_last_pass, # Convert image tensors to channels-last after any RGB->RGBA padding.
],
partitioner=[vgf_partitioner],
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
# IMPORTANT TO DO SEPARATE LOWERING HERE, as the multi-partiioner logic is busted with op decomposition!
# It will decompose all the ops as our custom backend doesn't specify ops to not decompose, and then VGF backend doesn'#t support them (e.g. affine_grid -> linspace -> int64 ops)
edge_program = edge_program.to_backend(grid_sampler_partitioner)
if enable_pt2e_quantization:
io_quant_params = extract_io_quant_params(
edge_program,
input_idxs=quantized_input_idxs,
output_idxs=quantized_output_idxs,
)
else:
io_quant_params = {"inputs": {}, "outputs": {}}
_emit_export_reports(edge_program, verbose, io_quant_params)
return edge_program, example_inputs, io_quant_params