From 408f82c3110cfb1f9491a4b7b2a5b30dc25e9816 Mon Sep 17 00:00:00 2001 From: "Mayeul@Zama" <69792125+mayeul-zama@users.noreply.github.com> Date: Fri, 13 Dec 2024 11:30:05 +0100 Subject: [PATCH] chore(hlapi): stabilize FheTypes --- tfhe/Cargo.toml | 4 +- tfhe/src/c_api/high_level_api/mod.rs | 71 ++++++++++++++++------------ tfhe/src/high_level_api/mod.rs | 63 ++++++++++++------------ 3 files changed, 77 insertions(+), 61 deletions(-) diff --git a/tfhe/Cargo.toml b/tfhe/Cargo.toml index 673a23ba8f..0eb968c2cb 100644 --- a/tfhe/Cargo.toml +++ b/tfhe/Cargo.toml @@ -91,11 +91,13 @@ console_error_panic_hook = { version = "0.1.7", optional = true } serde-wasm-bindgen = { version = "0.6.0", optional = true } getrandom = { version = "0.2.8", optional = true } bytemuck = { workspace = true } +strum = { version = "0.26", features = ["derive"], optional = true } + [features] boolean = [] shortint = ["dep:sha3"] -integer = ["shortint"] +integer = ["shortint", "strum"] strings = ["integer"] internal-keycache = ["dep:lazy_static", "dep:fs2"] gpu = ["dep:tfhe-cuda-backend"] diff --git a/tfhe/src/c_api/high_level_api/mod.rs b/tfhe/src/c_api/high_level_api/mod.rs index 6aba60b66d..690f231950 100644 --- a/tfhe/src/c_api/high_level_api/mod.rs +++ b/tfhe/src/c_api/high_level_api/mod.rs @@ -22,36 +22,36 @@ mod zk; #[repr(C)] #[allow(non_camel_case_types)] pub enum FheTypes { - Type_FheBool, - Type_FheUint2, - Type_FheUint4, - Type_FheUint6, - Type_FheUint8, - Type_FheUint10, - Type_FheUint12, - Type_FheUint14, - Type_FheUint16, - Type_FheUint32, - Type_FheUint64, - Type_FheUint128, - Type_FheUint160, - Type_FheUint256, - Type_FheUint512, - Type_FheUint1024, - Type_FheUint2048, - Type_FheInt2, - Type_FheInt4, - Type_FheInt6, - Type_FheInt8, - Type_FheInt10, - Type_FheInt12, - Type_FheInt14, - Type_FheInt16, - Type_FheInt32, - Type_FheInt64, - Type_FheInt128, - Type_FheInt160, - Type_FheInt256, + Type_FheBool = 0, + Type_FheUint4 = 1, + Type_FheUint8 = 2, + Type_FheUint16 = 3, + Type_FheUint32 = 4, + Type_FheUint64 = 5, + Type_FheUint128 = 6, + Type_FheUint160 = 7, + Type_FheUint256 = 8, + Type_FheUint512 = 9, + Type_FheUint1024 = 10, + Type_FheUint2048 = 11, + Type_FheUint2 = 12, + Type_FheUint6 = 13, + Type_FheUint10 = 14, + Type_FheUint12 = 15, + Type_FheUint14 = 16, + Type_FheInt2 = 17, + Type_FheInt4 = 18, + Type_FheInt6 = 19, + Type_FheInt8 = 20, + Type_FheInt10 = 21, + Type_FheInt12 = 22, + Type_FheInt14 = 23, + Type_FheInt16 = 24, + Type_FheInt32 = 25, + Type_FheInt64 = 26, + Type_FheInt128 = 27, + Type_FheInt160 = 28, + Type_FheInt256 = 29, } impl From for FheTypes { @@ -90,3 +90,14 @@ impl From for FheTypes { } } } + +#[test] +fn fhe_types_enum_to_int_compatible2() { + use strum::IntoEnumIterator; + + for rust_value in crate::FheTypes::iter() { + let c_value = FheTypes::from(rust_value); + + assert_eq!(rust_value as i32, c_value as i32) + } +} diff --git a/tfhe/src/high_level_api/mod.rs b/tfhe/src/high_level_api/mod.rs index 4caf58e251..3db641e5e9 100644 --- a/tfhe/src/high_level_api/mod.rs +++ b/tfhe/src/high_level_api/mod.rs @@ -58,6 +58,7 @@ pub use keys::{ generate_keys, ClientKey, CompactPublicKey, CompressedCompactPublicKey, CompressedPublicKey, CompressedServerKey, KeySwitchingKey, PublicKey, ServerKey, }; +use strum::EnumIter; #[cfg(test)] mod tests; @@ -129,35 +130,37 @@ pub enum Device { } #[derive(Copy, Clone, PartialEq, Eq, Debug)] +#[repr(i32)] +#[derive(EnumIter)] pub enum FheTypes { - Bool, - Uint2, - Uint4, - Uint6, - Uint8, - Uint10, - Uint12, - Uint14, - Uint16, - Uint32, - Uint64, - Uint128, - Uint160, - Uint256, - Uint512, - Uint1024, - Uint2048, - Int2, - Int4, - Int6, - Int8, - Int10, - Int12, - Int14, - Int16, - Int32, - Int64, - Int128, - Int160, - Int256, + Bool = 0, + Uint4 = 1, + Uint8 = 2, + Uint16 = 3, + Uint32 = 4, + Uint64 = 5, + Uint128 = 6, + Uint160 = 7, + Uint256 = 8, + Uint512 = 9, + Uint1024 = 10, + Uint2048 = 11, + Uint2 = 12, + Uint6 = 13, + Uint10 = 14, + Uint12 = 15, + Uint14 = 16, + Int2 = 17, + Int4 = 18, + Int6 = 19, + Int8 = 20, + Int10 = 21, + Int12 = 22, + Int14 = 23, + Int16 = 24, + Int32 = 25, + Int64 = 26, + Int128 = 27, + Int160 = 28, + Int256 = 29, }