diff --git a/tfhe/src/high_level_api/backward_compatibility/tag.rs b/tfhe/src/high_level_api/backward_compatibility/tag.rs index 883a1de772..728f830d4e 100644 --- a/tfhe/src/high_level_api/backward_compatibility/tag.rs +++ b/tfhe/src/high_level_api/backward_compatibility/tag.rs @@ -1,12 +1,6 @@ -use crate::high_level_api::tag::{SmallVec, Tag}; +use crate::high_level_api::tag::Tag; use tfhe_versionable::VersionsDispatch; -#[derive(VersionsDispatch)] -pub(in crate::high_level_api) enum SmallVecVersions { - #[allow(unused)] // Unused because V1 does not exists yet - V0(SmallVec), -} - #[derive(VersionsDispatch)] pub enum TagVersions { V0(Tag), diff --git a/tfhe/src/high_level_api/tag.rs b/tfhe/src/high_level_api/tag.rs index 8998e4f3d8..42a5b8fcf7 100644 --- a/tfhe/src/high_level_api/tag.rs +++ b/tfhe/src/high_level_api/tag.rs @@ -1,16 +1,17 @@ -use crate::high_level_api::backward_compatibility::tag::{SmallVecVersions, TagVersions}; -use tfhe_versionable::Versionize; +use crate::high_level_api::backward_compatibility::tag::TagVersions; +use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned}; + +const STACK_ARRAY_SIZE: usize = std::mem::size_of::>() - 1; /// Simple short optimized vec, where if the data is small enough /// (<= std::mem::size_of::>() - 1) the data will be stored on the stack /// /// Once a true heap allocated Vec was needed, it won't be deallocated in favor /// of stack data. -#[derive(Clone, Debug, Versionize)] -#[versionize(SmallVecVersions)] +#[derive(Clone, Debug)] pub(in crate::high_level_api) enum SmallVec { Stack { - bytes: [u8; std::mem::size_of::>() - 1], + bytes: [u8; STACK_ARRAY_SIZE], // The array has a fixed size, but the user may not use all of it // so we keep track of the actual len len: u8, @@ -68,6 +69,17 @@ impl SmallVec { } } + pub fn as_slice(&self) -> &[u8] { + self.data() + } + + pub fn as_mut_slice(&mut self) -> &mut [u8] { + match self { + Self::Stack { bytes, len } => &mut bytes[..usize::from(*len)], + Self::Heap(vec) => vec.as_mut_slice(), + } + } + pub fn len(&self) -> usize { match self { Self::Stack { len, .. } => usize::from(*len), @@ -134,6 +146,23 @@ impl SmallVec { let le_bytes = value.to_le_bytes(); self.set_data(le_bytes.as_slice()); } + + // Creates a SmallVec from the vec, but, only re-uses the vec + // if its len would not fit on the stack part. + // + // Meant for versioning and deserializing + fn from_vec_conservative(vec: Vec) -> Self { + // We only re-use the versioned vec, if the SmallVec would actually + // have had its data on the heap, otherwise we prefer to keep data on stack + // as its cheaper in memory and copies + if vec.len() > STACK_ARRAY_SIZE { + Self::Heap(vec) + } else { + let mut data = Self::default(); + data.set_data(vec.as_slice()); + data + } + } } impl serde::Serialize for SmallVec { @@ -167,7 +196,32 @@ impl<'de> serde::de::Visitor<'de> for SmallVecVisitor { where E: serde::de::Error, { - Ok(SmallVec::Heap(bytes)) + Ok(SmallVec::from_vec_conservative(bytes)) + } +} + +impl Versionize for SmallVec { + type Versioned<'vers> = &'vers [u8] where Self: 'vers; + + fn versionize(&self) -> Self::Versioned<'_> { + self.data() + } +} + +impl VersionizeOwned for SmallVec { + type VersionedOwned = Vec; + + fn versionize_owned(self) -> Self::VersionedOwned { + match self { + Self::Stack { bytes, len } => bytes[..usize::from(len)].to_vec(), + Self::Heap(vec) => vec, + } + } +} + +impl Unversionize for SmallVec { + fn unversionize(versioned: Self::VersionedOwned) -> Result { + Ok(Self::from_vec_conservative(versioned)) } } @@ -206,6 +260,16 @@ impl Tag { self.inner.data() } + /// Returns a slice to the bytes stored (same a [Self::data]) + pub fn as_slice(&self) -> &[u8] { + self.inner.as_slice() + } + + /// Returns a mutable slice to the bytes stored + pub fn as_mut_slice(&mut self) -> &mut [u8] { + self.inner.as_mut_slice() + } + /// Returns the len, i.e. the number of bytes stored in the tag pub fn len(&self) -> usize { self.inner.len() diff --git a/tfhe/src/high_level_api/tests/tags_on_entities.rs b/tfhe/src/high_level_api/tests/tags_on_entities.rs index 755cc6385b..a4acce2641 100644 --- a/tfhe/src/high_level_api/tests/tags_on_entities.rs +++ b/tfhe/src/high_level_api/tests/tags_on_entities.rs @@ -96,7 +96,7 @@ fn test_tag_propagation_zk_pok() { #[cfg(feature = "gpu")] fn test_tag_propagation_gpu() { test_tag_propagation( - Device::Gpu, + Device::CudaGpu, PARAM_MESSAGE_2_CARRY_2, Some(COMP_PARAM_MESSAGE_2_CARRY_2), )