forked from zama-ai/concrete-ml
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompile.py
600 lines (530 loc) · 26.2 KB
/
compile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
"""torch compilation function."""
import tempfile
import warnings
from pathlib import Path
from typing import Dict, Optional, Sequence, Tuple, Union
import numpy
import onnx
import torch
from brevitas.export.onnx.qonnx.manager import QONNXManager as BrevitasONNXManager
from brevitas.nn.quant_layer import QuantInputOutputLayer as QNNMixingLayer
from brevitas.nn.quant_layer import QuantNonLinearActLayer as QNNUnivariateLayer
from concrete.fhe import ParameterSelectionStrategy
from concrete.fhe.compilation.artifacts import DebugArtifacts
from concrete.fhe.compilation.configuration import Configuration
from ..common.debugging import assert_false, assert_true
from ..common.utils import (
MAX_BITWIDTH_BACKWARD_COMPATIBLE,
check_there_is_no_p_error_options_in_configuration,
get_onnx_opset_version,
manage_parameters_for_pbs_errors,
process_rounding_threshold_bits,
to_tuple,
)
from ..onnx.convert import OPSET_VERSION_FOR_ONNX_EXPORT
from ..onnx.onnx_utils import remove_initializer_from_input
from ..quantization import PostTrainingAffineQuantization, PostTrainingQATImporter, QuantizedModule
from . import NumpyModule
Tensor = Union[torch.Tensor, numpy.ndarray]
Dataset = Union[Tensor, Tuple[Tensor, ...]]
def has_any_qnn_layers(torch_model: torch.nn.Module) -> bool:
"""Check if a torch model has QNN layers.
This is useful to check if a model is a QAT model.
Args:
torch_model (torch.nn.Module): a torch model
Returns:
bool: whether this torch model contains any QNN layer.
"""
return any(
isinstance(layer, (QNNMixingLayer, QNNUnivariateLayer)) for layer in torch_model.modules()
)
def convert_torch_tensor_or_numpy_array_to_numpy_array(
torch_tensor_or_numpy_array: Tensor,
) -> numpy.ndarray:
"""Convert a torch tensor or a numpy array to a numpy array.
Args:
torch_tensor_or_numpy_array (Tensor): the value that is either
a torch tensor or a numpy array.
Returns:
numpy.ndarray: the value converted to a numpy array.
"""
return (
torch_tensor_or_numpy_array
if isinstance(torch_tensor_or_numpy_array, numpy.ndarray)
else torch_tensor_or_numpy_array.cpu().numpy()
)
def build_quantized_module(
model: Union[torch.nn.Module, onnx.ModelProto],
torch_inputset: Dataset,
import_qat: bool = False,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None,
reduce_sum_copy=False,
) -> QuantizedModule:
"""Build a quantized module from a Torch or ONNX model.
Take a model in torch or ONNX, turn it to numpy, quantize its inputs / weights / outputs and
retrieve the associated quantized module.
Args:
model (Union[torch.nn.Module, onnx.ModelProto]): The model to quantize, either in torch or
in ONNX.
torch_inputset (Dataset): the calibration input-set, can contain either torch
tensors or numpy.ndarray
import_qat (bool): Flag to signal that the network being imported contains quantizers in
in its computation graph and that Concrete ML should not re-quantize it
n_bits: the number of bits for the quantization
rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): Defines precision
rounding for model accumulators. Accepts None, an int, or a dict.
The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE)
and 'n_bits' ('auto' or int)
reduce_sum_copy (bool): if the inputs of QuantizedReduceSum should be copied to avoid
bit-width propagation
Returns:
QuantizedModule: The resulting QuantizedModule.
"""
rounding_threshold_bits = process_rounding_threshold_bits(rounding_threshold_bits)
inputset_as_numpy_tuple = tuple(
convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset)
)
# Tracing needs to be done with the batch size of 1 since we compile our models to FHE with
# this batch size. The input set contains many examples, to determine a representative
# bit-width, but for tracing we only take a single one. We need the ONNX tracing batch size to
# match the batch size during FHE inference which can only be 1 for the moment.
# We need to convert float64 to float32 to avoid errors later in the onnx export.
# When we have integer inputs, we can keep them as integers.
dummy_input_for_tracing = tuple(
(
torch.from_numpy(val[[0], ::]).float()
if val.dtype == numpy.float64
else torch.from_numpy(val[[0], ::])
)
for val in inputset_as_numpy_tuple
)
# Create corresponding numpy model
numpy_model = NumpyModule(model, dummy_input_for_tracing)
# Quantize with post-training static method, to have a model with integer weights
post_training = PostTrainingQATImporter if import_qat else PostTrainingAffineQuantization
post_training_quant = post_training(n_bits, numpy_model, rounding_threshold_bits)
# Build the quantized module
# FIXME: mismatch here. We traced with dummy_input_for_tracing which made some operator
# only work over shape of (1, ., .). For example, some reshape have newshape hardcoded based
# on the inputset we sent in the NumpyModule.
quantized_module = post_training_quant.quantize_module(*inputset_as_numpy_tuple)
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4127
if reduce_sum_copy:
quantized_module.set_reduce_sum_copy()
return quantized_module
# pylint: disable-next=too-many-arguments
def _compile_torch_or_onnx_model(
model: Union[torch.nn.Module, onnx.ModelProto],
torch_inputset: Dataset,
import_qat: bool = False,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
reduce_sum_copy: bool = False,
composition_mapping: Optional[Dict] = None,
device: str = "cpu",
) -> QuantizedModule:
"""Compile a torch module or ONNX into an FHE equivalent.
Take a model in torch or ONNX, turn it to numpy, quantize its inputs / weights / outputs and
finally compile it with Concrete
Args:
model (Union[torch.nn.Module, onnx.ModelProto]): the model to quantize, either in torch or
in ONNX
torch_inputset (Dataset): the calibration input-set, can contain either torch
tensors or numpy.ndarray
import_qat (bool): Flag to signal that the network being imported contains quantizers in
in its computation graph and that Concrete ML should not re-quantize it
configuration (Configuration): Configuration object to use during compilation
artifacts (DebugArtifacts): Artifacts object to fill during compilation
show_mlir (bool): if set, the MLIR produced by the converter and which is going
to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo
n_bits (Union[int, Dict[str, int]]): number of bits for quantization, can be a single value
or a dictionary with the following keys :
- "op_inputs" and "op_weights" (mandatory)
- "model_inputs" and "model_outputs" (optional, default to 5 bits).
When using a single integer for n_bits, its value is assigned to "op_inputs" and
"op_weights" bits. Default is 8 bits.
rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): Defines precision
rounding for model accumulators. Accepts None, an int, or a dict.
The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE)
and 'n_bits' ('auto' or int)
p_error (Optional[float]): probability of error of a single PBS
global_p_error (Optional[float]): probability of error of the full circuit. In FHE
simulation `global_p_error` is set to 0
verbose (bool): whether to show compilation information
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted')
for each input. By default all arguments will be encrypted.
reduce_sum_copy (bool): if the inputs of QuantizedReduceSum should be copied to avoid
bit-width propagation
composition_mapping (Optional[Dict]): Dictionary that maps output positions with input
positions in the case of composable circuits. Setting this parameter triggers a
re-quantization step at the end of the FHE circuit. This makes sure outputs are
de-quantized using their output quantizer and then re-quantized using their associated
input quantizer. Default to None.
device: FHE compilation device, can be either 'cpu' or 'cuda'.
Returns:
QuantizedModule: The resulting compiled QuantizedModule.
Raises:
ValueError: If a input-output mapping ('composition_mapping') is set but composition is not
enabled at the Concrete level (in 'configuration').
"""
rounding_threshold_bits = process_rounding_threshold_bits(rounding_threshold_bits)
inputset_as_numpy_tuple = tuple(
convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset)
)
# Check that composition is enabled if an input-output mapping has been set
if composition_mapping is not None and (configuration is None or not configuration.composable):
raise ValueError(
"Composition must be enabled in 'configuration' in order to trigger a re-quantization "
"step on the circuit's outputs."
)
# Build the quantized module
quantized_module = build_quantized_module(
model=model,
torch_inputset=inputset_as_numpy_tuple,
import_qat=import_qat,
n_bits=n_bits,
rounding_threshold_bits=rounding_threshold_bits,
reduce_sum_copy=reduce_sum_copy,
)
# Check that p_error or global_p_error is not set in both the configuration and in the direct
# parameters
check_there_is_no_p_error_options_in_configuration(configuration)
if (
rounding_threshold_bits is not None
and configuration is not None
and configuration.parameter_selection_strategy != ParameterSelectionStrategy.MULTI
):
warnings.warn(
"It is recommended to set the optimization strategy to multi-parameter when using "
"rounding as it should provide better performance.",
stacklevel=2,
)
# Find the right way to set parameters for compiler, depending on the way we want to default
p_error, global_p_error = manage_parameters_for_pbs_errors(p_error, global_p_error)
# If a mapping between input and output quantizers is set, add a re-quantization step at the
# end of the forward call. This is only useful for composable circuits in order to make sure
# that input and output quantizers match
if composition_mapping is not None:
# pylint: disable-next=protected-access
quantized_module._add_requant_for_composition(composition_mapping)
quantized_module.compile(
inputset_as_numpy_tuple,
configuration,
artifacts,
show_mlir=show_mlir,
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
device=device,
)
return quantized_module
# pylint: disable-next=too-many-arguments
def compile_torch_model(
torch_model: torch.nn.Module,
torch_inputset: Dataset,
import_qat: bool = False,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
reduce_sum_copy: bool = False,
device: str = "cpu",
) -> QuantizedModule:
"""Compile a torch module into an FHE equivalent.
Take a model in torch, turn it to numpy, quantize its inputs / weights / outputs and finally
compile it with Concrete
Args:
torch_model (torch.nn.Module): the model to quantize
torch_inputset (Dataset): the calibration input-set, can contain either torch
tensors or numpy.ndarray.
import_qat (bool): Set to True to import a network that contains quantizers and was
trained using quantization aware training
configuration (Configuration): Configuration object to use
during compilation
artifacts (DebugArtifacts): Artifacts object to fill
during compilation
show_mlir (bool): if set, the MLIR produced by the converter and which is going
to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo
n_bits (Union[int, Dict[str, int]]): number of bits for quantization, can be a single value
or a dictionary with the following keys :
- "op_inputs" and "op_weights" (mandatory)
- "model_inputs" and "model_outputs" (optional, default to 5 bits).
When using a single integer for n_bits, its value is assigned to "op_inputs" and
"op_weights" bits. Default is 8 bits.
rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): Defines precision
rounding for model accumulators. Accepts None, an int, or a dict.
The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE)
and 'n_bits' ('auto' or int)
p_error (Optional[float]): probability of error of a single PBS
global_p_error (Optional[float]): probability of error of the full circuit. In FHE
simulation `global_p_error` is set to 0
verbose (bool): whether to show compilation information
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted')
for each input. By default all arguments will be encrypted.
reduce_sum_copy (bool): if the inputs of QuantizedReduceSum should be copied to avoid
bit-width propagation
device: FHE compilation device, can be either 'cpu' or 'cuda'.
Returns:
QuantizedModule: The resulting compiled QuantizedModule.
"""
assert_true(
isinstance(torch_model, torch.nn.Module),
"The compile_torch_model function must be called on a torch.nn.Module",
)
assert_false(
has_any_qnn_layers(torch_model),
"The compile_torch_model was called on a torch.nn.Module that contains "
"Brevitas quantized layers. These models must be imported "
"using compile_brevitas_qat_model instead.",
)
return _compile_torch_or_onnx_model(
torch_model,
torch_inputset,
import_qat,
configuration=configuration,
artifacts=artifacts,
show_mlir=show_mlir,
n_bits=n_bits,
rounding_threshold_bits=rounding_threshold_bits,
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
reduce_sum_copy=reduce_sum_copy,
device=device,
)
# pylint: disable-next=too-many-arguments
def compile_onnx_model(
onnx_model: onnx.ModelProto,
torch_inputset: Dataset,
import_qat: bool = False,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
reduce_sum_copy: bool = False,
device: str = "cpu",
) -> QuantizedModule:
"""Compile a torch module into an FHE equivalent.
Take a model in torch, turn it to numpy, quantize its inputs / weights / outputs and finally
compile it with Concrete-Python
Args:
onnx_model (onnx.ModelProto): the model to quantize
torch_inputset (Dataset): the calibration input-set, can contain either torch
tensors or numpy.ndarray.
import_qat (bool): Flag to signal that the network being imported contains quantizers in
in its computation graph and that Concrete ML should not re-quantize it.
configuration (Configuration): Configuration object to use
during compilation
artifacts (DebugArtifacts): Artifacts object to fill
during compilation
show_mlir (bool): if set, the MLIR produced by the converter and which is going
to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo
n_bits (Union[int, Dict[str, int]]): number of bits for quantization, can be a single value
or a dictionary with the following keys :
- "op_inputs" and "op_weights" (mandatory)
- "model_inputs" and "model_outputs" (optional, default to 5 bits).
When using a single integer for n_bits, its value is assigned to "op_inputs" and
"op_weights" bits. Default is 8 bits.
rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): Defines precision
rounding for model accumulators. Accepts None, an int, or a dict.
The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE)
and 'n_bits' ('auto' or int)
p_error (Optional[float]): probability of error of a single PBS
global_p_error (Optional[float]): probability of error of the full circuit. In FHE
simulation `global_p_error` is set to 0
verbose (bool): whether to show compilation information
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted')
for each input. By default all arguments will be encrypted.
reduce_sum_copy (bool): if the inputs of QuantizedReduceSum should be copied to avoid
bit-width propagation
device: FHE compilation device, can be either 'cpu' or 'cuda'.
Returns:
QuantizedModule: The resulting compiled QuantizedModule.
"""
onnx_model_opset_version = get_onnx_opset_version(onnx_model)
assert_true(
onnx_model_opset_version == OPSET_VERSION_FOR_ONNX_EXPORT,
f"ONNX version must be {OPSET_VERSION_FOR_ONNX_EXPORT} "
f"but it is {onnx_model_opset_version}",
)
return _compile_torch_or_onnx_model(
onnx_model,
torch_inputset,
import_qat,
configuration=configuration,
artifacts=artifacts,
show_mlir=show_mlir,
n_bits=n_bits,
rounding_threshold_bits=rounding_threshold_bits,
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
reduce_sum_copy=reduce_sum_copy,
device=device,
)
# pylint: disable-next=too-many-arguments
def compile_brevitas_qat_model(
torch_model: torch.nn.Module,
torch_inputset: Dataset,
n_bits: Optional[Union[int, Dict[str, int]]] = None,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
output_onnx_file: Union[None, Path, str] = None,
verbose: bool = False,
inputs_encryption_status: Optional[Sequence[str]] = None,
reduce_sum_copy: bool = False,
device: str = "cpu",
) -> QuantizedModule:
"""Compile a Brevitas Quantization Aware Training model.
The torch_model parameter is a subclass of torch.nn.Module that uses quantized
operations from brevitas.qnn. The model is trained before calling this function. This
function compiles the trained model to FHE.
Args:
torch_model (torch.nn.Module): the model to quantize
torch_inputset (Dataset): the calibration input-set, can contain either torch
tensors or numpy.ndarray.
n_bits (Optional[Union[int, dict]): the number of bits for the quantization. By default,
for most models, a value of None should be given, which instructs Concrete ML to use the
bit-widths configured using Brevitas quantization options. For some networks, that
perform a non-linear operation on an input on an output, if None is given, a default
value of 8 bits is used for the input/output quantization. For such models the user can
also specify a dictionary with model_inputs/model_outputs keys to override
the 8-bit default or a single integer for both values.
configuration (Configuration): Configuration object to use
during compilation
artifacts (DebugArtifacts): Artifacts object to fill
during compilation
show_mlir (bool): if set, the MLIR produced by the converter and which is going
to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo
rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): Defines precision
rounding for model accumulators. Accepts None, an int, or a dict.
The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE)
and 'n_bits' ('auto' or int)
p_error (Optional[float]): probability of error of a single PBS
global_p_error (Optional[float]): probability of error of the full circuit. In FHE
simulation `global_p_error` is set to 0
output_onnx_file (str): temporary file to store ONNX model. If None a temporary file
is generated
verbose (bool): whether to show compilation information
inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted')
for each input. By default all arguments will be encrypted.
reduce_sum_copy (bool): if the inputs of QuantizedReduceSum should be copied to avoid
bit-width propagation
device: FHE compilation device, can be either 'cpu' or 'cuda'.
Returns:
QuantizedModule: The resulting compiled QuantizedModule.
"""
inputset_as_numpy_tuple = tuple(
convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset)
)
dummy_input_for_tracing = tuple(
torch.from_numpy(val[[0], ::]).float() for val in inputset_as_numpy_tuple
)
output_onnx_file_path = Path(
tempfile.mkstemp(suffix=".onnx")[1] if output_onnx_file is None else output_onnx_file
)
use_tempfile: bool = output_onnx_file is None
assert_true(
has_any_qnn_layers(torch_model),
"The compile_brevitas_qat_model was called on a torch.nn.Module that contains "
"no Brevitas quantized layers, consider using compile_torch_model instead",
)
# Brevitas to ONNX
exporter = BrevitasONNXManager()
# Here we add a "eliminate_nop_pad" optimization step for onnxoptimizer
# https://github.com/onnx/optimizer/blob/master/onnxoptimizer/passes/eliminate_nop_pad.h#L5
# It deletes 0-values padding.
# This is needed because AvgPool2d adds a 0-Pad operation that then breaks the compilation
# A list of steps that can be added can be found in the following link
# https://github.com/onnx/optimizer/blob/master/onnxoptimizer/pass_registry.h
# In the export function, the `args` parameter is used instead of the `input_shape` one in
# order to be able to handle multi-inputs models
exporter.onnx_passes += [
"eliminate_nop_pad",
"fuse_pad_into_conv",
"fuse_matmul_add_bias_into_gemm",
]
onnx_model = exporter.export(
torch_model,
args=dummy_input_for_tracing,
export_path=str(output_onnx_file_path),
keep_initializers_as_inputs=False,
opset_version=OPSET_VERSION_FOR_ONNX_EXPORT,
)
onnx_model = remove_initializer_from_input(onnx_model)
if n_bits is None:
n_bits = {
"model_inputs": 8,
"op_weights": 8,
"op_inputs": 8,
"model_outputs": 8,
}
elif isinstance(n_bits, int):
n_bits = {
"model_inputs": n_bits,
"op_weights": n_bits,
"op_inputs": n_bits,
"model_outputs": n_bits,
}
elif isinstance(n_bits, dict):
assert_true(
set(n_bits.keys()) == {"model_inputs", "model_outputs"},
"When importing a Brevitas QAT network, n_bits should only contain the following keys: "
f'"model_inputs", "model_outputs". Instead, got {n_bits.keys()}',
)
n_bits = {
"model_inputs": n_bits["model_inputs"],
"op_weights": n_bits["model_inputs"],
"op_inputs": n_bits["model_inputs"],
"model_outputs": n_bits["model_outputs"],
}
assert_true(
n_bits is None or isinstance(n_bits, (int, dict)),
"The n_bits parameter must be either a dictionary, an integer or None",
)
# Compile using the ONNX conversion flow, in QAT mode
q_module = compile_onnx_model(
onnx_model,
torch_inputset,
n_bits=n_bits,
import_qat=True,
artifacts=artifacts,
show_mlir=show_mlir,
rounding_threshold_bits=rounding_threshold_bits,
configuration=configuration,
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
inputs_encryption_status=inputs_encryption_status,
reduce_sum_copy=reduce_sum_copy,
device=device,
)
# Remove the tempfile if we used one
if use_tempfile:
output_onnx_file_path.unlink()
return q_module