diff --git a/crates/swss-common/src/lib.rs b/crates/swss-common/src/lib.rs index cfe08f3..a988bc8 100644 --- a/crates/swss-common/src/lib.rs +++ b/crates/swss-common/src/lib.rs @@ -4,17 +4,16 @@ mod bindings { } mod types; -use bindings::*; pub use types::*; /// Rust wrapper around `swss::SonicDBConfig::initialize`. -pub fn sonic_db_config_initialize(path: &str) { +pub fn sonic_db_config_initialize(path: &str) -> Result<(), Exception> { let path = cstr(path); - unsafe { bindings::SWSSSonicDBConfig_initialize(path.as_ptr()) } + unsafe { Exception::try0(bindings::SWSSSonicDBConfig_initialize(path.as_ptr())) } } /// Rust wrapper around `swss::SonicDBConfig::initializeGlobalConfig`. -pub fn sonic_db_config_initialize_global(path: &str) { +pub fn sonic_db_config_initialize_global(path: &str) -> Result<(), Exception> { let path = cstr(path); - unsafe { bindings::SWSSSonicDBConfig_initializeGlobalConfig(path.as_ptr()) } + unsafe { Exception::try0(bindings::SWSSSonicDBConfig_initializeGlobalConfig(path.as_ptr())) } } diff --git a/crates/swss-common/src/types.rs b/crates/swss-common/src/types.rs index 51e631a..a059d59 100644 --- a/crates/swss-common/src/types.rs +++ b/crates/swss-common/src/types.rs @@ -6,6 +6,7 @@ mod cxxstring; mod dbconnector; mod producerstatetable; mod subscriberstatetable; +mod table; mod zmqclient; mod zmqconsumerstatetable; mod zmqproducerstatetable; @@ -16,12 +17,13 @@ pub use cxxstring::{CxxStr, CxxString}; pub use dbconnector::{DbConnectionInfo, DbConnector}; pub use producerstatetable::ProducerStateTable; pub use subscriberstatetable::SubscriberStateTable; +pub use table::Table; pub use zmqclient::ZmqClient; pub use zmqconsumerstatetable::ZmqConsumerStateTable; pub use zmqproducerstatetable::ZmqProducerStateTable; pub use zmqserver::ZmqServer; -use crate::*; +use crate::bindings::*; use cxxstring::RawMutableSWSSString; use std::{ any::Any, @@ -29,18 +31,86 @@ use std::{ error::Error, ffi::{CStr, CString}, fmt::Display, + mem::MaybeUninit, slice, str::FromStr, }; pub(crate) fn cstr(s: impl AsRef<[u8]>) -> CString { - CString::new(s.as_ref()).unwrap() + CString::new(s.as_ref()).expect("Bytes being converted to a C string already contains a null byte") } -pub(crate) unsafe fn str(p: *const i8) -> String { - CStr::from_ptr(p).to_str().unwrap().to_string() +/// Take a malloc'd c string and convert it to a native String +pub(crate) unsafe fn take_cstr(p: *const i8) -> String { + let s = CStr::from_ptr(p) + .to_str() + .expect("C string being converted to Rust String contains invalid UTF-8") + .to_string(); + libc::free(p as *mut libc::c_void); + s } +pub type Result = std::result::Result; + +/// Rust version of a failed `SWSSResult`. +/// +/// When an `SWSSResult` is success/`SWSSException_None`, the rust function will return `Ok(..)`. +/// Otherwise, the rust function will return `Err(Exception)` +#[derive(Debug, Clone)] +pub struct Exception { + message: String, + location: String, +} + +impl Exception { + pub(crate) unsafe fn take_raw(res: &mut SWSSResult) -> Self { + Self { + message: CxxString::take_raw(&mut res.message) + .expect("SWSSResult missing message") + .to_string_lossy() + .into_owned(), + location: CxxString::take_raw(&mut res.location) + .expect("SWSSResult missing location") + .to_string_lossy() + .into_owned(), + } + } + + /// Call an SWSS function that takes one output pointer `*mut T` and returns an `SWSSResult`, and transform that into `Result`. + pub(crate) unsafe fn try1 SWSSResult>(f: F) -> Result { + let mut t: MaybeUninit = MaybeUninit::uninit(); + let mut result: SWSSResult = f(t.as_mut_ptr()); + if result.exception == SWSSException_SWSSException_None { + Ok(t.assume_init()) + } else { + Err(Exception::take_raw(&mut result)) + } + } + + /// Transform an `SWSSResult` into `Result<(), Exception>`, for SWSS functions that take no output pointer. + pub(crate) unsafe fn try0(res: SWSSResult) -> Result<(), Exception> { + Exception::try1(|_| res) + } + + /// Get an informational string about the error that occurred. + pub fn message(&self) -> &str { + &self.message + } + + /// Get an informational string about the where in the code the error occurred. + pub fn location(&self) -> &str { + &self.location + } +} + +impl Display for Exception { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "[{}] {}", self.location, self.message) + } +} + +impl Error for Exception {} + /// Rust version of the return type from `swss::Select::select`. /// /// This enum does not include the `swss::Select::ERROR` because errors are handled via a different @@ -130,12 +200,15 @@ impl Display for InvalidKeyOperationString { impl Error for InvalidKeyOperationString {} +/// Rust version of `vector`. +pub type FieldValues = HashMap; + /// Rust version of `swss::KeyOpFieldsValuesTuple`. #[derive(Debug, Clone, PartialEq, Eq)] pub struct KeyOpFieldValues { pub key: String, pub operation: KeyOperation, - pub field_values: HashMap, + pub field_values: FieldValues, } /// Intended for testing, ordered by key. @@ -153,12 +226,12 @@ impl Ord for KeyOpFieldValues { } /// Takes ownership of an `SWSSFieldValueArray` and turns it into a native representation. -pub(crate) unsafe fn take_field_value_array(arr: SWSSFieldValueArray) -> HashMap { +pub(crate) unsafe fn take_field_value_array(arr: SWSSFieldValueArray) -> FieldValues { let mut out = HashMap::with_capacity(arr.len as usize); if !arr.data.is_null() { let entries = slice::from_raw_parts_mut(arr.data, arr.len as usize); for fv in entries { - let field = str(fv.field); + let field = take_cstr(fv.field); let value = CxxString::take_raw(&mut fv.value).unwrap(); out.insert(field, value); } @@ -174,7 +247,7 @@ pub(crate) unsafe fn take_key_op_field_values_array(kfvs: SWSSKeyOpFieldValuesAr unsafe { let entries = slice::from_raw_parts_mut(kfvs.data, kfvs.len as usize); for kfv in entries { - let key = str(kfv.key); + let key = take_cstr(kfv.key); let operation = KeyOperation::from_raw(kfv.operation); let field_values = take_field_value_array(kfv.fieldValues); out.push(KeyOpFieldValues { @@ -189,6 +262,18 @@ pub(crate) unsafe fn take_key_op_field_values_array(kfvs: SWSSKeyOpFieldValuesAr out } +/// Takes ownership of an `SWSSStringArray` and turns it into a native representation. +pub(crate) unsafe fn take_string_array(arr: SWSSStringArray) -> Vec { + let out = if !arr.data.is_null() { + let entries = slice::from_raw_parts(arr.data, arr.len as usize); + Vec::from_iter(entries.iter().map(|&s| take_cstr(s))) + } else { + Vec::new() + }; + SWSSStringArray_free(arr); + out +} + pub(crate) fn make_field_value_array(fvs: I) -> (SWSSFieldValueArray, KeepAlive) where I: IntoIterator, diff --git a/crates/swss-common/src/types/async_util.rs b/crates/swss-common/src/types/async_util.rs index 689f2b5..154e115 100644 --- a/crates/swss-common/src/types/async_util.rs +++ b/crates/swss-common/src/types/async_util.rs @@ -35,10 +35,10 @@ macro_rules! impl_read_data_async { pub async fn read_data_async(&mut self) -> ::std::io::Result<()> { use ::tokio::io::{unix::AsyncFd, Interest}; - let _ready_guard = AsyncFd::with_interest(self.get_fd(), Interest::READABLE)? - .readable() - .await?; - self.read_data(Duration::from_secs(0), false); + let fd = self.get_fd().map_err(::std::io::Error::other)?; + let _ready_guard = AsyncFd::with_interest(fd, Interest::READABLE)?.readable().await?; + self.read_data(Duration::from_secs(0), false) + .map_err(::std::io::Error::other)?; Ok(()) } }; diff --git a/crates/swss-common/src/types/consumerstatetable.rs b/crates/swss-common/src/types/consumerstatetable.rs index ee2e060..a397c95 100644 --- a/crates/swss-common/src/types/consumerstatetable.rs +++ b/crates/swss-common/src/types/consumerstatetable.rs @@ -1,5 +1,5 @@ use super::*; -use crate::*; +use crate::bindings::*; use std::{os::fd::BorrowedFd, ptr::null, time::Duration}; /// Rust wrapper around `swss::ConsumerStateTable`. @@ -10,39 +10,47 @@ pub struct ConsumerStateTable { } impl ConsumerStateTable { - pub fn new(db: DbConnector, table_name: &str, pop_batch_size: Option, pri: Option) -> Self { + pub fn new(db: DbConnector, table_name: &str, pop_batch_size: Option, pri: Option) -> Result { let table_name = cstr(table_name); let pop_batch_size = pop_batch_size.as_ref().map(|n| n as *const i32).unwrap_or(null()); let pri = pri.as_ref().map(|n| n as *const i32).unwrap_or(null()); - let ptr = unsafe { SWSSConsumerStateTable_new(db.ptr, table_name.as_ptr(), pop_batch_size, pri) }; - Self { ptr, _db: db } + let ptr = unsafe { + Exception::try1(|p_cst| SWSSConsumerStateTable_new(db.ptr, table_name.as_ptr(), pop_batch_size, pri, p_cst)) + }?; + Ok(Self { ptr, _db: db }) } - pub fn pops(&self) -> Vec { + pub fn pops(&self) -> Result> { unsafe { - let ans = SWSSConsumerStateTable_pops(self.ptr); - take_key_op_field_values_array(ans) + let arr = Exception::try1(|p_arr| SWSSConsumerStateTable_pops(self.ptr, p_arr))?; + Ok(take_key_op_field_values_array(arr)) } } - pub fn get_fd(&self) -> BorrowedFd { - let fd = unsafe { SWSSConsumerStateTable_getFd(self.ptr) }; - + pub fn get_fd(&self) -> Result { // SAFETY: This fd represents the underlying redis connection, which should stay alive // as long as the DbConnector does. - unsafe { BorrowedFd::borrow_raw(fd.try_into().unwrap()) } + unsafe { + let fd = Exception::try1(|p_fd| SWSSConsumerStateTable_getFd(self.ptr, p_fd))?; + let fd = BorrowedFd::borrow_raw(fd.try_into().unwrap()); + Ok(fd) + } } - pub fn read_data(&self, timeout: Duration, interrupt_on_signal: bool) -> SelectResult { + pub fn read_data(&self, timeout: Duration, interrupt_on_signal: bool) -> Result { let timeout_ms = timeout.as_millis().try_into().unwrap(); - let res = unsafe { SWSSConsumerStateTable_readData(self.ptr, timeout_ms, interrupt_on_signal as u8) }; - SelectResult::from_raw(res) + let res = unsafe { + Exception::try1(|p_res| { + SWSSConsumerStateTable_readData(self.ptr, timeout_ms, interrupt_on_signal as u8, p_res) + })? + }; + Ok(SelectResult::from_raw(res)) } } impl Drop for ConsumerStateTable { fn drop(&mut self) { - unsafe { SWSSConsumerStateTable_free(self.ptr) }; + unsafe { Exception::try0(SWSSConsumerStateTable_free(self.ptr)).expect("Dropping ConsumerStateTable") }; } } diff --git a/crates/swss-common/src/types/cxxstring.rs b/crates/swss-common/src/types/cxxstring.rs index 312e5cd..30666f0 100644 --- a/crates/swss-common/src/types/cxxstring.rs +++ b/crates/swss-common/src/types/cxxstring.rs @@ -1,4 +1,4 @@ -use crate::*; +use crate::bindings::*; use std::{ borrow::{Borrow, Cow}, fmt::Debug, diff --git a/crates/swss-common/src/types/dbconnector.rs b/crates/swss-common/src/types/dbconnector.rs index 049980a..a22e72f 100644 --- a/crates/swss-common/src/types/dbconnector.rs +++ b/crates/swss-common/src/types/dbconnector.rs @@ -1,5 +1,5 @@ use super::*; -use crate::*; +use crate::bindings::*; use std::collections::HashMap; /// Rust wrapper around `swss::DBConnector`. @@ -22,25 +22,29 @@ impl DbConnector { /// Create a new DbConnector from [`DbConnectionInfo`]. /// /// Timeout of 0 means block indefinitely. - fn new(db_id: i32, connection: DbConnectionInfo, timeout_ms: u32) -> DbConnector { + fn new(db_id: i32, connection: DbConnectionInfo, timeout_ms: u32) -> Result { let ptr = match &connection { DbConnectionInfo::Tcp { hostname, port } => { let hostname = cstr(hostname); - unsafe { SWSSDBConnector_new_tcp(db_id, hostname.as_ptr(), *port, timeout_ms) } + unsafe { + Exception::try1(|p_db| SWSSDBConnector_new_tcp(db_id, hostname.as_ptr(), *port, timeout_ms, p_db))? + } } DbConnectionInfo::Unix { sock_path } => { let sock_path = cstr(sock_path); - unsafe { SWSSDBConnector_new_unix(db_id, sock_path.as_ptr(), timeout_ms) } + unsafe { + Exception::try1(|p_db| SWSSDBConnector_new_unix(db_id, sock_path.as_ptr(), timeout_ms, p_db))? + } } }; - Self { ptr, db_id, connection } + Ok(Self { ptr, db_id, connection }) } /// Create a DbConnector over a tcp socket. /// /// Timeout of 0 means block indefinitely. - pub fn new_tcp(db_id: i32, hostname: impl Into, port: u16, timeout_ms: u32) -> DbConnector { + pub fn new_tcp(db_id: i32, hostname: impl Into, port: u16, timeout_ms: u32) -> Result { let hostname = hostname.into(); Self::new(db_id, DbConnectionInfo::Tcp { hostname, port }, timeout_ms) } @@ -48,7 +52,7 @@ impl DbConnector { /// Create a DbConnector over a unix socket. /// /// Timeout of 0 means block indefinitely. - pub fn new_unix(db_id: i32, sock_path: impl Into, timeout_ms: u32) -> DbConnector { + pub fn new_unix(db_id: i32, sock_path: impl Into, timeout_ms: u32) -> Result { let sock_path = sock_path.into(); Self::new(db_id, DbConnectionInfo::Unix { sock_path }, timeout_ms) } @@ -56,7 +60,7 @@ impl DbConnector { /// Clone a DbConnector with a timeout. /// /// Timeout of 0 means block indefinitely. - pub fn clone_timeout(&self, timeout_ms: u32) -> Self { + pub fn clone_timeout(&self, timeout_ms: u32) -> Result { Self::new(self.db_id, self.connection.clone(), timeout_ms) } @@ -68,72 +72,88 @@ impl DbConnector { &self.connection } - pub fn del(&self, key: &str) -> bool { + pub fn del(&self, key: &str) -> Result { let key = cstr(key); - unsafe { SWSSDBConnector_del(self.ptr, key.as_ptr()) == 1 } + let status = unsafe { Exception::try1(|p_status| SWSSDBConnector_del(self.ptr, key.as_ptr(), p_status))? }; + Ok(status == 1) } - pub fn set(&self, key: &str, val: &CxxStr) { + pub fn set(&self, key: &str, val: &CxxStr) -> Result<()> { let key = cstr(key); - unsafe { SWSSDBConnector_set(self.ptr, key.as_ptr(), val.as_raw()) }; + unsafe { Exception::try0(SWSSDBConnector_set(self.ptr, key.as_ptr(), val.as_raw())) } } - pub fn get(&self, key: &str) -> Option { + pub fn get(&self, key: &str) -> Result> { let key = cstr(key); unsafe { - let mut ans = SWSSDBConnector_get(self.ptr, key.as_ptr()); - CxxString::take_raw(&mut ans) + let mut ans = Exception::try1(|p_ans| SWSSDBConnector_get(self.ptr, key.as_ptr(), p_ans))?; + Ok(CxxString::take_raw(&mut ans)) } } - pub fn exists(&self, key: &str) -> bool { + pub fn exists(&self, key: &str) -> Result { let key = cstr(key); - unsafe { SWSSDBConnector_exists(self.ptr, key.as_ptr()) == 1 } + let status = unsafe { Exception::try1(|p_status| SWSSDBConnector_exists(self.ptr, key.as_ptr(), p_status))? }; + Ok(status == 1) } - pub fn hdel(&self, key: &str, field: &str) -> bool { + pub fn hdel(&self, key: &str, field: &str) -> Result { let key = cstr(key); let field = cstr(field); - unsafe { SWSSDBConnector_hdel(self.ptr, key.as_ptr(), field.as_ptr()) == 1 } + let status = unsafe { + Exception::try1(|p_status| SWSSDBConnector_hdel(self.ptr, key.as_ptr(), field.as_ptr(), p_status))? + }; + Ok(status == 1) } - pub fn hset(&self, key: &str, field: &str, val: &CxxStr) { + pub fn hset(&self, key: &str, field: &str, val: &CxxStr) -> Result<()> { let key = cstr(key); let field = cstr(field); - unsafe { SWSSDBConnector_hset(self.ptr, key.as_ptr(), field.as_ptr(), val.as_raw()) }; + unsafe { + Exception::try0(SWSSDBConnector_hset( + self.ptr, + key.as_ptr(), + field.as_ptr(), + val.as_raw(), + )) + } } - pub fn hget(&self, key: &str, field: &str) -> Option { + pub fn hget(&self, key: &str, field: &str) -> Result> { let key = cstr(key); let field = cstr(field); unsafe { - let mut ans = SWSSDBConnector_hget(self.ptr, key.as_ptr(), field.as_ptr()); - CxxString::take_raw(&mut ans) + let mut ans = Exception::try1(|p_ans| SWSSDBConnector_hget(self.ptr, key.as_ptr(), field.as_ptr(), p_ans))?; + Ok(CxxString::take_raw(&mut ans)) } } - pub fn hgetall(&self, key: &str) -> HashMap { + pub fn hgetall(&self, key: &str) -> Result> { let key = cstr(key); unsafe { - let ans = SWSSDBConnector_hgetall(self.ptr, key.as_ptr()); - take_field_value_array(ans) + let arr = Exception::try1(|p_arr| SWSSDBConnector_hgetall(self.ptr, key.as_ptr(), p_arr))?; + Ok(take_field_value_array(arr)) } } - pub fn hexists(&self, key: &str, field: &str) -> bool { + pub fn hexists(&self, key: &str, field: &str) -> Result { let key = cstr(key); let field = cstr(field); - unsafe { SWSSDBConnector_hexists(self.ptr, key.as_ptr(), field.as_ptr()) == 1 } + let status = unsafe { + Exception::try1(|p_status| SWSSDBConnector_hexists(self.ptr, key.as_ptr(), field.as_ptr(), p_status))? + }; + Ok(status == 1) } - pub fn flush_db(&self) -> bool { - unsafe { SWSSDBConnector_flushdb(self.ptr) == 1 } + pub fn flush_db(&self) -> Result { + let status = unsafe { Exception::try1(|p_status| SWSSDBConnector_flushdb(self.ptr, p_status))? }; + Ok(status == 1) } } impl Drop for DbConnector { fn drop(&mut self) { - unsafe { SWSSDBConnector_free(self.ptr) }; + unsafe { Exception::try0(SWSSDBConnector_free(self.ptr)).expect("Dropping DbConnector") }; } } @@ -141,17 +161,17 @@ unsafe impl Send for DbConnector {} #[cfg(feature = "async")] impl DbConnector { - async_util::impl_basic_async_method!(new_tcp_async <= new_tcp(db_id: i32, hostname: &str, port: u16, timeout_ms: u32) -> DbConnector); - async_util::impl_basic_async_method!(new_unix_async <= new_unix(db_id: i32, sock_path: &str, timeout_ms: u32) -> DbConnector); - async_util::impl_basic_async_method!(clone_timeout_async <= clone_timeout(&self, timeout_ms: u32) -> DbConnector); - async_util::impl_basic_async_method!(del_async <= del(&self, key: &str) -> bool); - async_util::impl_basic_async_method!(set_async <= set(&self, key: &str, value: &CxxStr)); - async_util::impl_basic_async_method!(get_async <= get(&self, key: &str) -> Option); - async_util::impl_basic_async_method!(exists_async <= exists(&self, key: &str) -> bool); - async_util::impl_basic_async_method!(hdel_async <= hdel(&self, key: &str, field: &str) -> bool); - async_util::impl_basic_async_method!(hset_async <= hset(&self, key: &str, field: &str, value: &CxxStr)); - async_util::impl_basic_async_method!(hget_async <= hget(&self, key: &str, field: &str) -> Option); - async_util::impl_basic_async_method!(hgetall_async <= hgetall(&self, key: &str) -> HashMap); - async_util::impl_basic_async_method!(hexists_async <= hexists(&self, key: &str, field: &str) -> bool); - async_util::impl_basic_async_method!(flush_db_async <= flush_db(&self) -> bool); + async_util::impl_basic_async_method!(new_tcp_async <= new_tcp(db_id: i32, hostname: &str, port: u16, timeout_ms: u32) -> Result); + async_util::impl_basic_async_method!(new_unix_async <= new_unix(db_id: i32, sock_path: &str, timeout_ms: u32) -> Result); + async_util::impl_basic_async_method!(clone_timeout_async <= clone_timeout(&self, timeout_ms: u32) -> Result); + async_util::impl_basic_async_method!(del_async <= del(&self, key: &str) -> Result); + async_util::impl_basic_async_method!(set_async <= set(&self, key: &str, value: &CxxStr) -> Result<()>); + async_util::impl_basic_async_method!(get_async <= get(&self, key: &str) -> Result>); + async_util::impl_basic_async_method!(exists_async <= exists(&self, key: &str) -> Result); + async_util::impl_basic_async_method!(hdel_async <= hdel(&self, key: &str, field: &str) -> Result); + async_util::impl_basic_async_method!(hset_async <= hset(&self, key: &str, field: &str, value: &CxxStr) -> Result<()>); + async_util::impl_basic_async_method!(hget_async <= hget(&self, key: &str, field: &str) -> Result>); + async_util::impl_basic_async_method!(hgetall_async <= hgetall(&self, key: &str) -> Result>); + async_util::impl_basic_async_method!(hexists_async <= hexists(&self, key: &str, field: &str) -> Result); + async_util::impl_basic_async_method!(flush_db_async <= flush_db(&self) -> Result); } diff --git a/crates/swss-common/src/types/producerstatetable.rs b/crates/swss-common/src/types/producerstatetable.rs index 2f8f196..edca48c 100644 --- a/crates/swss-common/src/types/producerstatetable.rs +++ b/crates/swss-common/src/types/producerstatetable.rs @@ -1,5 +1,5 @@ use super::*; -use crate::*; +use crate::bindings::*; /// Rust wrapper around `swss::ProducerStateTable`. #[derive(Debug)] @@ -9,17 +9,17 @@ pub struct ProducerStateTable { } impl ProducerStateTable { - pub fn new(db: DbConnector, table_name: &str) -> Self { + pub fn new(db: DbConnector, table_name: &str) -> Result { let table_name = cstr(table_name); - let ptr = unsafe { SWSSProducerStateTable_new(db.ptr, table_name.as_ptr()) }; - Self { ptr, _db: db } + let ptr = unsafe { Exception::try1(|p_pst| SWSSProducerStateTable_new(db.ptr, table_name.as_ptr(), p_pst))? }; + Ok(Self { ptr, _db: db }) } - pub fn set_buffered(&self, buffered: bool) { - unsafe { SWSSProducerStateTable_setBuffered(self.ptr, buffered as u8) }; + pub fn set_buffered(&self, buffered: bool) -> Result<()> { + unsafe { Exception::try0(SWSSProducerStateTable_setBuffered(self.ptr, buffered as u8)) } } - pub fn set(&self, key: &str, fvs: I) + pub fn set(&self, key: &str, fvs: I) -> Result<()> where I: IntoIterator, F: AsRef<[u8]>, @@ -27,38 +27,38 @@ impl ProducerStateTable { { let key = cstr(key); let (arr, _k) = make_field_value_array(fvs); - unsafe { SWSSProducerStateTable_set(self.ptr, key.as_ptr(), arr) }; + unsafe { Exception::try0(SWSSProducerStateTable_set(self.ptr, key.as_ptr(), arr)) } } - pub fn del(&self, key: &str) { + pub fn del(&self, key: &str) -> Result<()> { let key = cstr(key); - unsafe { SWSSProducerStateTable_del(self.ptr, key.as_ptr()) }; + unsafe { Exception::try0(SWSSProducerStateTable_del(self.ptr, key.as_ptr())) } } - pub fn flush(&self) { - unsafe { SWSSProducerStateTable_flush(self.ptr) }; + pub fn flush(&self) -> Result<()> { + unsafe { Exception::try0(SWSSProducerStateTable_flush(self.ptr)) } } - pub fn count(&self) -> i64 { - unsafe { SWSSProducerStateTable_count(self.ptr) } + pub fn count(&self) -> Result { + unsafe { Exception::try1(|p_count| SWSSProducerStateTable_count(self.ptr, p_count)) } } - pub fn clear(&self) { - unsafe { SWSSProducerStateTable_clear(self.ptr) }; + pub fn clear(&self) -> Result<()> { + unsafe { Exception::try0(SWSSProducerStateTable_clear(self.ptr)) } } - pub fn create_temp_view(&self) { - unsafe { SWSSProducerStateTable_create_temp_view(self.ptr) }; + pub fn create_temp_view(&self) -> Result<()> { + unsafe { Exception::try0(SWSSProducerStateTable_create_temp_view(self.ptr)) } } - pub fn apply_temp_view(&self) { - unsafe { SWSSProducerStateTable_apply_temp_view(self.ptr) }; + pub fn apply_temp_view(&self) -> Result<()> { + unsafe { Exception::try0(SWSSProducerStateTable_apply_temp_view(self.ptr)) } } } impl Drop for ProducerStateTable { fn drop(&mut self) { - unsafe { SWSSProducerStateTable_free(self.ptr) }; + unsafe { Exception::try0(SWSSProducerStateTable_free(self.ptr)).expect("Dropping ProducerStateTable") }; } } @@ -67,12 +67,12 @@ unsafe impl Send for ProducerStateTable {} #[cfg(feature = "async")] impl ProducerStateTable { async_util::impl_basic_async_method!( - set_async <= set(&self, key: &str, fvs: I) + set_async <= set(&self, key: &str, fvs: I) -> Result<()> where I: IntoIterator + Send, F: AsRef<[u8]>, V: Into, ); - async_util::impl_basic_async_method!(del_async <= del(&self, key: &str)); - async_util::impl_basic_async_method!(flush_async <= flush(&self)); + async_util::impl_basic_async_method!(del_async <= del(&self, key: &str) -> Result<()>); + async_util::impl_basic_async_method!(flush_async <= flush(&self) -> Result<()>); } diff --git a/crates/swss-common/src/types/subscriberstatetable.rs b/crates/swss-common/src/types/subscriberstatetable.rs index 4ea271e..af361d3 100644 --- a/crates/swss-common/src/types/subscriberstatetable.rs +++ b/crates/swss-common/src/types/subscriberstatetable.rs @@ -1,5 +1,5 @@ use super::*; -use crate::*; +use crate::bindings::*; use std::{os::fd::BorrowedFd, ptr::null, time::Duration}; /// Rust wrapper around `swss::SubscriberStateTable`. @@ -10,39 +10,49 @@ pub struct SubscriberStateTable { } impl SubscriberStateTable { - pub fn new(db: DbConnector, table_name: &str, pop_batch_size: Option, pri: Option) -> Self { + pub fn new(db: DbConnector, table_name: &str, pop_batch_size: Option, pri: Option) -> Result { let table_name = cstr(table_name); let pop_batch_size = pop_batch_size.as_ref().map(|n| n as *const i32).unwrap_or(null()); let pri = pri.as_ref().map(|n| n as *const i32).unwrap_or(null()); - let ptr = unsafe { SWSSSubscriberStateTable_new(db.ptr, table_name.as_ptr(), pop_batch_size, pri) }; - Self { ptr, _db: db } + let ptr = unsafe { + Exception::try1(|p_sst| { + SWSSSubscriberStateTable_new(db.ptr, table_name.as_ptr(), pop_batch_size, pri, p_sst) + })? + }; + Ok(Self { ptr, _db: db }) } - pub fn pops(&self) -> Vec { + pub fn pops(&self) -> Result> { unsafe { - let ans = SWSSSubscriberStateTable_pops(self.ptr); - take_key_op_field_values_array(ans) + let arr = Exception::try1(|p_arr| SWSSSubscriberStateTable_pops(self.ptr, p_arr))?; + Ok(take_key_op_field_values_array(arr)) } } - pub fn read_data(&self, timeout: Duration, interrupt_on_signal: bool) -> SelectResult { + pub fn read_data(&self, timeout: Duration, interrupt_on_signal: bool) -> Result { let timeout_ms = timeout.as_millis().try_into().unwrap(); - let res = unsafe { SWSSSubscriberStateTable_readData(self.ptr, timeout_ms, interrupt_on_signal as u8) }; - SelectResult::from_raw(res) + let res = unsafe { + Exception::try1(|p_res| { + SWSSSubscriberStateTable_readData(self.ptr, timeout_ms, interrupt_on_signal as u8, p_res) + })? + }; + Ok(SelectResult::from_raw(res)) } - pub fn get_fd(&self) -> BorrowedFd { - let fd = unsafe { SWSSSubscriberStateTable_getFd(self.ptr) }; - - // SAFETY: This fd represents the underlying redis connection, which should remain open as - // long as the DbConnector is alive - unsafe { BorrowedFd::borrow_raw(fd.try_into().unwrap()) } + pub fn get_fd(&self) -> Result { + // SAFETY: This fd represents the underlying redis connection, which should stay alive + // as long as the DbConnector does. + unsafe { + let fd = Exception::try1(|p_fd| SWSSSubscriberStateTable_getFd(self.ptr, p_fd))?; + let fd = BorrowedFd::borrow_raw(fd.try_into().unwrap()); + Ok(fd) + } } } impl Drop for SubscriberStateTable { fn drop(&mut self) { - unsafe { SWSSSubscriberStateTable_free(self.ptr) }; + unsafe { Exception::try0(SWSSSubscriberStateTable_free(self.ptr)).expect("Dropping SubscriberStateTable") }; } } diff --git a/crates/swss-common/src/types/table.rs b/crates/swss-common/src/types/table.rs new file mode 100644 index 0000000..a15b4f1 --- /dev/null +++ b/crates/swss-common/src/types/table.rs @@ -0,0 +1,108 @@ +use super::*; +use crate::bindings::*; +use std::ptr; + +#[derive(Debug)] +pub struct Table { + ptr: SWSSTable, + _db: DbConnector, +} + +impl Table { + pub fn new(db: DbConnector, table_name: &str) -> Result { + let table_name = cstr(table_name); + let ptr = unsafe { Exception::try1(|p_tbl| SWSSTable_new(db.ptr, table_name.as_ptr(), p_tbl))? }; + Ok(Self { ptr, _db: db }) + } + + pub fn get(&self, key: &str) -> Result> { + let key = cstr(key); + let mut arr = SWSSFieldValueArray { + len: 0, + data: ptr::null_mut(), + }; + let exists = unsafe { Exception::try1(|p_exists| SWSSTable_get(self.ptr, key.as_ptr(), &mut arr, p_exists))? }; + let maybe_fvs = if exists == 1 { + Some(unsafe { take_field_value_array(arr) }) + } else { + None + }; + Ok(maybe_fvs) + } + + pub fn hget(&self, key: &str, field: &str) -> Result> { + let key = cstr(key); + let field = cstr(field); + let mut val: SWSSString = ptr::null_mut(); + let exists = unsafe { + Exception::try1(|p_exists| SWSSTable_hget(self.ptr, key.as_ptr(), field.as_ptr(), &mut val, p_exists))? + }; + let maybe_fvs = if exists == 1 { + Some(unsafe { CxxString::take_raw(&mut val).unwrap() }) + } else { + None + }; + Ok(maybe_fvs) + } + + pub fn set(&self, key: &str, fvs: I) -> Result<()> + where + I: IntoIterator, + F: AsRef<[u8]>, + V: Into, + { + let key = cstr(key); + let (arr, _k) = make_field_value_array(fvs); + unsafe { Exception::try0(SWSSTable_set(self.ptr, key.as_ptr(), arr)) } + } + + pub fn hset(&self, key: &str, field: &str, val: &CxxStr) -> Result<()> { + let key = cstr(key); + let field = cstr(field); + unsafe { Exception::try0(SWSSTable_hset(self.ptr, key.as_ptr(), field.as_ptr(), val.as_raw())) } + } + + pub fn del(&self, key: &str) -> Result<()> { + let key = cstr(key); + unsafe { Exception::try0(SWSSTable_del(self.ptr, key.as_ptr())) } + } + + pub fn hdel(&self, key: &str, field: &str) -> Result<()> { + let key = cstr(key); + let field = cstr(field); + unsafe { Exception::try0(SWSSTable_hdel(self.ptr, key.as_ptr(), field.as_ptr())) } + } + + pub fn get_keys(&self) -> Result> { + unsafe { + let arr = Exception::try1(|p_arr| SWSSTable_getKeys(self.ptr, p_arr))?; + Ok(take_string_array(arr)) + } + } +} + +impl Drop for Table { + fn drop(&mut self) { + unsafe { Exception::try0(SWSSTable_free(self.ptr)).expect("Dropping Table") }; + } +} + +unsafe impl Send for Table {} + +#[cfg(feature = "async")] +impl Table { + async_util::impl_basic_async_method!(new_async <= new(db: DbConnector, table_name: &str) -> Result); + async_util::impl_basic_async_method!(get_async <= get(&self, key: &str) -> Result>); + async_util::impl_basic_async_method!(hget_async <= hget(&self, key: &str, field: &str) -> Result>); + async_util::impl_basic_async_method!( + set_async <= set(&self, key: &str, fvs: I) -> Result<()> + where + I: IntoIterator + Send, + F: AsRef<[u8]>, + V: Into, + ); + async_util::impl_basic_async_method!(hset_async <= hset(&self, key: &str, field: &str, value: &CxxStr) -> Result<()>); + async_util::impl_basic_async_method!(del_async <= del(&self, key: &str) -> Result<()>); + async_util::impl_basic_async_method!(hdel_async <= hdel(&self, key: &str, field: &str) -> Result<()>); + async_util::impl_basic_async_method!(get_keys_async <= get_keys(&self) -> Result>); +} diff --git a/crates/swss-common/src/types/zmqclient.rs b/crates/swss-common/src/types/zmqclient.rs index 1c59c31..12d7c51 100644 --- a/crates/swss-common/src/types/zmqclient.rs +++ b/crates/swss-common/src/types/zmqclient.rs @@ -1,5 +1,5 @@ use super::*; -use crate::*; +use crate::bindings::*; /// Rust wrapper around `swss::ZmqClient`. #[derive(Debug)] @@ -8,34 +8,42 @@ pub struct ZmqClient { } impl ZmqClient { - pub fn new(endpoint: &str) -> Self { + pub fn new(endpoint: &str) -> Result { let endpoint = cstr(endpoint); - let ptr = unsafe { SWSSZmqClient_new(endpoint.as_ptr()) }; - Self { ptr } + let ptr = unsafe { Exception::try1(|p_zc| SWSSZmqClient_new(endpoint.as_ptr(), p_zc))? }; + Ok(Self { ptr }) } - pub fn is_connected(&self) -> bool { - unsafe { SWSSZmqClient_isConnected(self.ptr) == 1 } + pub fn is_connected(&self) -> Result { + let status = unsafe { Exception::try1(|p_status| SWSSZmqClient_isConnected(self.ptr, p_status))? }; + Ok(status == 1) } - pub fn connect(&self) { - unsafe { SWSSZmqClient_connect(self.ptr) } + pub fn connect(&self) -> Result<()> { + unsafe { Exception::try0(SWSSZmqClient_connect(self.ptr)) } } - pub fn send_msg(&self, db_name: &str, table_name: &str, kfvs: I) + pub fn send_msg(&self, db_name: &str, table_name: &str, kfvs: I) -> Result<()> where I: IntoIterator, { let db_name = cstr(db_name); let table_name = cstr(table_name); let (kfvs, _k) = make_key_op_field_values_array(kfvs); - unsafe { SWSSZmqClient_sendMsg(self.ptr, db_name.as_ptr(), table_name.as_ptr(), kfvs) }; + unsafe { + Exception::try0(SWSSZmqClient_sendMsg( + self.ptr, + db_name.as_ptr(), + table_name.as_ptr(), + kfvs, + )) + } } } impl Drop for ZmqClient { fn drop(&mut self) { - unsafe { SWSSZmqClient_free(self.ptr) }; + unsafe { Exception::try0(SWSSZmqClient_free(self.ptr)).expect("Dropping ZmqClient") }; } } @@ -43,10 +51,10 @@ unsafe impl Send for ZmqClient {} #[cfg(feature = "async")] impl ZmqClient { - async_util::impl_basic_async_method!(new_async <= new(endpoint: &str) -> Self); - async_util::impl_basic_async_method!(connect_async <= connect(&self)); + async_util::impl_basic_async_method!(new_async <= new(endpoint: &str) -> Result); + async_util::impl_basic_async_method!(connect_async <= connect(&self) -> Result<()>); async_util::impl_basic_async_method!( - send_msg_async <= send_msg(&self, db_name: &str, table_name: &str, kfvs: I) + send_msg_async <= send_msg(&self, db_name: &str, table_name: &str, kfvs: I) -> Result<()> where I: IntoIterator + Send, ); diff --git a/crates/swss-common/src/types/zmqconsumerstatetable.rs b/crates/swss-common/src/types/zmqconsumerstatetable.rs index 0eab7a0..ecc07bf 100644 --- a/crates/swss-common/src/types/zmqconsumerstatetable.rs +++ b/crates/swss-common/src/types/zmqconsumerstatetable.rs @@ -1,5 +1,5 @@ use super::*; -use crate::*; +use crate::bindings::*; use std::{os::fd::BorrowedFd, ptr::null, sync::Arc, time::Duration}; /// Rust wrapper around `swss::ZmqConsumerStateTable`. @@ -19,39 +19,49 @@ impl ZmqConsumerStateTable { zmqs: &mut ZmqServer, pop_batch_size: Option, pri: Option, - ) -> Self { + ) -> Result { let table_name = cstr(table_name); let pop_batch_size = pop_batch_size.as_ref().map(|n| n as *const i32).unwrap_or(null()); let pri = pri.as_ref().map(|n| n as *const i32).unwrap_or(null()); - let ptr = unsafe { SWSSZmqConsumerStateTable_new(db.ptr, table_name.as_ptr(), zmqs.ptr, pop_batch_size, pri) }; + let ptr = unsafe { + Exception::try1(|p_zs| { + SWSSZmqConsumerStateTable_new(db.ptr, table_name.as_ptr(), zmqs.ptr, pop_batch_size, pri, p_zs) + })? + }; let drop_guard = Arc::new(DropGuard(ptr)); zmqs.register_consumer_state_table(drop_guard.clone()); - Self { + Ok(Self { ptr, _db: db, _drop_guard: drop_guard, - } + }) } - pub fn pops(&self) -> Vec { + pub fn pops(&self) -> Result> { unsafe { - let ans = SWSSZmqConsumerStateTable_pops(self.ptr); - take_key_op_field_values_array(ans) + let arr = Exception::try1(|p_arr| SWSSZmqConsumerStateTable_pops(self.ptr, p_arr))?; + Ok(take_key_op_field_values_array(arr)) } } - pub fn get_fd(&self) -> BorrowedFd { - let fd = unsafe { SWSSZmqConsumerStateTable_getFd(self.ptr) }; - + pub fn get_fd(&self) -> Result { // SAFETY: This fd represents the underlying zmq socket, which should remain alive as long as there // is a listener (i.e. a ZmqConsumerStateTable) - unsafe { BorrowedFd::borrow_raw(fd.try_into().unwrap()) } + unsafe { + let fd = Exception::try1(|p_fd| SWSSZmqConsumerStateTable_getFd(self.ptr, p_fd))?; + let fd = BorrowedFd::borrow_raw(fd.try_into().unwrap()); + Ok(fd) + } } - pub fn read_data(&self, timeout: Duration, interrupt_on_signal: bool) -> SelectResult { + pub fn read_data(&self, timeout: Duration, interrupt_on_signal: bool) -> Result { let timeout_ms = timeout.as_millis().try_into().unwrap(); - let res = unsafe { SWSSZmqConsumerStateTable_readData(self.ptr, timeout_ms, interrupt_on_signal as u8) }; - SelectResult::from_raw(res) + let res = unsafe { + Exception::try1(|p_res| { + SWSSZmqConsumerStateTable_readData(self.ptr, timeout_ms, interrupt_on_signal as u8, p_res) + })? + }; + Ok(SelectResult::from_raw(res)) } } @@ -64,7 +74,7 @@ pub(crate) struct DropGuard(SWSSZmqConsumerStateTable); impl Drop for DropGuard { fn drop(&mut self) { - unsafe { SWSSZmqConsumerStateTable_free(self.0) }; + unsafe { Exception::try0(SWSSZmqConsumerStateTable_free(self.0)).expect("Dropping ZmqConsumerStateTable") }; } } diff --git a/crates/swss-common/src/types/zmqproducerstatetable.rs b/crates/swss-common/src/types/zmqproducerstatetable.rs index 945c82e..d8328bc 100644 --- a/crates/swss-common/src/types/zmqproducerstatetable.rs +++ b/crates/swss-common/src/types/zmqproducerstatetable.rs @@ -1,5 +1,5 @@ use super::*; -use crate::*; +use crate::bindings::*; /// Rust wrapper around `swss::ZmqProducerStateTable`. #[derive(Debug)] @@ -10,18 +10,22 @@ pub struct ZmqProducerStateTable { } impl ZmqProducerStateTable { - pub fn new(db: DbConnector, table_name: &str, zmqc: ZmqClient, db_persistence: bool) -> Self { + pub fn new(db: DbConnector, table_name: &str, zmqc: ZmqClient, db_persistence: bool) -> Result { let table_name = cstr(table_name); let db_persistence = db_persistence as u8; - let ptr = unsafe { SWSSZmqProducerStateTable_new(db.ptr, table_name.as_ptr(), zmqc.ptr, db_persistence) }; - Self { + let ptr = unsafe { + Exception::try1(|p_zpst| { + SWSSZmqProducerStateTable_new(db.ptr, table_name.as_ptr(), zmqc.ptr, db_persistence, p_zpst) + })? + }; + Ok(Self { ptr, _db: db, _zmqc: zmqc, - } + }) } - pub fn set(&self, key: &str, fvs: I) + pub fn set(&self, key: &str, fvs: I) -> Result<()> where I: IntoIterator, F: AsRef<[u8]>, @@ -29,16 +33,16 @@ impl ZmqProducerStateTable { { let key = cstr(key); let (arr, _k) = make_field_value_array(fvs); - unsafe { SWSSZmqProducerStateTable_set(self.ptr, key.as_ptr(), arr) }; + unsafe { Exception::try0(SWSSZmqProducerStateTable_set(self.ptr, key.as_ptr(), arr)) } } - pub fn del(&self, key: &str) { + pub fn del(&self, key: &str) -> Result<()> { let key = cstr(key); - unsafe { SWSSZmqProducerStateTable_del(self.ptr, key.as_ptr()) }; + unsafe { Exception::try0(SWSSZmqProducerStateTable_del(self.ptr, key.as_ptr())) } } - pub fn db_updater_queue_size(&self) -> u64 { - unsafe { SWSSZmqProducerStateTable_dbUpdaterQueueSize(self.ptr) } + pub fn db_updater_queue_size(&self) -> Result { + unsafe { Exception::try1(|p_size| SWSSZmqProducerStateTable_dbUpdaterQueueSize(self.ptr, p_size)) } } } @@ -53,11 +57,11 @@ unsafe impl Send for ZmqProducerStateTable {} #[cfg(feature = "async")] impl ZmqProducerStateTable { async_util::impl_basic_async_method!( - set_async <= set(&self, key: &str, fvs: I) + set_async <= set(&self, key: &str, fvs: I) -> Result<()> where I: IntoIterator + Send, F: AsRef<[u8]>, V: Into, ); - async_util::impl_basic_async_method!(del_async <= del(&self, key: &str)); + async_util::impl_basic_async_method!(del_async <= del(&self, key: &str) -> Result<()>); } diff --git a/crates/swss-common/src/types/zmqserver.rs b/crates/swss-common/src/types/zmqserver.rs index fa13cad..5902eac 100644 --- a/crates/swss-common/src/types/zmqserver.rs +++ b/crates/swss-common/src/types/zmqserver.rs @@ -1,5 +1,5 @@ use super::*; -use crate::*; +use crate::bindings::*; use std::sync::Arc; /// Rust wrapper around `swss::ZmqServer`. @@ -16,13 +16,13 @@ pub struct ZmqServer { } impl ZmqServer { - pub fn new(endpoint: &str) -> Self { + pub fn new(endpoint: &str) -> Result { let endpoint = cstr(endpoint); - let obj = unsafe { SWSSZmqServer_new(endpoint.as_ptr()) }; - Self { - ptr: obj, + let ptr = unsafe { Exception::try1(|p_zs| SWSSZmqServer_new(endpoint.as_ptr(), p_zs))? }; + Ok(Self { + ptr, message_handler_guards: Vec::new(), - } + }) } pub(crate) fn register_consumer_state_table(&mut self, tbl_dg: Arc) { diff --git a/crates/swss-common/tests/async.rs b/crates/swss-common/tests/async.rs index f80ee67..c6d91e0 100644 --- a/crates/swss-common/tests/async.rs +++ b/crates/swss-common/tests/async.rs @@ -21,7 +21,7 @@ macro_rules! define_tokio_test_fns { paste! { #[tokio::test] async fn [< $f _ >]() { - $f().await; + $f().await.unwrap(); } fn [< _assert_ $f _is_send >]() { @@ -33,53 +33,53 @@ macro_rules! define_tokio_test_fns { } define_tokio_test_fns!(dbconnector_async_api_basic_test); -async fn dbconnector_async_api_basic_test() { +async fn dbconnector_async_api_basic_test() -> Result<(), Exception> { let redis = Redis::start(); let mut db = redis.db_connector(); - drop(db.clone_timeout_async(10000).await); + drop(db.clone_timeout_async(10000).await?); - assert!(db.flush_db_async().await); + assert!(db.flush_db_async().await?); let random = random_cxx_string(); - db.set_async("hello", &CxxString::new("hello, world!")).await; - db.set_async("random", &random).await; - assert_eq!(db.get_async("hello").await.unwrap(), "hello, world!"); - assert_eq!(db.get_async("random").await.unwrap(), random); - assert_eq!(db.get_async("noexist").await, None); - - assert!(db.exists_async("hello").await); - assert!(!db.exists_async("noexist").await); - assert!(db.del_async("hello").await); - assert!(!db.del_async("hello").await); - assert!(db.del_async("random").await); - assert!(!db.del_async("random").await); - assert!(!db.del_async("noexist").await); - - db.hset_async("a", "hello", &CxxString::new("hello, world!")).await; - db.hset_async("a", "random", &random).await; - assert_eq!(db.hget_async("a", "hello").await.unwrap(), "hello, world!"); - assert_eq!(db.hget_async("a", "random").await.unwrap(), random); - assert_eq!(db.hget_async("a", "noexist").await, None); - assert_eq!(db.hget_async("noexist", "noexist").await, None); - assert!(db.hexists_async("a", "hello").await); - assert!(!db.hexists_async("a", "noexist").await); - assert!(!db.hexists_async("noexist", "hello").await); - assert!(db.hdel_async("a", "hello").await); - assert!(!db.hdel_async("a", "hello").await); - assert!(db.hdel_async("a", "random").await); - assert!(!db.hdel_async("a", "random").await); - assert!(!db.hdel_async("a", "noexist").await); - assert!(!db.hdel_async("noexist", "noexist").await); - assert!(!db.del_async("a").await); - - assert!(db.hgetall_async("a").await.is_empty()); - db.hset_async("a", "a", &CxxString::new("1")).await; - db.hset_async("a", "b", &CxxString::new("2")).await; - db.hset_async("a", "c", &CxxString::new("3")).await; + db.set_async("hello", &CxxString::new("hello, world!")).await?; + db.set_async("random", &random).await?; + assert_eq!(db.get_async("hello").await?.unwrap(), "hello, world!"); + assert_eq!(db.get_async("random").await?.unwrap(), random); + assert_eq!(db.get_async("noexist").await?, None); + + assert!(db.exists_async("hello").await?); + assert!(!db.exists_async("noexist").await?); + assert!(db.del_async("hello").await?); + assert!(!db.del_async("hello").await?); + assert!(db.del_async("random").await?); + assert!(!db.del_async("random").await?); + assert!(!db.del_async("noexist").await?); + + db.hset_async("a", "hello", &CxxString::new("hello, world!")).await?; + db.hset_async("a", "random", &random).await?; + assert_eq!(db.hget_async("a", "hello").await?.unwrap(), "hello, world!"); + assert_eq!(db.hget_async("a", "random").await?.unwrap(), random); + assert_eq!(db.hget_async("a", "noexist").await?, None); + assert_eq!(db.hget_async("noexist", "noexist").await?, None); + assert!(db.hexists_async("a", "hello").await?); + assert!(!db.hexists_async("a", "noexist").await?); + assert!(!db.hexists_async("noexist", "hello").await?); + assert!(db.hdel_async("a", "hello").await?); + assert!(!db.hdel_async("a", "hello").await?); + assert!(db.hdel_async("a", "random").await?); + assert!(!db.hdel_async("a", "random").await?); + assert!(!db.hdel_async("a", "noexist").await?); + assert!(!db.hdel_async("noexist", "noexist").await?); + assert!(!db.del_async("a").await?); + + assert!(db.hgetall_async("a").await?.is_empty()); + db.hset_async("a", "a", &CxxString::new("1")).await?; + db.hset_async("a", "b", &CxxString::new("2")).await?; + db.hset_async("a", "c", &CxxString::new("3")).await?; assert_eq!( - db.hgetall_async("a").await, + db.hgetall_async("a").await?, HashMap::from_iter([ ("a".into(), "1".into()), ("b".into(), "2".into()), @@ -87,56 +87,60 @@ async fn dbconnector_async_api_basic_test() { ]) ); - assert!(db.flush_db_async().await); + assert!(db.flush_db_async().await?); + + Ok(()) } define_tokio_test_fns!(consumer_producer_state_tables_async_api_basic_test); -async fn consumer_producer_state_tables_async_api_basic_test() { +async fn consumer_producer_state_tables_async_api_basic_test() -> Result<(), Exception> { let redis = Redis::start(); - let mut pst = ProducerStateTable::new(redis.db_connector(), "table_a"); - let mut cst = ConsumerStateTable::new(redis.db_connector(), "table_a", None, None); + let mut pst = ProducerStateTable::new(redis.db_connector(), "table_a")?; + let mut cst = ConsumerStateTable::new(redis.db_connector(), "table_a", None, None)?; - assert!(cst.pops().is_empty()); + assert!(cst.pops()?.is_empty()); let mut kfvs = random_kfvs(); for (i, kfv) in kfvs.iter().enumerate() { - assert_eq!(pst.count(), i as i64); + assert_eq!(pst.count()?, i as i64); match kfv.operation { - KeyOperation::Set => pst.set_async(&kfv.key, kfv.field_values.clone()).await, - KeyOperation::Del => pst.del_async(&kfv.key).await, + KeyOperation::Set => pst.set_async(&kfv.key, kfv.field_values.clone()).await?, + KeyOperation::Del => pst.del_async(&kfv.key).await?, } } timeout(2000, cst.read_data_async()).await.unwrap(); - let mut kfvs_cst = cst.pops(); - assert!(cst.pops().is_empty()); + let mut kfvs_cst = cst.pops()?; + assert!(cst.pops()?.is_empty()); kfvs.sort_unstable(); kfvs_cst.sort_unstable(); assert_eq!(kfvs_cst.len(), kfvs.len()); assert_eq!(kfvs_cst, kfvs); + + Ok(()) } define_tokio_test_fns!(subscriber_state_table_async_api_basic_test); -async fn subscriber_state_table_async_api_basic_test() { +async fn subscriber_state_table_async_api_basic_test() -> Result<(), Exception> { let redis = Redis::start(); let mut db = redis.db_connector(); - let mut sst = SubscriberStateTable::new(redis.db_connector(), "table_a", None, None); - assert!(sst.pops().is_empty()); + let mut sst = SubscriberStateTable::new(redis.db_connector(), "table_a", None, None)?; + assert!(sst.pops()?.is_empty()); db.hset_async("table_a:key_a", "field_a", &CxxString::new("value_a")) - .await; + .await?; db.hset_async("table_a:key_a", "field_b", &CxxString::new("value_b")) - .await; + .await?; timeout(300, sst.read_data_async()).await.unwrap(); - let mut kfvs = sst.pops(); + let mut kfvs = sst.pops()?; // SubscriberStateTable will pick up duplicate KeyOpFieldValues' after two SETs on the same // key. I'm not actually sure if this is intended. assert_eq!(kfvs.len(), 2); assert_eq!(kfvs[0], kfvs[1]); - assert!(sst.pops().is_empty()); + assert!(sst.pops()?.is_empty()); let KeyOpFieldValues { key, @@ -153,30 +157,66 @@ async fn subscriber_state_table_async_api_basic_test() { ("field_b".into(), "value_b".into()) ]) ); + + Ok(()) } define_tokio_test_fns!(zmq_consumer_producer_state_table_async_api_basic_test); -async fn zmq_consumer_producer_state_table_async_api_basic_test() { +async fn zmq_consumer_producer_state_table_async_api_basic_test() -> Result<(), Exception> { let (endpoint, _delete) = random_zmq_endpoint(); - let mut zmqs = ZmqServer::new(&endpoint); - let zmqc = ZmqClient::new(&endpoint); + let mut zmqs = ZmqServer::new(&endpoint)?; + let zmqc = ZmqClient::new(&endpoint)?; let redis = Redis::start(); - let mut zpst = ZmqProducerStateTable::new(redis.db_connector(), "table_a", zmqc, false); - let mut zcst = ZmqConsumerStateTable::new(redis.db_connector(), "table_a", &mut zmqs, None, None); + let mut zpst = ZmqProducerStateTable::new(redis.db_connector(), "table_a", zmqc, false)?; + let mut zcst = ZmqConsumerStateTable::new(redis.db_connector(), "table_a", &mut zmqs, None, None)?; let kfvs = random_kfvs(); for kfv in &kfvs { match kfv.operation { - KeyOperation::Set => zpst.set_async(&kfv.key, kfv.field_values.clone()).await, - KeyOperation::Del => zpst.del_async(&kfv.key).await, + KeyOperation::Set => zpst.set_async(&kfv.key, kfv.field_values.clone()).await?, + KeyOperation::Del => zpst.del_async(&kfv.key).await?, } } let mut kfvs_seen = Vec::new(); while kfvs_seen.len() != kfvs.len() { timeout(2000, zcst.read_data_async()).await.unwrap(); - kfvs_seen.extend(zcst.pops()); + kfvs_seen.extend(zcst.pops()?); } assert_eq!(kfvs, kfvs_seen); + + Ok(()) +} + +define_tokio_test_fns!(table_async_api_basic_test); +async fn table_async_api_basic_test() -> Result<(), Exception> { + let redis = Redis::start(); + let mut table = Table::new_async(redis.db_connector(), "mytable").await?; + assert!(table.get_keys_async().await?.is_empty()); + assert!(table.get_async("mykey").await?.is_none()); + + let fvs = random_fvs(); + table.set_async("mykey", fvs.clone()).await?; + assert_eq!(table.get_keys_async().await?, &["mykey"]); + assert_eq!(table.get_async("mykey").await?.as_ref(), Some(&fvs)); + + let (field, value) = fvs.iter().next().unwrap(); + assert_eq!(table.hget_async("mykey", field).await?.as_ref(), Some(value)); + table.hdel_async("mykey", field).await?; + assert_eq!(table.hget_async("mykey", field).await?, None); + + table + .hset_async("mykey", field, &CxxString::from("my special value")) + .await?; + assert_eq!( + table.hget_async("mykey", field).await?.unwrap().as_bytes(), + b"my special value" + ); + + table.del_async("mykey").await?; + assert!(table.get_keys_async().await?.is_empty()); + assert!(table.get_async("mykey").await?.is_none()); + + Ok(()) } diff --git a/crates/swss-common/tests/common.rs b/crates/swss-common/tests/common.rs index c823ddc..c2c9ce7 100644 --- a/crates/swss-common/tests/common.rs +++ b/crates/swss-common/tests/common.rs @@ -52,7 +52,7 @@ impl Redis { } pub fn db_connector(&self) -> DbConnector { - DbConnector::new_unix(0, &self.sock, 0) + DbConnector::new_unix(0, &self.sock, 0).unwrap() } } @@ -119,18 +119,24 @@ pub fn random_cxx_string() -> CxxString { CxxString::new(random_string()) } +pub fn random_fvs() -> FieldValues { + let mut field_values = HashMap::new(); + for _ in 0..rand::thread_rng().gen_range(100..1000) { + field_values.insert(random_string(), random_cxx_string()); + } + field_values +} + pub fn random_kfv() -> KeyOpFieldValues { let key = random_string(); let operation = if random() { KeyOperation::Set } else { KeyOperation::Del }; - let mut field_values = HashMap::new(); - - if operation == KeyOperation::Set { + let field_values = if operation == KeyOperation::Set { // We need at least one field-value pair, otherwise swss::BinarySerializer infers that // the operation is DEL even if the .operation field is SET - for _ in 0..rand::thread_rng().gen_range(100..1000) { - field_values.insert(random_string(), random_cxx_string()); - } - } + random_fvs() + } else { + HashMap::new() + }; KeyOpFieldValues { key, diff --git a/crates/swss-common/tests/sync.rs b/crates/swss-common/tests/sync.rs index 8bd5acb..71ba468 100644 --- a/crates/swss-common/tests/sync.rs +++ b/crates/swss-common/tests/sync.rs @@ -5,53 +5,53 @@ use std::{collections::HashMap, time::Duration}; use swss_common::*; #[test] -fn dbconnector_sync_api_basic_test() { +fn dbconnector_sync_api_basic_test() -> Result<(), Exception> { let redis = Redis::start(); let db = redis.db_connector(); drop(db.clone_timeout(10000)); - assert!(db.flush_db()); + assert!(db.flush_db()?); let random = random_cxx_string(); - db.set("hello", &CxxString::new("hello, world!")); - db.set("random", &random); - assert_eq!(db.get("hello").unwrap(), "hello, world!"); - assert_eq!(db.get("random").unwrap(), random); - assert_eq!(db.get("noexist"), None); - - assert!(db.exists("hello")); - assert!(!db.exists("noexist")); - assert!(db.del("hello")); - assert!(!db.del("hello")); - assert!(db.del("random")); - assert!(!db.del("random")); - assert!(!db.del("noexist")); - - db.hset("a", "hello", &CxxString::new("hello, world!")); - db.hset("a", "random", &random); - assert_eq!(db.hget("a", "hello").unwrap(), "hello, world!"); - assert_eq!(db.hget("a", "random").unwrap(), random); - assert_eq!(db.hget("a", "noexist"), None); - assert_eq!(db.hget("noexist", "noexist"), None); - assert!(db.hexists("a", "hello")); - assert!(!db.hexists("a", "noexist")); - assert!(!db.hexists("noexist", "hello")); - assert!(db.hdel("a", "hello")); - assert!(!db.hdel("a", "hello")); - assert!(db.hdel("a", "random")); - assert!(!db.hdel("a", "random")); - assert!(!db.hdel("a", "noexist")); - assert!(!db.hdel("noexist", "noexist")); - assert!(!db.del("a")); - - assert!(db.hgetall("a").is_empty()); - db.hset("a", "a", &CxxString::new("1")); - db.hset("a", "b", &CxxString::new("2")); - db.hset("a", "c", &CxxString::new("3")); + db.set("hello", &CxxString::new("hello, world!"))?; + db.set("random", &random)?; + assert_eq!(db.get("hello")?.unwrap(), "hello, world!"); + assert_eq!(db.get("random")?.unwrap(), random); + assert_eq!(db.get("noexist")?, None); + + assert!(db.exists("hello")?); + assert!(!db.exists("noexist")?); + assert!(db.del("hello")?); + assert!(!db.del("hello")?); + assert!(db.del("random")?); + assert!(!db.del("random")?); + assert!(!db.del("noexist")?); + + db.hset("a", "hello", &CxxString::new("hello, world!"))?; + db.hset("a", "random", &random)?; + assert_eq!(db.hget("a", "hello")?.unwrap(), "hello, world!"); + assert_eq!(db.hget("a", "random")?.unwrap(), random); + assert_eq!(db.hget("a", "noexist")?, None); + assert_eq!(db.hget("noexist", "noexist")?, None); + assert!(db.hexists("a", "hello")?); + assert!(!db.hexists("a", "noexist")?); + assert!(!db.hexists("noexist", "hello")?); + assert!(db.hdel("a", "hello")?); + assert!(!db.hdel("a", "hello")?); + assert!(db.hdel("a", "random")?); + assert!(!db.hdel("a", "random")?); + assert!(!db.hdel("a", "noexist")?); + assert!(!db.hdel("noexist", "noexist")?); + assert!(!db.del("a")?); + + assert!(db.hgetall("a")?.is_empty()); + db.hset("a", "a", &CxxString::new("1"))?; + db.hset("a", "b", &CxxString::new("2"))?; + db.hset("a", "c", &CxxString::new("3"))?; assert_eq!( - db.hgetall("a"), + db.hgetall("a")?, HashMap::from_iter([ ("a".into(), "1".into()), ("b".into(), "2".into()), @@ -59,56 +59,60 @@ fn dbconnector_sync_api_basic_test() { ]) ); - assert!(db.flush_db()); + assert!(db.flush_db()?); + + Ok(()) } #[test] -fn consumer_producer_state_tables_sync_api_basic_test() { +fn consumer_producer_state_tables_sync_api_basic_test() -> Result<(), Exception> { sonic_db_config_init_for_test(); let redis = Redis::start(); - let pst = ProducerStateTable::new(redis.db_connector(), "table_a"); - let cst = ConsumerStateTable::new(redis.db_connector(), "table_a", None, None); + let pst = ProducerStateTable::new(redis.db_connector(), "table_a")?; + let cst = ConsumerStateTable::new(redis.db_connector(), "table_a", None, None)?; - assert!(cst.pops().is_empty()); + assert!(cst.pops()?.is_empty()); let mut kfvs = random_kfvs(); for (i, kfv) in kfvs.iter().enumerate() { - assert_eq!(pst.count(), i as i64); + assert_eq!(pst.count()?, i as i64); match kfv.operation { - KeyOperation::Set => pst.set(&kfv.key, kfv.field_values.clone()), - KeyOperation::Del => pst.del(&kfv.key), + KeyOperation::Set => pst.set(&kfv.key, kfv.field_values.clone())?, + KeyOperation::Del => pst.del(&kfv.key)?, } } - assert_eq!(cst.read_data(Duration::from_millis(2000), true), SelectResult::Data); - let mut kfvs_cst = cst.pops(); - assert!(cst.pops().is_empty()); + assert_eq!(cst.read_data(Duration::from_millis(2000), true)?, SelectResult::Data); + let mut kfvs_cst = cst.pops()?; + assert!(cst.pops()?.is_empty()); kfvs.sort_unstable(); kfvs_cst.sort_unstable(); assert_eq!(kfvs_cst.len(), kfvs.len()); assert_eq!(kfvs_cst, kfvs); + + Ok(()) } #[test] -fn subscriber_state_table_sync_api_basic_test() { +fn subscriber_state_table_sync_api_basic_test() -> Result<(), Exception> { sonic_db_config_init_for_test(); let redis = Redis::start(); let db = redis.db_connector(); - let sst = SubscriberStateTable::new(redis.db_connector(), "table_a", None, None); - assert!(sst.pops().is_empty()); + let sst = SubscriberStateTable::new(redis.db_connector(), "table_a", None, None)?; + assert!(sst.pops()?.is_empty()); - db.hset("table_a:key_a", "field_a", &CxxString::new("value_a")); - db.hset("table_a:key_a", "field_b", &CxxString::new("value_b")); - assert_eq!(sst.read_data(Duration::from_millis(300), true), SelectResult::Data); - let mut kfvs = sst.pops(); + db.hset("table_a:key_a", "field_a", &CxxString::new("value_a"))?; + db.hset("table_a:key_a", "field_b", &CxxString::new("value_b"))?; + assert_eq!(sst.read_data(Duration::from_millis(300), true)?, SelectResult::Data); + let mut kfvs = sst.pops()?; // SubscriberStateTable will pick up duplicate KeyOpFieldValues' after two SETs on the same // key. I'm not actually sure if this is intended. assert_eq!(kfvs.len(), 2); assert_eq!(kfvs[0], kfvs[1]); - assert!(sst.pops().is_empty()); + assert!(sst.pops()?.is_empty()); let KeyOpFieldValues { key, @@ -125,59 +129,97 @@ fn subscriber_state_table_sync_api_basic_test() { ("field_b".into(), "value_b".into()) ]) ); + + Ok(()) } #[test] -fn zmq_consumer_state_table_sync_api_basic_test() { +fn zmq_consumer_state_table_sync_api_basic_test() -> Result<(), Exception> { use SelectResult::*; let (endpoint, _delete) = random_zmq_endpoint(); - let mut zmqs = ZmqServer::new(&endpoint); - let zmqc = ZmqClient::new(&endpoint); - assert!(zmqc.is_connected()); + let mut zmqs = ZmqServer::new(&endpoint)?; + let zmqc = ZmqClient::new(&endpoint)?; + assert!(zmqc.is_connected()?); let redis = Redis::start(); - let zcst_table_a = ZmqConsumerStateTable::new(redis.db_connector(), "table_a", &mut zmqs, None, None); - let zcst_table_b = ZmqConsumerStateTable::new(redis.db_connector(), "table_b", &mut zmqs, None, None); + let zcst_table_a = ZmqConsumerStateTable::new(redis.db_connector(), "table_a", &mut zmqs, None, None)?; + let zcst_table_b = ZmqConsumerStateTable::new(redis.db_connector(), "table_b", &mut zmqs, None, None)?; let kfvs = random_kfvs(); - zmqc.send_msg("", "table_a", kfvs.clone()); // db name is empty because we are using DbConnector::new_unix - assert_eq!(zcst_table_a.read_data(Duration::from_millis(1500), true), Data); + zmqc.send_msg("", "table_a", kfvs.clone())?; // db name is empty because we are using DbConnector::new_unix + assert_eq!(zcst_table_a.read_data(Duration::from_millis(1500), true)?, Data); - zmqc.send_msg("", "table_b", kfvs.clone()); - assert_eq!(zcst_table_b.read_data(Duration::from_millis(1500), true), Data); + zmqc.send_msg("", "table_b", kfvs.clone())?; + assert_eq!(zcst_table_b.read_data(Duration::from_millis(1500), true)?, Data); - let kfvs_a = zcst_table_a.pops(); - let kvfs_b = zcst_table_b.pops(); + let kfvs_a = zcst_table_a.pops()?; + let kvfs_b = zcst_table_b.pops()?; assert_eq!(kfvs_a, kvfs_b); assert_eq!(kfvs, kfvs_a); + + Ok(()) } #[test] -fn zmq_consumer_producer_state_tables_sync_api_basic_test() { +fn zmq_consumer_producer_state_tables_sync_api_basic_test() -> Result<(), Exception> { use SelectResult::*; let (endpoint, _delete) = random_zmq_endpoint(); - let mut zmqs = ZmqServer::new(&endpoint); - let zmqc = ZmqClient::new(&endpoint); + let mut zmqs = ZmqServer::new(&endpoint)?; + let zmqc = ZmqClient::new(&endpoint)?; let redis = Redis::start(); - let zpst = ZmqProducerStateTable::new(redis.db_connector(), "table_a", zmqc, false); - let zcst = ZmqConsumerStateTable::new(redis.db_connector(), "table_a", &mut zmqs, None, None); + let zpst = ZmqProducerStateTable::new(redis.db_connector(), "table_a", zmqc, false)?; + let zcst = ZmqConsumerStateTable::new(redis.db_connector(), "table_a", &mut zmqs, None, None)?; let kfvs = random_kfvs(); for kfv in &kfvs { match kfv.operation { - KeyOperation::Set => zpst.set(&kfv.key, kfv.field_values.clone()), - KeyOperation::Del => zpst.del(&kfv.key), + KeyOperation::Set => zpst.set(&kfv.key, kfv.field_values.clone())?, + KeyOperation::Del => zpst.del(&kfv.key)?, } } let mut kfvs_seen = Vec::new(); while kfvs_seen.len() != kfvs.len() { - assert_eq!(zcst.read_data(Duration::from_millis(2000), true), Data); - kfvs_seen.extend(zcst.pops()); + assert_eq!(zcst.read_data(Duration::from_millis(2000), true)?, Data); + kfvs_seen.extend(zcst.pops()?); } assert_eq!(kfvs, kfvs_seen); + + Ok(()) +} + +#[test] +fn table_sync_api_basic_test() -> Result<(), Exception> { + let redis = Redis::start(); + let table = Table::new(redis.db_connector(), "mytable")?; + assert!(table.get_keys()?.is_empty()); + assert!(table.get("mykey")?.is_none()); + + let fvs = random_fvs(); + table.set("mykey", fvs.clone())?; + assert_eq!(table.get_keys()?, &["mykey"]); + assert_eq!(table.get("mykey")?.as_ref(), Some(&fvs)); + + let (field, value) = fvs.iter().next().unwrap(); + assert_eq!(table.hget("mykey", field)?.as_ref(), Some(value)); + table.hdel("mykey", field)?; + assert_eq!(table.hget("mykey", field)?, None); + + table.hset("mykey", field, &CxxString::from("my special value"))?; + assert_eq!(table.hget("mykey", field)?.unwrap().as_bytes(), b"my special value"); + + table.del("mykey")?; + assert!(table.get_keys()?.is_empty()); + assert!(table.get("mykey")?.is_none()); + + Ok(()) +} + +#[test] +fn expected_exceptions() { + DbConnector::new_tcp(0, "127.0.0.1", 1, 10000).unwrap_err(); }