-
Notifications
You must be signed in to change notification settings - Fork 159
/
Copy pathbase.py
2421 lines (1876 loc) · 94.9 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Base classes for all estimators."""
from __future__ import annotations
import copy
import os
import tempfile
# Disable pylint as some names like X and q_X are used, following scikit-Learn's standard. The file
# is also more than 1000 lines long.
# pylint: disable=too-many-lines,invalid-name
import warnings
from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, TextIO, Type, Union
import brevitas.nn as qnn
# pylint: disable-next=ungrouped-imports
import concrete.fhe as cp
import numpy
import onnx
import sklearn
import skorch.net
import torch
from brevitas.export.onnx.qonnx.manager import QONNXManager as BrevitasONNXManager
from concrete.fhe.compilation.artifacts import DebugArtifacts
from concrete.fhe.compilation.circuit import Circuit
from concrete.fhe.compilation.compiler import Compiler
from concrete.fhe.compilation.configuration import Configuration
from concrete.fhe.dtypes.integer import Integer
from sklearn.base import clone
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.utils.validation import check_is_fitted
from xgboost.sklearn import XGBModel
from ..common.check_inputs import check_array_and_assert, check_X_y_and_assert_multi_output
from ..common.debugging.custom_assert import assert_true
from ..common.serialization.dumpers import dump, dumps
from ..common.utils import (
USE_OLD_VL,
FheMode,
check_compilation_device_is_valid_and_is_cuda,
check_execution_device_is_valid_and_is_cuda,
check_there_is_no_p_error_options_in_configuration,
generate_proxy_function,
manage_parameters_for_pbs_errors,
)
from ..onnx.convert import OPSET_VERSION_FOR_ONNX_EXPORT
from ..onnx.onnx_model_manipulations import clean_graph_after_node_op_type, remove_node_types
# The sigmoid and softmax functions are already defined in the ONNX module and thus are imported
# here in order to avoid duplicating them.
from ..onnx.ops_impl import numpy_sigmoid, numpy_softmax
from ..quantization import (
PostTrainingQATImporter,
QuantizedArray,
_get_n_bits_dict_trees,
_inspect_tree_n_bits,
get_n_bits_dict,
)
from ..quantization.quantized_module import QuantizedModule, _get_inputset_generator
from ..quantization.quantizers import (
QuantizationOptions,
UniformQuantizationParameters,
UniformQuantizer,
)
from ..torch import NumpyModule
from .qnn_module import SparseQuantNeuralNetwork
from .tree_to_numpy import (
get_equivalent_numpy_forward_from_onnx_tree,
get_onnx_model,
is_regressor_or_partial_regressor,
onnx_fp32_model_to_quantized_model,
tree_to_numpy,
)
# Disable pylint to import Hummingbird while ignoring the warnings
# pylint: disable=wrong-import-position,wrong-import-order
# Silence Hummingbird warnings
warnings.filterwarnings("ignore")
from hummingbird.ml import convert as hb_convert # noqa: E402
from hummingbird.ml.operator_converters import constants # noqa: E402
_ALL_SKLEARN_MODELS: Set[Type] = set()
_LINEAR_MODELS: Set[Type] = set()
_TREE_MODELS: Set[Type] = set()
_NEURALNET_MODELS: Set[Type] = set()
_NEIGHBORS_MODELS: Set[Type] = set()
# Define the supported types for both the input data and the target values. Since the Pandas
# library is currently only a dev dependencies, we cannot import it. We therefore need to use type
# strings and the `name-defined` mypy error to do so.
Data = Union[
numpy.ndarray,
torch.Tensor,
"pandas.DataFrame", # type: ignore[name-defined] # noqa: F821
List,
]
Target = Union[
numpy.ndarray,
torch.Tensor,
"pandas.DataFrame", # type: ignore[name-defined] # noqa: F821
"pandas.Series", # type: ignore[name-defined] # noqa: F821
List,
]
# Define QNN's attribute that will be auto-generated when fitting
QNN_AUTO_KWARGS = ["module__n_outputs", "module__input_dim"]
# Enable rounding feature for all tree-based models by default
# Note: This setting is fixed and cannot be altered by users
# However, for internal testing purposes, we retain the capability to disable this feature
os.environ["TREES_USE_ROUNDING"] = os.environ.get("TREES_USE_ROUNDING", "1")
# pylint: disable=too-many-public-methods
class BaseEstimator:
"""Base class for all estimators in Concrete ML.
This class does not inherit from sklearn.base.BaseEstimator as it creates some conflicts
with skorch in QuantizedTorchEstimatorMixin's subclasses (more specifically, the `get_params`
method is not properly inherited).
Attributes:
_is_a_public_cml_model (bool): Private attribute indicating if the class is a public model
(as opposed to base or mixin classes).
"""
#: Base float estimator class to consider for the model. Is set for each subclasses.
sklearn_model_class: Type
_is_a_public_cml_model: bool = False
def __init__(self):
"""Initialize the base class with common attributes used in all estimators.
An underscore "_" is appended to attributes that were created while fitting the model. This
is done in order to follow scikit-Learn's standard format. More information available
in their documentation:
https://scikit-learn.org/stable/developers/develop.html#:~:text=Estimated%20Attributes%C2%B6
"""
#: The equivalent fitted float model. Is None if the model is not fitted.
self.sklearn_model: Optional[sklearn.base.BaseEstimator] = None
#: The list of quantizers, which contain all information necessary for applying uniform
#: quantization to inputs and provide quantization/de-quantization functionalities. Is empty
#: if the model is not fitted
self.input_quantizers: List[UniformQuantizer] = []
#: The list of quantizers, which contain all information necessary for applying uniform
#: quantization to outputs and provide quantization/de-quantization functionalities. Is
#: empty if the model is not fitted
self.output_quantizers: List[UniformQuantizer] = []
#: The parameters needed for post-processing the outputs.
#: Can be empty if no post-processing operations are needed for the associated model
#: This attribute is typically used for serving models
self.post_processing_params: Dict[str, Any] = {}
#: Indicate if the model is fitted.
self._is_fitted: bool = False
#: Indicate if the model is compiled.
self._is_compiled: bool = False
self._compiled_for_cuda: bool = False
self.fhe_circuit_: Optional[Circuit] = None
self.onnx_model_: Optional[onnx.ModelProto] = None
def __getattr__(self, attr: str):
"""Get the model's attribute.
If the attribute's name ends with an underscore ("_"), get the attribute from the underlying
scikit-learn model as it represents a training attribute:
https://scikit-learn.org/stable/glossary.html#term-attributes
This method is only called if the attribute has not been found in the class instance:
https://docs.python.org/3/reference/datamodel.html?highlight=getattr#object.__getattr__
Args:
attr (str): The attribute's name.
Returns:
The attribute value.
Raises:
AttributeError: If the attribute cannot be found or is not a training attribute.
"""
# If the attribute ends with a single underscore and can be found in the underlying
# scikit-learn model (once fitted), retrieve its value
# Enable non-training attributes as well once Concrete ML models initialize their
# underlying scikit-learn models during initialization
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3373
if (
attr.endswith("_")
and not attr.endswith("__")
and getattr(self, "sklearn_model", None) is not None
):
return getattr(self.sklearn_model, attr)
raise AttributeError(
f"Attribute {attr} cannot be found in the Concrete ML {self.__class__.__name__} object "
f"and is not a training attribute from the underlying scikit-learn "
f"{self.sklearn_model_class} one. If the attribute is meant to represent one from that "
f"latter, please check that the model is properly fitted."
)
# We need to specifically call the default __setattr__ method as QNN models still inherit from
# skorch, which provides its own __setattr__ implementation and creates a cyclic loop
# with __getattr__. Removing this inheritance once and for all should fix the issue
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3373
def __setattr__(self, name: str, value: Any):
"""Set the value as a model attribute.
Args:
name (str): The attribute's name to consider.
value (Any): The attribute's value to consider.
"""
object.__setattr__(self, name, value)
@abstractmethod
def dump_dict(self) -> Dict[str, Any]:
"""Dump the object as a dict.
Returns:
Dict[str, Any]: Dict of serialized objects.
"""
@classmethod
@abstractmethod
def load_dict(cls, metadata: Dict[str, Any]) -> BaseEstimator:
"""Load itself from a dict.
Args:
metadata (Dict[str, Any]): Dict of serialized objects.
Returns:
BaseEstimator: The loaded object.
"""
def dumps(self) -> str:
"""Dump itself to a string.
Returns:
metadata (str): String of the serialized object.
"""
return dumps(self)
def dump(self, file: TextIO) -> None:
"""Dump itself to a file.
Args:
file (TextIO): The file to dump the serialized object into.
"""
dump(self, file)
@property
def onnx_model(self) -> Optional[onnx.ModelProto]:
"""Get the ONNX model.
Is None if the model is not fitted.
Returns:
onnx.ModelProto: The ONNX model.
"""
assert isinstance(self.onnx_model_, onnx.ModelProto) or self.onnx_model_ is None
return self.onnx_model_
@property
def fhe_circuit(self) -> Optional[Circuit]:
"""Get the FHE circuit.
The FHE circuit combines computational graph, mlir, client and server into a single object.
More information available in Concrete documentation
(https://docs.zama.ai/concrete/get-started/terminology)
Is None if the model is not fitted.
Returns:
Circuit: The FHE circuit.
"""
assert isinstance(self.fhe_circuit_, Circuit) or self.fhe_circuit_ is None
return self.fhe_circuit_
def _sklearn_model_is_not_fitted_error_message(self) -> str:
return (
f"The underlying model (class: {self.sklearn_model_class}) is not fitted and thus "
"cannot be quantized."
) # pragma: no cover
@property
def is_fitted(self) -> bool:
"""Indicate if the model is fitted.
Returns:
bool: If the model is fitted.
"""
return self._is_fitted
def _is_not_fitted_error_message(self) -> str:
return (
f"The {self.__class__.__name__} model is not fitted. "
"Please run fit(...) on proper arguments first."
)
def check_model_is_fitted(self) -> None:
"""Check if the model is fitted.
Raises:
AttributeError: If the model is not fitted.
"""
if not self.is_fitted:
raise AttributeError(self._is_not_fitted_error_message())
@property
def is_compiled(self) -> bool:
"""Indicate if the model is compiled.
Returns:
bool: If the model is compiled.
"""
return self._is_compiled
def _is_not_compiled_error_message(self) -> str:
return (
f"The {self.__class__.__name__} model is not compiled. "
"Please run compile(...) first before executing the prediction in FHE."
)
def check_model_is_compiled(self) -> None:
"""Check if the model is compiled.
Raises:
AttributeError: If the model is not compiled.
"""
if not self.is_compiled:
raise AttributeError(self._is_not_compiled_error_message())
def get_sklearn_params(self, deep: bool = True) -> dict:
"""Get parameters for this estimator.
This method is used to instantiate a scikit-learn model using the Concrete ML model's
parameters. It does not override scikit-learn's existing `get_params` method in order to
not break its implementation of `set_params`.
Args:
deep (bool): If True, will return the parameters for this estimator and contained
subobjects that are estimators. Default to True.
Returns:
params (dict): Parameter names mapped to their values.
"""
# Here, the `get_params` method is the `BaseEstimator.get_params` method from scikit-learn,
# which will become available once a subclass inherits from it. We therefore disable both
# pylint and mypy as this behavior is expected
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3373
# pylint: disable-next=no-member
params = super().get_params(deep=deep) # type: ignore[misc]
# Remove the n_bits parameters as this attribute is added by Concrete ML
params.pop("n_bits", None)
return params
def _set_post_processing_params(self) -> None:
"""Set parameters used in post-processing."""
self.post_processing_params = {}
def _fit_sklearn_model(self, X: Data, y: Target, **fit_parameters):
"""Fit the model's scikit-learn equivalent estimator.
Args:
X (Data): The training data, as a Numpy array, Torch tensor, Pandas DataFrame or List.
y (Target): The target data, as a Numpy array, Torch tensor, Pandas DataFrame, Pandas
Series or List.
**fit_parameters: Keyword arguments to pass to the scikit-learn estimator's fit method.
Returns:
The fitted scikit-learn estimator.
"""
# Initialize the underlying scikit-learn model if it has not already been done or if
# `warm_start` is set to False (for neural networks)
# This model should be directly initialized in the model's __init__ method instead
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3373
if self.sklearn_model is None or not getattr(self, "warm_start", False):
# Retrieve the init parameters
params = self.get_sklearn_params()
self.sklearn_model = self.sklearn_model_class(**params)
# Fit the scikit-learn model
self.sklearn_model.fit(X, y, **fit_parameters)
return self.sklearn_model
@abstractmethod
def fit(self, X: Data, y: Target, **fit_parameters):
"""Fit the estimator.
This method trains a scikit-learn estimator, computes its ONNX graph and defines the
quantization parameters needed for proper FHE inference.
Args:
X (Data): The training data, as a Numpy array, Torch tensor, Pandas DataFrame or List.
y (Target): The target data, as a Numpy array, Torch tensor, Pandas DataFrame, Pandas
Series or List.
**fit_parameters: Keyword arguments to pass to the float estimator's fit method.
Returns:
The fitted estimator.
"""
# Several attributes and methods are called in `fit_benchmark` but will only be accessible
# in subclasses, we therefore need to disable pylint and mypy from checking these no-member
# issues
# pylint: disable=no-member
def fit_benchmark(
self,
X: Data,
y: Target,
random_state: Optional[int] = None,
**fit_parameters,
):
"""Fit both the Concrete ML and its equivalent float estimators.
Args:
X (Data): The training data, as a Numpy array, Torch tensor, Pandas DataFrame or List.
y (Target): The target data, as a Numpy array, Torch tensor, Pandas DataFrame, Pandas
Series or List.
random_state (Optional[int]): The random state to use when fitting. Defaults to None.
**fit_parameters: Keyword arguments to pass to the float estimator's fit method.
Returns:
The Concrete ML and float equivalent fitted estimators.
"""
# Retrieve sklearn's init parameters
params = self.get_sklearn_params()
# Make sure the random_state is set or both algorithms will diverge
# due to randomness in the training.
if "random_state" in params:
if random_state is not None:
params["random_state"] = random_state
elif getattr(self, "random_state", None) is not None:
# Disable mypy attribute definition errors as it does not seem to see that we make
# sure this attribute actually exists before calling it
params["random_state"] = self.random_state # type: ignore[attr-defined]
else:
params["random_state"] = numpy.random.randint(0, 2**15)
# Initialize the scikit-learn model
sklearn_model = self.sklearn_model_class(**params)
# Train the scikit-learn model
sklearn_model.fit(X, y, **fit_parameters)
# Update the Concrete ML model's parameters
# Disable mypy attribute definition errors as this attribute is expected to be
# initialized once the model inherits from skorch
self.set_params(n_bits=self.n_bits, **params) # type: ignore[attr-defined]
# Train the Concrete ML model
self.fit(X, y, **fit_parameters)
return self, sklearn_model
# pylint: enable=no-member
@abstractmethod
def quantize_input(self, X: numpy.ndarray) -> numpy.ndarray:
"""Quantize the input.
This step ensures that the fit method has been called.
Args:
X (numpy.ndarray): The input values to quantize.
Returns:
numpy.ndarray: The quantized input values.
"""
@abstractmethod
def dequantize_output(self, q_y_preds: numpy.ndarray) -> numpy.ndarray:
"""De-quantize the output.
This step ensures that the fit method has been called.
Args:
q_y_preds (numpy.ndarray): The quantized output values to de-quantize.
Returns:
numpy.ndarray: The de-quantized output values.
"""
@abstractmethod
def _get_module_to_compile(self) -> Union[Compiler, QuantizedModule]:
"""Retrieve the module instance to compile.
Returns:
Union[Compiler, QuantizedModule]: The module instance to compile.
"""
def compile(
self,
X: Data,
configuration: Optional[Configuration] = None,
artifacts: Optional[DebugArtifacts] = None,
show_mlir: bool = False,
p_error: Optional[float] = None,
global_p_error: Optional[float] = None,
verbose: bool = False,
device: str = "cpu",
) -> Circuit:
"""Compile the model.
Args:
X (Data): A representative set of input values used for building cryptographic
parameters, as a Numpy array, Torch tensor, Pandas DataFrame or List. This is
usually the training data-set or a sub-set of it.
configuration (Optional[Configuration]): Options to use for compilation. Default
to None.
artifacts (Optional[DebugArtifacts]): Artifacts information about the compilation
process to store for debugging. Default to None.
show_mlir (bool): Indicate if the MLIR graph should be printed during compilation.
Default to False.
p_error (Optional[float]): Probability of error of a single PBS. A p_error value cannot
be given if a global_p_error value is already set. Default to None, which sets this
error to a default value.
global_p_error (Optional[float]): Probability of error of the full circuit. A
global_p_error value cannot be given if a p_error value is already set. This feature
is not supported during the FHE simulation mode, meaning the probability is
currently set to 0. Default to None, which sets this error to a default value.
verbose (bool): Indicate if compilation information should be printed
during compilation. Default to False.
device: FHE compilation device, can be either 'cpu' or 'cuda'.
Returns:
Circuit: The compiled Circuit.
"""
# Reset for double compile
self._is_compiled = False
# Check that the model is correctly fitted
self.check_model_is_fitted()
# Cast pandas, list or torch to numpy
X = check_array_and_assert(X)
# p_error or global_p_error should not be set in both the configuration and direct arguments
check_there_is_no_p_error_options_in_configuration(configuration)
# Find the right way to set parameters for compiler, depending on the way we want to default
p_error, global_p_error = manage_parameters_for_pbs_errors(p_error, global_p_error)
use_gpu = check_compilation_device_is_valid_and_is_cuda(device)
# Quantize the inputs
q_X = self.quantize_input(X)
# Generate the compilation input-set with proper dimensions
inputset = _get_inputset_generator(q_X)
# Retrieve the compiler instance
module_to_compile = self._get_module_to_compile()
# Compiling using a QuantizedModule requires different steps and should not be done here
assert isinstance(module_to_compile, Compiler), (
"Wrong module to compile. Expected to be of type `Compiler` but got "
f"{type(module_to_compile)}."
)
# Enable input ciphertext compression
enable_input_compression = os.environ.get("USE_INPUT_COMPRESSION", "1") == "1"
# Enable evaluation key compression
enable_key_compression = os.environ.get("USE_KEY_COMPRESSION", "1") == "1"
self.fhe_circuit_ = module_to_compile.compile(
inputset,
configuration=configuration,
artifacts=artifacts,
show_mlir=show_mlir,
p_error=p_error,
global_p_error=global_p_error,
verbose=verbose,
single_precision=False,
use_gpu=use_gpu,
compress_input_ciphertexts=enable_input_compression,
compress_evaluation_keys=enable_key_compression,
)
self._is_compiled = True
self._compiled_for_cuda = use_gpu
# For mypy
assert isinstance(self.fhe_circuit, Circuit)
return self.fhe_circuit
@abstractmethod
def _inference(self, q_X: numpy.ndarray) -> numpy.ndarray:
"""Inference function to consider when executing in the clear.
Args:
q_X (numpy.ndarray): The quantized input values.
Returns:
numpy.ndarray: The quantized predicted values.
"""
def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.ndarray:
"""Predict values for X, in FHE or in the clear.
Args:
X (Data): The input values to predict, as a Numpy array, Torch tensor, Pandas DataFrame
or List.
fhe (Union[FheMode, str]): The mode to use for prediction.
Can be FheMode.DISABLE for Concrete ML Python inference,
FheMode.SIMULATE for FHE simulation and FheMode.EXECUTE for actual FHE execution.
Can also be the string representation of any of these values.
Default to FheMode.DISABLE.
Returns:
np.ndarray: The predicted values for X.
"""
assert_true(
FheMode.is_valid(fhe),
"`fhe` mode is not supported. Expected one of 'disable' (resp. FheMode.DISABLE), "
"'simulate' (resp. FheMode.SIMULATE) or 'execute' (resp. FheMode.EXECUTE). Got "
f"{fhe}",
ValueError,
)
# Check that the model is properly fitted
self.check_model_is_fitted()
check_execution_device_is_valid_and_is_cuda(self._compiled_for_cuda, fhe=fhe)
# Ensure inputs are 2D
if isinstance(X, (numpy.ndarray, torch.Tensor)) and X.ndim == 1:
X = X.reshape((1, -1))
# Check that X's type and shape are supported
X = check_array_and_assert(X)
# Quantize the input
q_X = self.quantize_input(X)
# If the inference is executed in FHE or simulation mode
if fhe in ["simulate", "execute"]:
# Check that the model is properly compiled
self.check_model_is_compiled()
q_y_pred_list = []
for q_X_i in q_X:
# Expected encrypt_run_decrypt output shape is (1, n_features) while q_X_i
# is of shape (n_features,)
q_X_i = numpy.expand_dims(q_X_i, 0)
# For mypy, even though we already check this with self.check_model_is_compiled()
assert self.fhe_circuit is not None
# If the inference should be executed using simulation
if fhe == "simulate":
is_crt_encoding = self.fhe_circuit.statistics["packing_key_switch_count"] != 0
# If the virtual library method should be used
# For now, use the virtual library when simulating
# circuits that use CRT encoding because the official simulation is too slow
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4391
if USE_OLD_VL or is_crt_encoding:
predict_method = partial(
self.fhe_circuit.graph, p_error=self.fhe_circuit.p_error
) # pragma: no cover
# Else, use the official simulation method
else:
predict_method = self.fhe_circuit.simulate
# Else, use the FHE execution method
else:
predict_method = self.fhe_circuit.encrypt_run_decrypt
# Execute the inference in FHE or with simulation
q_y_pred_i = predict_method(q_X_i)
assert isinstance(q_y_pred_i, numpy.ndarray)
q_y_pred_list.append(q_y_pred_i[0])
q_y_pred = numpy.array(q_y_pred_list)
# Else, the prediction is simulated in the clear
else:
q_y_pred = self._inference(q_X)
# De-quantize the predicted values in the clear
y_pred = self.dequantize_output(q_y_pred)
assert isinstance(y_pred, numpy.ndarray)
return y_pred
def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:
"""Apply post-processing to the de-quantized predictions.
This post-processing step can include operations such as applying the sigmoid or softmax
function for classifiers, or summing an ensemble's outputs. These steps are done in the
clear because of current technical constraints. They most likely will be integrated in the
FHE computations in the future.
For some simple models such a linear regression, there is no post-processing step but the
method is kept to make the API consistent for the client-server API. Other models might
need to use attributes stored in `post_processing_params`.
Args:
y_preds (numpy.ndarray): The de-quantized predictions to post-process.
Returns:
numpy.ndarray: The post-processed predictions.
"""
assert isinstance(y_preds, numpy.ndarray), "Output predictions must be an array."
return y_preds
# This class only is an equivalent of BaseEstimator applied to classifiers, therefore not all
# methods are implemented and we need to disable pylint from checking that
# pylint: disable-next=abstract-method
class BaseClassifier(BaseEstimator):
"""Base class for linear and tree-based classifiers in Concrete ML.
This class inherits from BaseEstimator and modifies some of its methods in order to align them
with classifier behaviors. This notably include applying a sigmoid/softmax post-processing to
the predicted values as well as handling a mapping of classes in case they are not ordered.
"""
# Remove in our next release major release
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3994
@property
def target_classes_(self) -> Optional[numpy.ndarray]: # pragma: no cover
"""Get the model's classes.
Using this attribute is deprecated.
Returns:
Optional[numpy.ndarray]: The model's classes.
"""
warnings.warn(
"Attribute 'target_classes_' is now deprecated and will be removed in a future "
"version. Please use 'classes_' instead.",
category=UserWarning,
stacklevel=2,
)
return self.classes_
# Remove in our next release major release
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3994
@property
def n_classes_(self) -> int: # pragma: no cover
"""Get the model's number of classes.
Using this attribute is deprecated.
Returns:
int: The model's number of classes.
"""
# Tree-based classifiers from scikit-learn provide a `n_classes_` attribute
if self.sklearn_model is not None and hasattr(self.sklearn_model, "n_classes_"):
return self.sklearn_model.n_classes_
warnings.warn(
"Attribute 'n_classes_' is now deprecated and will be removed in a future version. "
"Please use 'len(classes_)' instead.",
category=UserWarning,
stacklevel=2,
)
return len(self.classes_)
def fit(self, X: Data, y: Target, **fit_parameters):
X, y = check_X_y_and_assert_multi_output(X, y)
# Retrieve the different target classes
classes = numpy.unique(y)
# Make sure y contains at least two classes
assert_true(len(classes) > 1, "You must provide at least 2 classes in y.")
# Change to composition in order to avoid diamond inheritance and indirect super() calls
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3249
return super().fit(X, y, **fit_parameters) # type: ignore[safe-super]
def predict_proba(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.ndarray:
"""Predict class probabilities.
Args:
X (Data): The input values to predict, as a Numpy array, Torch tensor, Pandas DataFrame
or List.
fhe (Union[FheMode, str]): The mode to use for prediction.
Can be FheMode.DISABLE for Concrete ML Python inference,
FheMode.SIMULATE for FHE simulation and FheMode.EXECUTE for actual FHE execution.
Can also be the string representation of any of these values.
Default to FheMode.DISABLE.
Returns:
numpy.ndarray: The predicted class probabilities.
"""
return super().predict(X, fhe=fhe)
def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.ndarray:
# Compute the predicted probabilities
y_proba = self.predict_proba(X, fhe=fhe)
# Retrieve the class with the highest probability
y_preds = numpy.argmax(y_proba, axis=1)
return self.classes_[y_preds]
def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:
y_preds = super().post_processing(y_preds)
# If the prediction array is 1D, which happens with some models such as XGBCLassifier or
# LogisticRegression models, we have a binary classification problem
n_classes = y_preds.shape[1] if y_preds.ndim > 1 and y_preds.shape[1] > 1 else 2
# For binary classification problem, apply the sigmoid operator
if n_classes == 2:
y_preds = numpy_sigmoid(y_preds)[0]
# If the prediction array is 1D, transform the output into a 2D array [1-p, p],
# with p the initial output probabilities
# This is similar to what is done in scikit-learn
if y_preds.ndim == 1 or y_preds.shape[1] == 1:
y_preds = y_preds.flatten()
return numpy.vstack([1 - y_preds, y_preds]).T
# Else, apply the softmax operator
else:
y_preds = numpy_softmax(y_preds)[0]
return y_preds
# Pylint complains that this method does not override the `dump_dict` and `load_dict` methods. This
# is expected as the QuantizedTorchEstimatorMixin class is not supposed to be used as such. This
# disable could probably be removed when refactoring the serialization of models
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3250
# pylint: disable-next=abstract-method,too-many-instance-attributes
class QuantizedTorchEstimatorMixin(BaseEstimator):
"""Mixin that provides quantization for a torch module and follows the Estimator API."""
def __init_subclass__(cls, *args, **kwargs):
super().__init_subclass__(*args, **kwargs)
for klass in cls.__mro__:
# pylint: disable-next=protected-access
if getattr(klass, "_is_a_public_cml_model", False):
_NEURALNET_MODELS.add(cls)
_ALL_SKLEARN_MODELS.add(cls)
def __init__(self):
#: The quantized module used to store the quantized parameters. Is empty if the model is
#: not fitted.
self.quantized_module_ = QuantizedModule()
#: The input dimension in the underlying model
self.module__input_dim: Optional[int] = None
#: The number of outputs in the underlying model
self.module__n_outputs: Optional[int] = None
BaseEstimator.__init__(self)
@property
def base_module(self) -> SparseQuantNeuralNetwork:
"""Get the Torch module.
Returns:
SparseQuantNeuralNetwork: The fitted underlying module.
"""
assert self.sklearn_model is not None, self._sklearn_model_is_not_fitted_error_message()
return self.sklearn_model.module_
@property
def input_quantizers(self) -> List[UniformQuantizer]:
"""Get the input quantizers.
Returns:
List[UniformQuantizer]: The input quantizers.
"""
return self.quantized_module_.input_quantizers
@input_quantizers.setter
def input_quantizers(self, value: List[UniformQuantizer]) -> None:
self.quantized_module_.input_quantizers = value
@property
def output_quantizers(self) -> List[UniformQuantizer]:
"""Get the output quantizers.
Returns:
List[UniformQuantizer]: The output quantizers.
"""
return self.quantized_module_.output_quantizers
@output_quantizers.setter
def output_quantizers(self, value: List[UniformQuantizer]) -> None:
self.quantized_module_.output_quantizers = value
@property
def fhe_circuit(self) -> Circuit:
return self.quantized_module_.fhe_circuit
def get_params(self, deep: bool = True) -> dict:
"""Get parameters for this estimator.
This method is overloaded in order to make sure that auto-computed parameters are not
considered when cloning the model (e.g during a GridSearchCV call).
Args:
deep (bool): If True, will return the parameters for this estimator and
contained subobjects that are estimators.
Returns:
params (dict): Parameter names mapped to their values.
"""
# Retrieve the skorch estimator's init parameters
# Here, the `get_params` method is the `NeuralNet.get_params` method from skorch, which
# will become available once a subclass inherits from it. We therefore disable both pylint
# and mypy as this behavior is expected
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3373
# pylint: disable-next=no-member
params = super().get_params(deep) # type: ignore[misc]
# Remove `module` since it is automatically set to SparseQuantNeuralNetImpl. Therefore,
# we don't need to pass module again when cloning this estimator
params.pop("module", None)
# Remove the parameters that are auto-computed by `fit` as well
for kwarg in QNN_AUTO_KWARGS:
params.pop(kwarg, None)
return params
def get_sklearn_params(self, deep: bool = True) -> Dict:
# Retrieve the skorch estimator's init parameters
# Here, the `get_params` method is the `NeuralNet.get_params` method from skorch, which
# will become available once a subclass inherits from it. We therefore disable both pylint
# and mypy as this behavior is expected
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3373
# pylint: disable-next=no-member
params = super().get_params(deep=deep) # type: ignore[misc]
# Set the quantized module to SparseQuantNeuralNetwork
params["module"] = SparseQuantNeuralNetwork
return params
def _fit_sklearn_model(self, X: Data, y: Target, **fit_parameters):
super()._fit_sklearn_model(X, y, **fit_parameters)
# Make pruning permanent by removing weights associated to pruned neurons
self.base_module.make_pruning_permanent()
def fit(self, X: Data, y: Target, **fit_parameters):
"""Fit he estimator.
If the module was already initialized, the module will be re-initialized unless
`warm_start` is set to True. In addition to the torch training step, this method performs
quantization of the trained Torch model using Quantization Aware Training (QAT).
Values of dtype float64 are not supported and will be casted to float32.
Args:
X (Data): The training data, as a Numpy array, Torch tensor, Pandas DataFrame or List.
y (Target): The target data, as a Numpy array, Torch tensor, Pandas DataFrame, Pandas
Series or List.
**fit_parameters: Keyword arguments to pass to skorch's fit method.
Returns:
The fitted estimator.
"""
# Reset for double fit
self._is_fitted = False
# Reset the quantized module since quantization is lost during refit
# This will make the .infer() function call into the Torch nn.Module
# Instead of the quantized module
self.quantized_module_ = QuantizedModule()
X, y = check_X_y_and_assert_multi_output(X, y)
# A helper for users so they don't need to import torch directly
args_to_convert_to_tensor = ["criterion__weight"]
for arg_name in args_to_convert_to_tensor:
if hasattr(self, arg_name):