diff --git a/.vscode/extensions.json b/.vscode/extensions.json index 38c9e8a8f..da9331c37 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -1,3 +1,3 @@ { - "recommendations": ["Vue.volar", "Vue.vscode-typescript-vue-plugin", "esbenp.prettier-vscode"] + "recommendations": ["Vue.volar", "esbenp.prettier-vscode"] } diff --git a/Cargo.lock b/Cargo.lock index 09114a53c..361e706ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -767,6 +767,7 @@ dependencies = [ "convert_case", "getrandom", "ic-cdk", + "ic-stable-structures", "rand_chacha", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index e557552af..6a9dd47f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,5 +55,5 @@ sha2 = "0.10" syn = { version = "2.0", features = ["extra-traits", "full"] } thiserror = "1.0.48" time = { version = "0.3", features = ["formatting", "parsing"] } -tokio = { version = "1.33.0", features = ["full"] } +tokio = { version = "1.33.0" } uuid = { version = "1.4.1", features = ["serde", "v4"] } diff --git a/canisters/wallet/api/src/proposal.rs b/canisters/wallet/api/src/proposal.rs index 7e942e758..8450c5cfc 100644 --- a/canisters/wallet/api/src/proposal.rs +++ b/canisters/wallet/api/src/proposal.rs @@ -30,14 +30,14 @@ pub enum ProposalStatusDTO { #[derive(CandidType, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] pub enum ProposalStatusCodeDTO { - Created, - Adopted, - Rejected, - Cancelled, - Scheduled, - Processing, - Completed, - Failed, + Created = 0, + Adopted = 1, + Rejected = 2, + Cancelled = 3, + Scheduled = 4, + Processing = 5, + Completed = 6, + Failed = 7, } #[derive(CandidType, Deserialize, Debug, Clone)] diff --git a/canisters/wallet/impl/Cargo.toml b/canisters/wallet/impl/Cargo.toml index c1fe28e16..6498df578 100644 --- a/canisters/wallet/impl/Cargo.toml +++ b/canisters/wallet/impl/Cargo.toml @@ -13,6 +13,9 @@ repository = "https://github.com/dfinity/orbit-wallet" crate-type = ["cdylib"] bench = false +[features] +canbench = ["canbench-rs"] + [dependencies] anyhow = { workspace = true } async-trait = { workspace = true } @@ -32,7 +35,7 @@ lazy_static = { workspace = true } num-bigint = { workspace = true } num-traits = { workspace = true } prometheus = { workspace = true } -serde = { workspace = true } +serde = { workspace = true, features = ["derive"] } serde_bytes = { workspace = true } serde_cbor = { workspace = true } sha2 = { workspace = true } diff --git a/canisters/wallet/impl/canbench.yml b/canisters/wallet/impl/canbench.yml index 148fde6a0..1cd86f935 100644 --- a/canisters/wallet/impl/canbench.yml +++ b/canisters/wallet/impl/canbench.yml @@ -1,3 +1,3 @@ -build_cmd: cargo build --locked --release --target wasm32-unknown-unknown --features canbench-rs +build_cmd: cargo build --locked --release --target wasm32-unknown-unknown --features canbench wasm_path: ../../../target/wasm32-unknown-unknown/release/wallet.wasm results_path: results.yml diff --git a/canisters/wallet/impl/results.yml b/canisters/wallet/impl/results.yml index 02fa2a195..8ade4caa6 100644 --- a/canisters/wallet/impl/results.yml +++ b/canisters/wallet/impl/results.yml @@ -1,20 +1,26 @@ benches: - batch_insert_100_proposals: + repository_batch_insert_100_proposals: total: - instructions: 3051795057 + instructions: 3107967048 heap_increase: 2 - stable_memory_increase: 898 + stable_memory_increase: 1410 scopes: {} - filter_all_proposals_by_default_filters: + repository_filter_all_proposal_ids_by_default_filters: total: - instructions: 4690513262 - heap_increase: 10 + instructions: 489667791 + heap_increase: 0 stable_memory_increase: 0 scopes: {} - list_all_proposals: + repository_list_all_proposals: total: - instructions: 3308284274 + instructions: 3300020748 heap_increase: 8 stable_memory_increase: 0 scopes: {} + service_filter_all_proposals_with_default_filters: + total: + instructions: 336959888 + heap_increase: 0 + stable_memory_increase: 0 + scopes: {} version: 0.1.1 diff --git a/canisters/wallet/impl/src/controllers/wallet.rs b/canisters/wallet/impl/src/controllers/wallet.rs index 9fb51ad27..a3a3711fe 100644 --- a/canisters/wallet/impl/src/controllers/wallet.rs +++ b/canisters/wallet/impl/src/controllers/wallet.rs @@ -17,7 +17,7 @@ use wallet_api::{ }; // Canister entrypoints for the controller. -#[cfg(any(not(feature = "canbench-rs"), test))] +#[cfg(any(not(feature = "canbench"), test))] #[ic_cdk_macros::init] async fn initialize(input: Option) { match input { @@ -29,7 +29,7 @@ async fn initialize(input: Option) { /// The init is overriden for benchmarking purposes. /// /// This is only used for benchmarking and is not included in the final canister. -#[cfg(feature = "canbench-rs")] +#[cfg(all(feature = "canbench", not(test)))] #[ic_cdk_macros::init] pub async fn mock_init() { // Initialize the random number generator with a fixed seed to ensure deterministic @@ -79,7 +79,7 @@ impl WalletController { Self { wallet_service } } - #[cfg(any(not(feature = "canbench-rs"), test))] + #[cfg(any(not(feature = "canbench"), test))] async fn initialize(&self, input: wallet_api::WalletInit) { let ctx = &call_context(); self.wallet_service diff --git a/canisters/wallet/impl/src/core/memory.rs b/canisters/wallet/impl/src/core/memory.rs index f139996b6..7c8ff2602 100644 --- a/canisters/wallet/impl/src/core/memory.rs +++ b/canisters/wallet/impl/src/core/memory.rs @@ -34,6 +34,10 @@ pub const ADDRESS_BOOK_MEMORY_ID: MemoryId = MemoryId::new(22); pub const ADDRESS_BOOK_INDEX_MEMORY_ID: MemoryId = MemoryId::new(23); pub const PROPOSAL_PROPOSER_INDEX_MEMORY_ID: MemoryId = MemoryId::new(24); pub const PROPOSAL_CREATION_TIME_INDEX_MEMORY_ID: MemoryId = MemoryId::new(25); +pub const PROPOSAL_KEY_CREATION_TIME_INDEX_MEMORY_ID: MemoryId = MemoryId::new(26); +pub const PROPOSAL_KEY_EXPIRATION_TIME_INDEX_MEMORY_ID: MemoryId = MemoryId::new(27); +pub const PROPOSAL_SORT_INDEX_MEMORY_ID: MemoryId = MemoryId::new(28); +pub const PROPOSAL_STATUS_MODIFICATION_INDEX_MEMORY_ID: MemoryId = MemoryId::new(29); thread_local! { /// Static configuration of the canister. diff --git a/canisters/wallet/impl/src/core/utils.rs b/canisters/wallet/impl/src/core/utils.rs index 9866b9d41..25bf965a3 100644 --- a/canisters/wallet/impl/src/core/utils.rs +++ b/canisters/wallet/impl/src/core/utils.rs @@ -81,18 +81,6 @@ pub fn calculate_minimum_threshold(percentage: &Percentage, total_value: &usize) } } -/// Matches a date against a date range. -/// -/// If the provided range is `None`, then the date is considered to be within the range. -pub(crate) fn match_date_range(date: &u64, start_dt: &Option, to_dt: &Option) -> bool { - match (start_dt, to_dt) { - (Some(start_dt), Some(to_dt)) => date >= start_dt && date <= to_dt, - (Some(start_dt), None) => date >= start_dt, - (None, Some(to_dt)) => date <= to_dt, - (None, None) => true, - } -} - /// Retains items based on the result of an access control evaluation. /// /// This function will evaluate the access control for each item in the list and retain only the diff --git a/canisters/wallet/impl/src/jobs/cancel_expired_proposals.rs b/canisters/wallet/impl/src/jobs/cancel_expired_proposals.rs index 2d13bd5ef..48f16253a 100644 --- a/canisters/wallet/impl/src/jobs/cancel_expired_proposals.rs +++ b/canisters/wallet/impl/src/jobs/cancel_expired_proposals.rs @@ -1,4 +1,8 @@ -use crate::{core::ic_cdk::api::time, models::ProposalStatus, repositories::ProposalRepository}; +use crate::{ + core::ic_cdk::api::time, + models::{ProposalStatus, ProposalStatusCode}, + repositories::ProposalRepository, +}; use async_trait::async_trait; use ic_canister_core::repository::Repository; @@ -27,7 +31,7 @@ impl Job { let mut proposals = self.proposal_repository.find_by_expiration_dt_and_status( None, Some(current_time), - ProposalStatus::Created.to_string(), + ProposalStatusCode::Created.to_string(), ); for proposal in proposals.iter_mut() { diff --git a/canisters/wallet/impl/src/jobs/schedule_adopted_proposals.rs b/canisters/wallet/impl/src/jobs/schedule_adopted_proposals.rs index dd0fad47a..2305762a1 100644 --- a/canisters/wallet/impl/src/jobs/schedule_adopted_proposals.rs +++ b/canisters/wallet/impl/src/jobs/schedule_adopted_proposals.rs @@ -1,6 +1,6 @@ use crate::{ core::ic_cdk::api::time, - models::{ProposalExecutionPlan, ProposalStatus}, + models::{ProposalExecutionPlan, ProposalStatus, ProposalStatusCode}, repositories::ProposalRepository, }; use async_trait::async_trait; @@ -34,7 +34,7 @@ impl Job { async fn process_adopted_proposals(&self) { let current_time = time(); let mut proposals = self.proposal_repository.find_by_status( - ProposalStatus::Adopted.to_string(), + ProposalStatusCode::Adopted, None, Some(current_time), ); diff --git a/canisters/wallet/impl/src/mappers/account.rs b/canisters/wallet/impl/src/mappers/account.rs index 2f7022f29..77b11d281 100644 --- a/canisters/wallet/impl/src/mappers/account.rs +++ b/canisters/wallet/impl/src/mappers/account.rs @@ -7,7 +7,7 @@ use crate::{ }, repositories::policy::PROPOSAL_POLICY_REPOSITORY, }; -use ic_canister_core::{repository::Repository, types::UUID, utils::timestamp_to_rfc3339}; +use ic_canister_core::{repository::Repository, utils::timestamp_to_rfc3339}; use uuid::Uuid; use wallet_api::{AccountBalanceDTO, AccountBalanceInfoDTO, AccountDTO, CriteriaDTO}; @@ -47,7 +47,7 @@ impl AccountMapper { pub fn from_create_input( input: AddAccountOperationInput, - account_id: UUID, + account_id: AccountId, address: Option, ) -> Result { if !input diff --git a/canisters/wallet/impl/src/mappers/proposal_status.rs b/canisters/wallet/impl/src/mappers/proposal_status.rs index 56c00342a..d88871688 100644 --- a/canisters/wallet/impl/src/mappers/proposal_status.rs +++ b/canisters/wallet/impl/src/mappers/proposal_status.rs @@ -1,4 +1,4 @@ -use crate::models::ProposalStatus; +use crate::models::{ProposalStatus, ProposalStatusCode}; use ic_canister_core::utils::{rfc3339_to_timestamp, timestamp_to_rfc3339}; use wallet_api::{ProposalStatusCodeDTO, ProposalStatusDTO}; @@ -59,6 +59,21 @@ impl From for ProposalStatus { } } +impl From for ProposalStatusCode { + fn from(status: ProposalStatusCodeDTO) -> Self { + match status { + ProposalStatusCodeDTO::Created => ProposalStatusCode::Created, + ProposalStatusCodeDTO::Adopted => ProposalStatusCode::Adopted, + ProposalStatusCodeDTO::Rejected => ProposalStatusCode::Rejected, + ProposalStatusCodeDTO::Completed => ProposalStatusCode::Completed, + ProposalStatusCodeDTO::Failed => ProposalStatusCode::Failed, + ProposalStatusCodeDTO::Processing => ProposalStatusCode::Processing, + ProposalStatusCodeDTO::Scheduled => ProposalStatusCode::Scheduled, + ProposalStatusCodeDTO::Cancelled => ProposalStatusCode::Cancelled, + } + } +} + #[derive(Debug)] pub struct ProposalStatusMapper; diff --git a/canisters/wallet/impl/src/models/access_control.rs b/canisters/wallet/impl/src/models/access_control.rs index f5313de7e..018745202 100644 --- a/canisters/wallet/impl/src/models/access_control.rs +++ b/canisters/wallet/impl/src/models/access_control.rs @@ -414,13 +414,14 @@ mod tests { } } -#[cfg(test)] +#[cfg(any(test, feature = "canbench"))] pub mod access_control_test_utils { use super::*; + use uuid::Uuid; pub fn mock_access_policy() -> AccessControlPolicy { AccessControlPolicy { - id: [0; 16], + id: *Uuid::new_v4().as_bytes(), user: UserSpecifier::Any, resource: ResourceSpecifier::Common( ResourceType::Account, diff --git a/canisters/wallet/impl/src/models/account.rs b/canisters/wallet/impl/src/models/account.rs index eff0b2c84..ff127064d 100644 --- a/canisters/wallet/impl/src/models/account.rs +++ b/canisters/wallet/impl/src/models/account.rs @@ -291,10 +291,11 @@ pub mod account_test_utils { use super::*; use crate::repositories::ACCOUNT_REPOSITORY; use ic_canister_core::repository::Repository; + use uuid::Uuid; pub fn mock_account() -> Account { Account { - id: [0; 16], + id: *Uuid::new_v4().as_bytes(), address: "0x1234".to_string(), balance: None, blockchain: Blockchain::InternetComputer, diff --git a/canisters/wallet/impl/src/models/indexes/mod.rs b/canisters/wallet/impl/src/models/indexes/mod.rs index a4612c071..15e7f113a 100644 --- a/canisters/wallet/impl/src/models/indexes/mod.rs +++ b/canisters/wallet/impl/src/models/indexes/mod.rs @@ -6,9 +6,13 @@ pub mod notification_user_index; pub mod proposal_account_index; pub mod proposal_creation_time_index; pub mod proposal_expiration_time_index; +pub mod proposal_key_creation_time_index; +pub mod proposal_key_expiration_time_index; pub mod proposal_proposer_index; pub mod proposal_scheduled_index; +pub mod proposal_sort_index; pub mod proposal_status_index; +pub mod proposal_status_modification_index; pub mod proposal_voter_index; pub mod transfer_account_index; pub mod transfer_status_index; diff --git a/canisters/wallet/impl/src/models/indexes/proposal_account_index.rs b/canisters/wallet/impl/src/models/indexes/proposal_account_index.rs index 66b8f99df..50be17a31 100644 --- a/canisters/wallet/impl/src/models/indexes/proposal_account_index.rs +++ b/canisters/wallet/impl/src/models/indexes/proposal_account_index.rs @@ -1,6 +1,5 @@ use crate::models::{AccountId, Proposal, ProposalId, ProposalOperation}; use candid::{CandidType, Deserialize}; -use ic_canister_core::types::Timestamp; use ic_canister_macros::stable_object; /// Index of proposals by account id. @@ -9,8 +8,6 @@ use ic_canister_macros::stable_object; pub struct ProposalAccountIndex { /// The account id that is associated with this proposal. pub account_id: AccountId, - /// The time when the proposal was created. - pub created_at: Timestamp, /// The proposal id, which is a UUID. pub proposal_id: ProposalId, } @@ -18,8 +15,6 @@ pub struct ProposalAccountIndex { #[derive(Clone, Debug)] pub struct ProposalAccountIndexCriteria { pub account_id: AccountId, - pub from_dt: Option, - pub to_dt: Option, } impl Proposal { @@ -27,7 +22,6 @@ impl Proposal { if let ProposalOperation::Transfer(ctx) = &self.operation { return Some(ProposalAccountIndex { proposal_id: self.id.to_owned(), - created_at: self.created_timestamp.to_owned(), account_id: ctx.input.from_account_id.to_owned(), }); } @@ -50,9 +44,8 @@ mod tests { let account_id = [0; 16]; let proposal_id = [1; 16]; let model = ProposalAccountIndex { - proposal_id, account_id, - created_at: 0, + proposal_id, }; let serialized_model = model.to_bytes(); diff --git a/canisters/wallet/impl/src/models/indexes/proposal_creation_time_index.rs b/canisters/wallet/impl/src/models/indexes/proposal_creation_time_index.rs index 9a3156cc1..acf77a160 100644 --- a/canisters/wallet/impl/src/models/indexes/proposal_creation_time_index.rs +++ b/canisters/wallet/impl/src/models/indexes/proposal_creation_time_index.rs @@ -1,12 +1,11 @@ use crate::models::Proposal; -use candid::{CandidType, Deserialize}; use ic_canister_core::types::{Timestamp, UUID}; -use ic_canister_macros::stable_object; +use ic_canister_macros::storable; use std::hash::Hash; -/// Represents a proposal index by execution time. -#[stable_object] -#[derive(CandidType, Deserialize, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +/// Represents a proposal index by creation time. +#[storable] +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct ProposalCreationTimeIndex { /// The time the proposal was created. pub created_at: Timestamp, diff --git a/canisters/wallet/impl/src/models/indexes/proposal_expiration_time_index.rs b/canisters/wallet/impl/src/models/indexes/proposal_expiration_time_index.rs index 0ef6b7495..fbb73996c 100644 --- a/canisters/wallet/impl/src/models/indexes/proposal_expiration_time_index.rs +++ b/canisters/wallet/impl/src/models/indexes/proposal_expiration_time_index.rs @@ -4,11 +4,11 @@ use ic_canister_core::types::{Timestamp, UUID}; use ic_canister_macros::stable_object; use std::hash::Hash; -/// Represents a proposal index by execution time. +/// Represents a proposal index by expiration time. #[stable_object] #[derive(CandidType, Deserialize, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct ProposalExpirationTimeIndex { - /// The time the proposal is scheduled to be set as expired if not executed. + /// The time the proposal is scheduled to be set as expired if still pending. pub expiration_dt: Timestamp, /// The proposal id, which is a UUID. pub proposal_id: UUID, diff --git a/canisters/wallet/impl/src/models/indexes/proposal_key_creation_time_index.rs b/canisters/wallet/impl/src/models/indexes/proposal_key_creation_time_index.rs new file mode 100644 index 000000000..7d267fcba --- /dev/null +++ b/canisters/wallet/impl/src/models/indexes/proposal_key_creation_time_index.rs @@ -0,0 +1,47 @@ +use crate::models::Proposal; +use candid::{CandidType, Deserialize}; +use ic_canister_core::types::{Timestamp, UUID}; +use ic_canister_macros::stable_object; +use std::hash::Hash; + +/// Represents a proposal index by creation time prefixed by the proposal id. +#[stable_object] +#[derive(CandidType, Deserialize, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct ProposalKeyCreationTimeIndex { + /// The proposal id, which is a UUID. + pub proposal_id: UUID, + /// The time the proposal was created. + pub created_at: Timestamp, +} + +#[derive(Clone, Debug)] +pub struct ProposalKeyCreationTimeIndexCriteria { + pub proposal_id: UUID, + pub from_dt: Option, + pub to_dt: Option, +} + +impl Proposal { + pub fn to_index_by_key_and_creation_dt(&self) -> ProposalKeyCreationTimeIndex { + ProposalKeyCreationTimeIndex { + proposal_id: self.id, + created_at: self.created_timestamp, + } + } +} + +#[cfg(test)] +mod tests { + use crate::models::proposal_test_utils::mock_proposal; + + #[test] + fn test_proposal_to_index_by_key_and_creation_dt() { + let mut proposal = mock_proposal(); + proposal.created_timestamp = 5; + + let index = proposal.to_index_by_key_and_creation_dt(); + + assert_eq!(index.proposal_id, proposal.id); + assert_eq!(index.created_at, 5); + } +} diff --git a/canisters/wallet/impl/src/models/indexes/proposal_key_expiration_time_index.rs b/canisters/wallet/impl/src/models/indexes/proposal_key_expiration_time_index.rs new file mode 100644 index 000000000..06eac976d --- /dev/null +++ b/canisters/wallet/impl/src/models/indexes/proposal_key_expiration_time_index.rs @@ -0,0 +1,47 @@ +use crate::models::Proposal; +use candid::{CandidType, Deserialize}; +use ic_canister_core::types::{Timestamp, UUID}; +use ic_canister_macros::stable_object; +use std::hash::Hash; + +/// Represents a proposal index by expiration time prefixed by the proposal id. +#[stable_object] +#[derive(CandidType, Deserialize, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct ProposalKeyExpirationTimeIndex { + /// The proposal id, which is a UUID. + pub proposal_id: UUID, + /// The time the proposal is scheduled to be set as expired if still pending. + pub expiration_dt: Timestamp, +} + +#[derive(Clone, Debug)] +pub struct ProposalKeyExpirationTimeIndexCriteria { + pub proposal_id: UUID, + pub from_dt: Option, + pub to_dt: Option, +} + +impl Proposal { + pub fn to_index_by_key_and_expiration_dt(&self) -> ProposalKeyExpirationTimeIndex { + ProposalKeyExpirationTimeIndex { + proposal_id: self.id, + expiration_dt: self.expiration_dt, + } + } +} + +#[cfg(test)] +mod tests { + use crate::models::proposal_test_utils::mock_proposal; + + #[test] + fn test_proposal_to_index_by_key_and_expiration_dt() { + let mut proposal = mock_proposal(); + proposal.expiration_dt = 5; + + let index = proposal.to_index_by_key_and_expiration_dt(); + + assert_eq!(index.proposal_id, proposal.id); + assert_eq!(index.expiration_dt, 5); + } +} diff --git a/canisters/wallet/impl/src/models/indexes/proposal_proposer_index.rs b/canisters/wallet/impl/src/models/indexes/proposal_proposer_index.rs index 6f216b026..a953af79b 100644 --- a/canisters/wallet/impl/src/models/indexes/proposal_proposer_index.rs +++ b/canisters/wallet/impl/src/models/indexes/proposal_proposer_index.rs @@ -1,6 +1,6 @@ use crate::models::Proposal; use candid::{CandidType, Deserialize}; -use ic_canister_core::types::{Timestamp, UUID}; +use ic_canister_core::types::UUID; use ic_canister_macros::stable_object; /// Index of proposals by the proposer user id. @@ -8,26 +8,21 @@ use ic_canister_macros::stable_object; #[derive(CandidType, Deserialize, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct ProposalProposerIndex { /// The user who proposed this proposal. - pub user_id: UUID, - /// The time when the proposal was created. - pub created_at: Timestamp, + pub proposer_id: UUID, /// The proposal id, which is a UUID. pub proposal_id: UUID, } #[derive(Clone, Debug)] pub struct ProposalProposerIndexCriteria { - pub user_id: UUID, - pub from_dt: Option, - pub to_dt: Option, + pub proposer_id: UUID, } impl Proposal { pub fn to_index_for_proposer(&self) -> ProposalProposerIndex { ProposalProposerIndex { - user_id: self.proposed_by.to_owned(), + proposer_id: self.proposed_by.to_owned(), proposal_id: self.id.to_owned(), - created_at: self.created_timestamp.to_owned(), } } } @@ -43,16 +38,15 @@ mod tests { let proposal_id = [1; 16]; let user_id = [u8::MAX; 16]; let model = ProposalProposerIndex { + proposer_id: user_id, proposal_id, - user_id, - created_at: 0, }; let serialized_model = model.to_bytes(); let deserialized_model = ProposalProposerIndex::from_bytes(serialized_model); assert_eq!(model.proposal_id, deserialized_model.proposal_id); - assert_eq!(model.user_id, deserialized_model.user_id); + assert_eq!(model.proposer_id, deserialized_model.proposer_id); } #[test] @@ -79,7 +73,7 @@ mod tests { let index = proposal.to_index_for_proposer(); - assert_eq!(index.created_at, proposal.created_timestamp); - assert_eq!(index.user_id, proposal.proposed_by); + assert_eq!(index.proposal_id, proposal.id); + assert_eq!(index.proposer_id, proposal.proposed_by); } } diff --git a/canisters/wallet/impl/src/models/indexes/proposal_sort_index.rs b/canisters/wallet/impl/src/models/indexes/proposal_sort_index.rs new file mode 100644 index 000000000..cfcd4da5a --- /dev/null +++ b/canisters/wallet/impl/src/models/indexes/proposal_sort_index.rs @@ -0,0 +1,107 @@ +use crate::models::{Proposal, ProposalId}; +use ic_canister_core::types::Timestamp; +use ic_canister_macros::storable; + +/// Index of proposals to use for sorting. +#[storable] +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct ProposalSortIndex { + /// The proposal id, which is a UUID. + pub key: ProposalSortIndexKey, + /// The proposal's last modification timestamp. + pub value: ProposalSortIndexValue, +} + +#[storable] +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct ProposalSortIndexKey { + pub proposal_id: ProposalId, +} + +#[storable] +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct ProposalSortIndexValue { + /// The proposal's last modification timestamp. + pub modification_timestamp: Timestamp, + /// The proposal's creation timestamp. + pub creation_timestamp: Timestamp, + /// The proposal's expiration_dt. + pub expiration_timestamp: Timestamp, +} + +#[derive(Clone, Debug)] +pub struct ProposalSortIndexCriteria { + pub proposal_id: ProposalId, +} + +impl Proposal { + pub fn to_index_for_sorting(&self) -> ProposalSortIndex { + ProposalSortIndex { + key: ProposalSortIndexKey { + proposal_id: self.id, + }, + value: ProposalSortIndexValue { + modification_timestamp: self.last_modification_timestamp, + creation_timestamp: self.created_timestamp, + expiration_timestamp: self.expiration_dt, + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::models::proposal_test_utils::mock_proposal; + use ic_stable_structures::Storable; + + #[test] + fn valid_model_serialization() { + let proposal_id = [1; 16]; + let model = ProposalSortIndex { + key: ProposalSortIndexKey { proposal_id }, + value: ProposalSortIndexValue { + creation_timestamp: 1, + modification_timestamp: 2, + expiration_timestamp: 3, + }, + }; + + let serialized_model = model.to_bytes(); + let deserialized_model = ProposalSortIndex::from_bytes(serialized_model); + + assert_eq!(model.key.proposal_id, deserialized_model.key.proposal_id); + assert_eq!( + model.value.creation_timestamp, + deserialized_model.value.creation_timestamp + ); + assert_eq!( + model.value.modification_timestamp, + deserialized_model.value.modification_timestamp + ); + assert_eq!( + model.value.expiration_timestamp, + deserialized_model.value.expiration_timestamp + ); + } + + #[test] + fn valid_user_voter_indexes() { + let mut proposal = mock_proposal(); + proposal.id = [1; 16]; + proposal.proposed_by = [u8::MAX; 16]; + proposal.created_timestamp = 1; + proposal.last_modification_timestamp = 2; + proposal.expiration_dt = 3; + + let index = proposal.to_index_for_sorting(); + + assert_eq!(index.key.proposal_id, proposal.id); + assert_eq!(index.value.creation_timestamp, proposal.created_timestamp); + assert_eq!( + index.value.modification_timestamp, + proposal.last_modification_timestamp + ); + assert_eq!(index.value.expiration_timestamp, proposal.expiration_dt); + } +} diff --git a/canisters/wallet/impl/src/models/indexes/proposal_status_index.rs b/canisters/wallet/impl/src/models/indexes/proposal_status_index.rs index 7a8f3c3b1..d2b6b9cfe 100644 --- a/canisters/wallet/impl/src/models/indexes/proposal_status_index.rs +++ b/canisters/wallet/impl/src/models/indexes/proposal_status_index.rs @@ -1,33 +1,27 @@ -use crate::models::Proposal; -use candid::{CandidType, Deserialize}; -use ic_canister_core::types::{Timestamp, UUID}; -use ic_canister_macros::stable_object; +use crate::models::{Proposal, ProposalStatusCode}; +use ic_canister_core::types::UUID; +use ic_canister_macros::storable; use std::hash::Hash; /// Represents a proposal index by its status. -#[stable_object] -#[derive(CandidType, Deserialize, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[storable] +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct ProposalStatusIndex { /// The status of the proposal. - pub status: String, - /// The last time the proposal was modified. - pub last_modification_timestamp: Timestamp, + pub status: ProposalStatusCode, /// The proposal id, which is a UUID. pub proposal_id: UUID, } #[derive(Clone, Debug)] pub struct ProposalStatusIndexCriteria { - pub status: String, - pub from_dt: Option, - pub to_dt: Option, + pub status: ProposalStatusCode, } impl Proposal { pub fn to_index_by_status(&self) -> ProposalStatusIndex { ProposalStatusIndex { - status: self.status.to_string(), - last_modification_timestamp: self.last_modification_timestamp, + status: self.status.to_type(), proposal_id: self.id, } } @@ -35,7 +29,7 @@ impl Proposal { #[cfg(test)] mod tests { - use crate::models::{proposal_test_utils::mock_proposal, ProposalStatus}; + use crate::models::{proposal_test_utils::mock_proposal, ProposalStatus, ProposalStatusCode}; #[test] fn test_proposal_to_index_by_status() { @@ -46,7 +40,6 @@ mod tests { let index = proposal.to_index_by_status(); assert_eq!(index.proposal_id, proposal.id); - assert_eq!(index.last_modification_timestamp, 5); - assert_eq!(index.status, "created"); + assert_eq!(index.status, ProposalStatusCode::Created); } } diff --git a/canisters/wallet/impl/src/models/indexes/proposal_status_modification_index.rs b/canisters/wallet/impl/src/models/indexes/proposal_status_modification_index.rs new file mode 100644 index 000000000..02914b85a --- /dev/null +++ b/canisters/wallet/impl/src/models/indexes/proposal_status_modification_index.rs @@ -0,0 +1,51 @@ +use crate::models::{Proposal, ProposalId, ProposalStatusCode}; +use ic_canister_core::types::Timestamp; +use ic_canister_macros::storable; +use std::hash::Hash; + +/// Represents a proposal index by its status and the last modification timestamp. +#[storable] +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct ProposalStatusModificationIndex { + /// The status of the proposal. + pub status: ProposalStatusCode, + /// The last modification timestamp of the proposal. + pub modification_timestamp: Timestamp, + /// The proposal id, which is a UUID. + pub proposal_id: ProposalId, +} + +#[derive(Clone, Debug)] +pub struct ProposalStatusModificationIndexCriteria { + pub status: ProposalStatusCode, + pub from_dt: Option, + pub to_dt: Option, +} + +impl Proposal { + pub fn to_index_by_status_and_modification(&self) -> ProposalStatusModificationIndex { + ProposalStatusModificationIndex { + status: self.status.to_type(), + modification_timestamp: self.last_modification_timestamp, + proposal_id: self.id, + } + } +} + +#[cfg(test)] +mod tests { + use crate::models::{proposal_test_utils::mock_proposal, ProposalStatus, ProposalStatusCode}; + + #[test] + fn test_proposal_to_index_by_status_and_modification() { + let mut proposal = mock_proposal(); + proposal.last_modification_timestamp = 5; + proposal.status = ProposalStatus::Created; + + let index = proposal.to_index_by_status_and_modification(); + + assert_eq!(index.proposal_id, proposal.id); + assert_eq!(index.status, ProposalStatusCode::Created); + assert_eq!(index.modification_timestamp, 5); + } +} diff --git a/canisters/wallet/impl/src/models/indexes/proposal_voter_index.rs b/canisters/wallet/impl/src/models/indexes/proposal_voter_index.rs index e1332b6d5..8a33f6177 100644 --- a/canisters/wallet/impl/src/models/indexes/proposal_voter_index.rs +++ b/canisters/wallet/impl/src/models/indexes/proposal_voter_index.rs @@ -1,6 +1,6 @@ use crate::models::Proposal; use candid::{CandidType, Deserialize}; -use ic_canister_core::types::{Timestamp, UUID}; +use ic_canister_core::types::UUID; use ic_canister_macros::stable_object; use std::collections::HashSet; @@ -9,18 +9,14 @@ use std::collections::HashSet; #[derive(CandidType, Deserialize, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct ProposalVoterIndex { /// The user that has voted on this proposal. - pub user_id: UUID, - /// The time when the proposal was created. - pub created_at: Timestamp, + pub voter_id: UUID, /// The proposal id, which is a UUID. pub proposal_id: UUID, } #[derive(Clone, Debug)] pub struct ProposalVoterIndexCriteria { - pub user_id: UUID, - pub from_dt: Option, - pub to_dt: Option, + pub voter_id: UUID, } impl Proposal { @@ -33,9 +29,8 @@ impl Proposal { users .iter() .map(|user_id| ProposalVoterIndex { + voter_id: user_id.to_owned(), proposal_id: self.id.to_owned(), - created_at: self.created_timestamp.to_owned(), - user_id: user_id.to_owned(), }) .collect() } @@ -52,16 +47,15 @@ mod tests { let proposal_id = [1; 16]; let user_id = [u8::MAX; 16]; let model = ProposalVoterIndex { + voter_id: user_id, proposal_id, - user_id, - created_at: 0, }; let serialized_model = model.to_bytes(); let deserialized_model = ProposalVoterIndex::from_bytes(serialized_model); assert_eq!(model.proposal_id, deserialized_model.proposal_id); - assert_eq!(model.user_id, deserialized_model.user_id); + assert_eq!(model.voter_id, deserialized_model.voter_id); } #[test] @@ -89,7 +83,7 @@ mod tests { let indexes = proposal.to_index_for_voters(); assert_eq!(indexes.len(), 2); - assert!(indexes.iter().any(|i| i.user_id == [1; 16])); - assert!(indexes.iter().any(|i| i.user_id == [2; 16])); + assert!(indexes.iter().any(|i| i.voter_id == [1; 16])); + assert!(indexes.iter().any(|i| i.voter_id == [2; 16])); } } diff --git a/canisters/wallet/impl/src/models/proposal.rs b/canisters/wallet/impl/src/models/proposal.rs index f4f3d0fb6..7bed4b113 100644 --- a/canisters/wallet/impl/src/models/proposal.rs +++ b/canisters/wallet/impl/src/models/proposal.rs @@ -259,16 +259,16 @@ mod tests { } } -#[cfg(any(test, feature = "canbench-rs"))] +#[cfg(any(test, feature = "canbench"))] pub mod proposal_test_utils { - use num_bigint::BigUint; - use super::*; use crate::models::{Metadata, ProposalVoteStatus, TransferOperation, TransferOperationInput}; + use num_bigint::BigUint; + use uuid::Uuid; pub fn mock_proposal() -> Proposal { Proposal { - id: [0; 16], + id: *Uuid::new_v4().as_bytes(), title: "foo".to_string(), summary: Some("bar".to_string()), proposed_by: [1; 16], diff --git a/canisters/wallet/impl/src/models/proposal_status.rs b/canisters/wallet/impl/src/models/proposal_status.rs index c621e0dfa..0d09c18ee 100644 --- a/canisters/wallet/impl/src/models/proposal_status.rs +++ b/canisters/wallet/impl/src/models/proposal_status.rs @@ -1,8 +1,7 @@ use candid::{CandidType, Deserialize}; use ic_canister_core::types::Timestamp; -use ic_canister_macros::stable_object; +use ic_canister_macros::{stable_object, storable}; use std::fmt::{Display, Formatter}; -use wallet_api::ProposalStatusCodeDTO; #[stable_object] #[derive(CandidType, Deserialize, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] @@ -17,19 +16,70 @@ pub enum ProposalStatus { Failed { reason: Option }, } -pub type ProposalStatusCode = ProposalStatusCodeDTO; +#[storable] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(u8)] +pub enum ProposalStatusCode { + Created = 0, + Adopted = 1, + Rejected = 2, + Cancelled = 3, + Scheduled = 4, + Processing = 5, + Completed = 6, + Failed = 7, +} + +impl From for u8 { + fn from(status: ProposalStatusCode) -> Self { + status as u8 + } +} -impl Display for ProposalStatus { +impl TryFrom for ProposalStatusCode { + type Error = (); + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(ProposalStatusCode::Created), + 1 => Ok(ProposalStatusCode::Adopted), + 2 => Ok(ProposalStatusCode::Rejected), + 3 => Ok(ProposalStatusCode::Cancelled), + 4 => Ok(ProposalStatusCode::Scheduled), + 5 => Ok(ProposalStatusCode::Processing), + 6 => Ok(ProposalStatusCode::Completed), + 7 => Ok(ProposalStatusCode::Failed), + _ => Err(()), + } + } +} + +impl Display for ProposalStatusCode { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - ProposalStatus::Created => write!(f, "created"), - ProposalStatus::Adopted => write!(f, "adopted"), - ProposalStatus::Rejected => write!(f, "rejected"), - ProposalStatus::Scheduled { .. } => write!(f, "scheduled"), - ProposalStatus::Processing { .. } => write!(f, "processing"), - ProposalStatus::Completed { .. } => write!(f, "completed"), - ProposalStatus::Failed { .. } => write!(f, "failed"), - ProposalStatus::Cancelled { .. } => write!(f, "cancelled"), + ProposalStatusCode::Created => write!(f, "created"), + ProposalStatusCode::Adopted => write!(f, "adopted"), + ProposalStatusCode::Rejected => write!(f, "rejected"), + ProposalStatusCode::Scheduled => write!(f, "scheduled"), + ProposalStatusCode::Processing => write!(f, "processing"), + ProposalStatusCode::Completed => write!(f, "completed"), + ProposalStatusCode::Failed => write!(f, "failed"), + ProposalStatusCode::Cancelled => write!(f, "cancelled"), + } + } +} + +impl ProposalStatus { + pub fn to_type(&self) -> ProposalStatusCode { + match self { + ProposalStatus::Created => ProposalStatusCode::Created, + ProposalStatus::Adopted => ProposalStatusCode::Adopted, + ProposalStatus::Rejected => ProposalStatusCode::Rejected, + ProposalStatus::Scheduled { .. } => ProposalStatusCode::Scheduled, + ProposalStatus::Processing { .. } => ProposalStatusCode::Processing, + ProposalStatus::Completed { .. } => ProposalStatusCode::Completed, + ProposalStatus::Failed { .. } => ProposalStatusCode::Failed, + ProposalStatus::Cancelled { .. } => ProposalStatusCode::Cancelled, } } } @@ -40,28 +90,61 @@ mod tests { #[test] fn test_status_string_representation() { - assert_eq!(ProposalStatus::Created.to_string(), "created"); - assert_eq!(ProposalStatus::Adopted.to_string(), "adopted"); - assert_eq!(ProposalStatus::Rejected.to_string(), "rejected"); + assert_eq!(ProposalStatusCode::Created.to_string(), "created"); + assert_eq!(ProposalStatusCode::Adopted.to_string(), "adopted"); + assert_eq!(ProposalStatusCode::Rejected.to_string(), "rejected"); + assert_eq!(ProposalStatusCode::Scheduled.to_string(), "scheduled"); + assert_eq!(ProposalStatusCode::Processing.to_string(), "processing"); + assert_eq!(ProposalStatusCode::Completed.to_string(), "completed"); + assert_eq!(ProposalStatusCode::Failed.to_string(), "failed"); + assert_eq!(ProposalStatusCode::Cancelled.to_string(), "cancelled"); + } + + #[test] + fn test_to_status_u8_representation() { + assert_eq!(u8::from(ProposalStatusCode::Created), 0); + assert_eq!(u8::from(ProposalStatusCode::Adopted), 1); + assert_eq!(u8::from(ProposalStatusCode::Rejected), 2); + assert_eq!(u8::from(ProposalStatusCode::Scheduled), 4); + assert_eq!(u8::from(ProposalStatusCode::Processing), 5); + assert_eq!(u8::from(ProposalStatusCode::Completed), 6); + assert_eq!(u8::from(ProposalStatusCode::Failed), 7); + assert_eq!(u8::from(ProposalStatusCode::Cancelled), 3); + } + + #[test] + fn test_from_status_u8_representation() { + assert_eq!( + ProposalStatusCode::try_from(0), + Ok(ProposalStatusCode::Created) + ); + assert_eq!( + ProposalStatusCode::try_from(1), + Ok(ProposalStatusCode::Adopted) + ); + assert_eq!( + ProposalStatusCode::try_from(2), + Ok(ProposalStatusCode::Rejected) + ); assert_eq!( - ProposalStatus::Scheduled { scheduled_at: 0 }.to_string(), - "scheduled" + ProposalStatusCode::try_from(4), + Ok(ProposalStatusCode::Scheduled) ); assert_eq!( - ProposalStatus::Processing { started_at: 0 }.to_string(), - "processing" + ProposalStatusCode::try_from(5), + Ok(ProposalStatusCode::Processing) ); assert_eq!( - ProposalStatus::Completed { completed_at: 0 }.to_string(), - "completed" + ProposalStatusCode::try_from(6), + Ok(ProposalStatusCode::Completed) ); assert_eq!( - ProposalStatus::Failed { reason: None }.to_string(), - "failed" + ProposalStatusCode::try_from(7), + Ok(ProposalStatusCode::Failed) ); assert_eq!( - ProposalStatus::Cancelled { reason: None }.to_string(), - "cancelled" + ProposalStatusCode::try_from(3), + Ok(ProposalStatusCode::Cancelled) ); } } diff --git a/canisters/wallet/impl/src/models/transfer.rs b/canisters/wallet/impl/src/models/transfer.rs index 1eb5af742..a501ec3a7 100644 --- a/canisters/wallet/impl/src/models/transfer.rs +++ b/canisters/wallet/impl/src/models/transfer.rs @@ -281,10 +281,11 @@ mod tests { #[cfg(test)] pub mod transfer_test_utils { use super::*; + use uuid::Uuid; pub fn mock_transfer() -> Transfer { Transfer { - id: [1; 16], + id: *Uuid::new_v4().as_bytes(), initiator_user: [0; 16], from_account: [0; 16], proposal_id: [2; 16], diff --git a/canisters/wallet/impl/src/models/user.rs b/canisters/wallet/impl/src/models/user.rs index 6566ea4dc..03909662b 100644 --- a/canisters/wallet/impl/src/models/user.rs +++ b/canisters/wallet/impl/src/models/user.rs @@ -199,15 +199,16 @@ mod tests { } } -#[cfg(test)] +#[cfg(any(test, feature = "canbench"))] pub mod user_test_utils { use super::*; use crate::repositories::USER_REPOSITORY; use ic_canister_core::repository::Repository; + use uuid::Uuid; pub fn mock_user() -> User { User { - id: [0; 16], + id: *Uuid::new_v4().as_bytes(), identities: vec![Principal::anonymous()], groups: vec![], name: None, diff --git a/canisters/wallet/impl/src/repositories/indexes/mod.rs b/canisters/wallet/impl/src/repositories/indexes/mod.rs index a4612c071..15e7f113a 100644 --- a/canisters/wallet/impl/src/repositories/indexes/mod.rs +++ b/canisters/wallet/impl/src/repositories/indexes/mod.rs @@ -6,9 +6,13 @@ pub mod notification_user_index; pub mod proposal_account_index; pub mod proposal_creation_time_index; pub mod proposal_expiration_time_index; +pub mod proposal_key_creation_time_index; +pub mod proposal_key_expiration_time_index; pub mod proposal_proposer_index; pub mod proposal_scheduled_index; +pub mod proposal_sort_index; pub mod proposal_status_index; +pub mod proposal_status_modification_index; pub mod proposal_voter_index; pub mod transfer_account_index; pub mod transfer_status_index; diff --git a/canisters/wallet/impl/src/repositories/indexes/proposal_account_index.rs b/canisters/wallet/impl/src/repositories/indexes/proposal_account_index.rs index 68b8a8c26..a43cb2497 100644 --- a/canisters/wallet/impl/src/repositories/indexes/proposal_account_index.rs +++ b/canisters/wallet/impl/src/repositories/indexes/proposal_account_index.rs @@ -40,12 +40,10 @@ impl IndexRepository for ProposalAccountIndexR DB.with(|db| { let start_key = ProposalAccountIndex { account_id: criteria.account_id.to_owned(), - created_at: criteria.from_dt.to_owned().unwrap_or(u64::MIN), proposal_id: [u8::MIN; 16], }; let end_key = ProposalAccountIndex { account_id: criteria.account_id.to_owned(), - created_at: criteria.to_dt.to_owned().unwrap_or(u64::MAX), proposal_id: [u8::MAX; 16], }; @@ -66,7 +64,6 @@ mod tests { let repository = ProposalAccountIndexRepository::default(); let index = ProposalAccountIndex { proposal_id: [0; 16], - created_at: 10, account_id: [1; 16], }; @@ -84,7 +81,6 @@ mod tests { let repository = ProposalAccountIndexRepository::default(); let index = ProposalAccountIndex { proposal_id: [0; 16], - created_at: 10, account_id: [1; 16], }; @@ -92,8 +88,6 @@ mod tests { let criteria = ProposalAccountIndexCriteria { account_id: [1; 16], - from_dt: None, - to_dt: None, }; let result = repository.find_by_criteria(criteria); diff --git a/canisters/wallet/impl/src/repositories/indexes/proposal_key_creation_time_index.rs b/canisters/wallet/impl/src/repositories/indexes/proposal_key_creation_time_index.rs new file mode 100644 index 000000000..272e3535e --- /dev/null +++ b/canisters/wallet/impl/src/repositories/indexes/proposal_key_creation_time_index.rs @@ -0,0 +1,119 @@ +use crate::{ + core::{with_memory_manager, Memory, PROPOSAL_KEY_CREATION_TIME_INDEX_MEMORY_ID}, + models::indexes::proposal_key_creation_time_index::{ + ProposalKeyCreationTimeIndex, ProposalKeyCreationTimeIndexCriteria, + }, +}; +use ic_canister_core::{repository::IndexRepository, types::UUID}; +use ic_stable_structures::{memory_manager::VirtualMemory, StableBTreeMap}; +use std::{cell::RefCell, collections::HashSet}; + +thread_local! { + static DB: RefCell>> = with_memory_manager(|memory_manager| { + RefCell::new( + StableBTreeMap::init(memory_manager.get(PROPOSAL_KEY_CREATION_TIME_INDEX_MEMORY_ID)) + ) + }) +} + +#[derive(Default, Debug)] +pub struct ProposalKeyCreationTimeIndexRepository {} + +impl IndexRepository + for ProposalKeyCreationTimeIndexRepository +{ + type FindByCriteria = ProposalKeyCreationTimeIndexCriteria; + + fn exists(&self, index: &ProposalKeyCreationTimeIndex) -> bool { + DB.with(|m| m.borrow().contains_key(index)) + } + + fn insert(&self, index: ProposalKeyCreationTimeIndex) { + DB.with(|m| m.borrow_mut().insert(index, ())); + } + + fn remove(&self, index: &ProposalKeyCreationTimeIndex) -> bool { + DB.with(|m| m.borrow_mut().remove(index).is_some()) + } + + fn find_by_criteria(&self, criteria: Self::FindByCriteria) -> HashSet { + DB.with(|db| { + let start_key = ProposalKeyCreationTimeIndex { + proposal_id: criteria.proposal_id.to_owned(), + created_at: criteria.from_dt.to_owned().unwrap_or(u64::MIN), + }; + let end_key = ProposalKeyCreationTimeIndex { + proposal_id: criteria.proposal_id, + created_at: criteria.to_dt.to_owned().unwrap_or(u64::MAX), + }; + + db.borrow() + .range(start_key..=end_key) + .map(|(index, _)| index.proposal_id) + .collect::>() + }) + } +} + +impl ProposalKeyCreationTimeIndexRepository { + pub fn exists_by_criteria(&self, criteria: ProposalKeyCreationTimeIndexCriteria) -> bool { + let start_key = ProposalKeyCreationTimeIndex { + proposal_id: criteria.proposal_id.to_owned(), + created_at: criteria.from_dt.to_owned().unwrap_or(u64::MIN), + }; + let end_key = ProposalKeyCreationTimeIndex { + proposal_id: criteria.proposal_id, + created_at: criteria.to_dt.to_owned().unwrap_or(u64::MAX), + }; + + DB.with(|db| db.borrow().range(start_key..=end_key).next().is_some()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_repository_crud() { + let repository = ProposalKeyCreationTimeIndexRepository::default(); + let index = ProposalKeyCreationTimeIndex { + created_at: 10, + proposal_id: [1; 16], + }; + + assert!(!repository.exists(&index)); + + repository.insert(index.clone()); + + assert!(repository.exists(&index)); + assert!(repository.remove(&index)); + assert!(!repository.exists(&index)); + } + + #[test] + fn test_find_by_criteria() { + let repository = ProposalKeyCreationTimeIndexRepository::default(); + let index = ProposalKeyCreationTimeIndex { + created_at: 10, + proposal_id: [1; 16], + }; + + repository.insert(index.clone()); + repository.insert(ProposalKeyCreationTimeIndex { + created_at: 11, + proposal_id: [2; 16], + }); + + let criteria = ProposalKeyCreationTimeIndexCriteria { + proposal_id: [1; 16], + from_dt: None, + to_dt: Some(10), + }; + + let result = repository.find_by_criteria(criteria); + + assert_eq!(result.len(), 1); + assert!(result.contains(&index.proposal_id)); + } +} diff --git a/canisters/wallet/impl/src/repositories/indexes/proposal_key_expiration_time_index.rs b/canisters/wallet/impl/src/repositories/indexes/proposal_key_expiration_time_index.rs new file mode 100644 index 000000000..0baa84c7b --- /dev/null +++ b/canisters/wallet/impl/src/repositories/indexes/proposal_key_expiration_time_index.rs @@ -0,0 +1,119 @@ +use crate::{ + core::{with_memory_manager, Memory, PROPOSAL_KEY_EXPIRATION_TIME_INDEX_MEMORY_ID}, + models::indexes::proposal_key_expiration_time_index::{ + ProposalKeyExpirationTimeIndex, ProposalKeyExpirationTimeIndexCriteria, + }, +}; +use ic_canister_core::{repository::IndexRepository, types::UUID}; +use ic_stable_structures::{memory_manager::VirtualMemory, StableBTreeMap}; +use std::{cell::RefCell, collections::HashSet}; + +thread_local! { + static DB: RefCell>> = with_memory_manager(|memory_manager| { + RefCell::new( + StableBTreeMap::init(memory_manager.get(PROPOSAL_KEY_EXPIRATION_TIME_INDEX_MEMORY_ID)) + ) + }) +} + +#[derive(Default, Debug)] +pub struct ProposalKeyExpirationTimeIndexRepository {} + +impl IndexRepository + for ProposalKeyExpirationTimeIndexRepository +{ + type FindByCriteria = ProposalKeyExpirationTimeIndexCriteria; + + fn exists(&self, index: &ProposalKeyExpirationTimeIndex) -> bool { + DB.with(|m| m.borrow().get(index).is_some()) + } + + fn insert(&self, index: ProposalKeyExpirationTimeIndex) { + DB.with(|m| m.borrow_mut().insert(index, ())); + } + + fn remove(&self, index: &ProposalKeyExpirationTimeIndex) -> bool { + DB.with(|m| m.borrow_mut().remove(index).is_some()) + } + + fn find_by_criteria(&self, criteria: Self::FindByCriteria) -> HashSet { + DB.with(|db| { + let start_key = ProposalKeyExpirationTimeIndex { + proposal_id: criteria.proposal_id.to_owned(), + expiration_dt: criteria.from_dt.to_owned().unwrap_or(u64::MIN), + }; + let end_key = ProposalKeyExpirationTimeIndex { + proposal_id: criteria.proposal_id, + expiration_dt: criteria.to_dt.to_owned().unwrap_or(u64::MAX), + }; + + db.borrow() + .range(start_key..=end_key) + .map(|(index, _)| index.proposal_id) + .collect::>() + }) + } +} + +impl ProposalKeyExpirationTimeIndexRepository { + pub fn exists_by_criteria(&self, criteria: ProposalKeyExpirationTimeIndexCriteria) -> bool { + let start_key = ProposalKeyExpirationTimeIndex { + proposal_id: criteria.proposal_id.to_owned(), + expiration_dt: criteria.from_dt.to_owned().unwrap_or(u64::MIN), + }; + let end_key = ProposalKeyExpirationTimeIndex { + proposal_id: criteria.proposal_id, + expiration_dt: criteria.to_dt.to_owned().unwrap_or(u64::MAX), + }; + + DB.with(|db| db.borrow().range(start_key..=end_key).next().is_some()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_repository_crud() { + let repository = ProposalKeyExpirationTimeIndexRepository::default(); + let index = ProposalKeyExpirationTimeIndex { + expiration_dt: 10, + proposal_id: [1; 16], + }; + + assert!(!repository.exists(&index)); + + repository.insert(index.clone()); + + assert!(repository.exists(&index)); + assert!(repository.remove(&index)); + assert!(!repository.exists(&index)); + } + + #[test] + fn test_find_by_criteria() { + let repository = ProposalKeyExpirationTimeIndexRepository::default(); + let index = ProposalKeyExpirationTimeIndex { + expiration_dt: 10, + proposal_id: [1; 16], + }; + + repository.insert(index.clone()); + repository.insert(ProposalKeyExpirationTimeIndex { + expiration_dt: 11, + proposal_id: [2; 16], + }); + + let criteria = ProposalKeyExpirationTimeIndexCriteria { + proposal_id: [1; 16], + from_dt: None, + to_dt: Some(10), + }; + + let result = repository.find_by_criteria(criteria); + + assert_eq!(result.len(), 1); + assert!(result.contains(&index.proposal_id)); + } +} diff --git a/canisters/wallet/impl/src/repositories/indexes/proposal_proposer_index.rs b/canisters/wallet/impl/src/repositories/indexes/proposal_proposer_index.rs index 908ae554a..56298cfe7 100644 --- a/canisters/wallet/impl/src/repositories/indexes/proposal_proposer_index.rs +++ b/canisters/wallet/impl/src/repositories/indexes/proposal_proposer_index.rs @@ -38,13 +38,11 @@ impl IndexRepository for ProposalProposerIndexRepos fn find_by_criteria(&self, criteria: Self::FindByCriteria) -> HashSet { DB.with(|db| { let start_key = ProposalProposerIndex { - user_id: criteria.user_id.to_owned(), - created_at: criteria.from_dt.to_owned().unwrap_or(u64::MIN), + proposer_id: criteria.proposer_id.to_owned(), proposal_id: [u8::MIN; 16], }; let end_key = ProposalProposerIndex { - user_id: criteria.user_id.to_owned(), - created_at: criteria.to_dt.to_owned().unwrap_or(u64::MAX), + proposer_id: criteria.proposer_id.to_owned(), proposal_id: [u8::MAX; 16], }; @@ -65,8 +63,7 @@ mod tests { let repository = ProposalProposerIndexRepository::default(); let index = ProposalProposerIndex { proposal_id: [0; 16], - created_at: 10, - user_id: [1; 16], + proposer_id: [1; 16], }; assert!(!repository.exists(&index)); @@ -83,16 +80,13 @@ mod tests { let repository = ProposalProposerIndexRepository::default(); let index = ProposalProposerIndex { proposal_id: [0; 16], - created_at: 10, - user_id: [1; 16], + proposer_id: [1; 16], }; repository.insert(index.clone()); let criteria = ProposalProposerIndexCriteria { - user_id: [1; 16], - from_dt: None, - to_dt: None, + proposer_id: [1; 16], }; let result = repository.find_by_criteria(criteria); diff --git a/canisters/wallet/impl/src/repositories/indexes/proposal_sort_index.rs b/canisters/wallet/impl/src/repositories/indexes/proposal_sort_index.rs new file mode 100644 index 000000000..e19d074c0 --- /dev/null +++ b/canisters/wallet/impl/src/repositories/indexes/proposal_sort_index.rs @@ -0,0 +1,120 @@ +use crate::{ + core::{with_memory_manager, Memory, PROPOSAL_SORT_INDEX_MEMORY_ID}, + models::indexes::proposal_sort_index::{ + ProposalSortIndex, ProposalSortIndexCriteria, ProposalSortIndexKey, ProposalSortIndexValue, + }, +}; +use ic_canister_core::repository::IndexRepository; +use ic_stable_structures::{memory_manager::VirtualMemory, StableBTreeMap}; +use std::{cell::RefCell, collections::HashSet}; + +thread_local! { + static DB: RefCell>> = with_memory_manager(|memory_manager| { + RefCell::new( + StableBTreeMap::init(memory_manager.get(PROPOSAL_SORT_INDEX_MEMORY_ID)) + ) + }) +} + +/// A repository that enables finding proposals based on the voter in stable memory. +#[derive(Default, Debug)] +pub struct ProposalSortIndexRepository {} + +impl IndexRepository for ProposalSortIndexRepository { + type FindByCriteria = ProposalSortIndexCriteria; + + fn exists(&self, index: &ProposalSortIndex) -> bool { + DB.with(|m| m.borrow().contains_key(&index.key)) + } + + fn insert(&self, index: ProposalSortIndex) { + DB.with(|m| m.borrow_mut().insert(index.key, index.value)); + } + + fn remove(&self, index: &ProposalSortIndex) -> bool { + DB.with(|m| m.borrow_mut().remove(&index.key).is_some()) + } + + fn find_by_criteria(&self, criteria: Self::FindByCriteria) -> HashSet { + let value = self.get(&ProposalSortIndexKey { + proposal_id: criteria.proposal_id, + }); + + match value { + Some(value) => { + let mut set = HashSet::new(); + set.insert(value); + set + } + None => HashSet::new(), + } + } +} + +impl ProposalSortIndexRepository { + pub fn get(&self, key: &ProposalSortIndexKey) -> Option { + DB.with(|m| m.borrow().get(key)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_repository_crud() { + let repository = ProposalSortIndexRepository::default(); + let index = ProposalSortIndex { + key: ProposalSortIndexKey { + proposal_id: [0; 16], + }, + value: ProposalSortIndexValue { + creation_timestamp: 1, + modification_timestamp: 2, + expiration_timestamp: 3, + }, + }; + + assert!(!repository.exists(&index)); + + repository.insert(index.clone()); + + assert!(repository.exists(&index)); + assert!(repository.remove(&index)); + assert!(!repository.exists(&index)); + } + + #[test] + fn test_find_by_criteria() { + let repository = ProposalSortIndexRepository::default(); + let index = ProposalSortIndex { + key: ProposalSortIndexKey { + proposal_id: [0; 16], + }, + value: ProposalSortIndexValue { + creation_timestamp: 1, + modification_timestamp: 2, + expiration_timestamp: 3, + }, + }; + + repository.insert(index.clone()); + + let criteria = ProposalSortIndexCriteria { + proposal_id: [0; 16], + }; + + let result = repository.find_by_criteria(criteria); + + assert_eq!(result.len(), 1); + + let value = result.into_iter().next().unwrap(); + + assert_eq!(value.creation_timestamp, index.value.creation_timestamp); + assert_eq!( + value.modification_timestamp, + index.value.modification_timestamp + ); + assert_eq!(value.expiration_timestamp, index.value.expiration_timestamp); + } +} diff --git a/canisters/wallet/impl/src/repositories/indexes/proposal_status_index.rs b/canisters/wallet/impl/src/repositories/indexes/proposal_status_index.rs index 66ae11517..ef3768826 100644 --- a/canisters/wallet/impl/src/repositories/indexes/proposal_status_index.rs +++ b/canisters/wallet/impl/src/repositories/indexes/proposal_status_index.rs @@ -21,7 +21,7 @@ impl IndexRepository for ProposalStatusIndexRepositor type FindByCriteria = ProposalStatusIndexCriteria; fn exists(&self, index: &ProposalStatusIndex) -> bool { - DB.with(|m| m.borrow().get(index).is_some()) + DB.with(|m| m.borrow().contains_key(index)) } fn insert(&self, index: ProposalStatusIndex) { @@ -36,12 +36,10 @@ impl IndexRepository for ProposalStatusIndexRepositor DB.with(|db| { let start_key = ProposalStatusIndex { status: criteria.status.to_owned(), - last_modification_timestamp: criteria.from_dt.to_owned().unwrap_or(u64::MIN), proposal_id: [std::u8::MIN; 16], }; let end_key = ProposalStatusIndex { status: criteria.status.to_owned(), - last_modification_timestamp: criteria.to_dt.to_owned().unwrap_or(u64::MAX), proposal_id: [std::u8::MAX; 16], }; @@ -55,16 +53,14 @@ impl IndexRepository for ProposalStatusIndexRepositor #[cfg(test)] mod tests { - use crate::models::ProposalStatus; - use super::*; + use crate::models::ProposalStatusCode; #[test] fn test_repository_crud() { let repository = ProposalStatusIndexRepository::default(); let index = ProposalStatusIndex { - status: ProposalStatus::Created.to_string(), - last_modification_timestamp: 10, + status: ProposalStatusCode::Created, proposal_id: [1; 16], }; @@ -81,27 +77,23 @@ mod tests { fn test_find_by_criteria() { let repository = ProposalStatusIndexRepository::default(); let index = ProposalStatusIndex { - status: ProposalStatus::Created.to_string(), - last_modification_timestamp: 10, + status: ProposalStatusCode::Adopted, proposal_id: [1; 16], }; repository.insert(index.clone()); repository.insert(ProposalStatusIndex { - status: ProposalStatus::Created.to_string(), - last_modification_timestamp: 11, + status: ProposalStatusCode::Created, proposal_id: [2; 16], }); let criteria = ProposalStatusIndexCriteria { - status: ProposalStatus::Created.to_string(), - from_dt: None, - to_dt: Some(10), + status: ProposalStatusCode::Created, }; let result = repository.find_by_criteria(criteria); assert_eq!(result.len(), 1); - assert!(result.contains(&index.proposal_id)); + assert!(result.contains(&[2; 16])); } } diff --git a/canisters/wallet/impl/src/repositories/indexes/proposal_status_modification_index.rs b/canisters/wallet/impl/src/repositories/indexes/proposal_status_modification_index.rs new file mode 100644 index 000000000..a9e54b1bd --- /dev/null +++ b/canisters/wallet/impl/src/repositories/indexes/proposal_status_modification_index.rs @@ -0,0 +1,110 @@ +use crate::{ + core::{with_memory_manager, Memory, PROPOSAL_STATUS_MODIFICATION_INDEX_MEMORY_ID}, + models::indexes::proposal_status_modification_index::{ + ProposalStatusModificationIndex, ProposalStatusModificationIndexCriteria, + }, +}; +use ic_canister_core::{repository::IndexRepository, types::UUID}; +use ic_stable_structures::{memory_manager::VirtualMemory, StableBTreeMap}; +use std::{cell::RefCell, collections::HashSet}; + +thread_local! { + static DB: RefCell>> = with_memory_manager(|memory_manager| { + RefCell::new( + StableBTreeMap::init(memory_manager.get(PROPOSAL_STATUS_MODIFICATION_INDEX_MEMORY_ID)) + ) + }) +} + +#[derive(Default, Debug)] +pub struct ProposalStatusModificationIndexRepository; + +impl IndexRepository + for ProposalStatusModificationIndexRepository +{ + type FindByCriteria = ProposalStatusModificationIndexCriteria; + + fn exists(&self, index: &ProposalStatusModificationIndex) -> bool { + DB.with(|m| m.borrow().contains_key(index)) + } + + fn insert(&self, index: ProposalStatusModificationIndex) { + DB.with(|m| m.borrow_mut().insert(index, ())); + } + + fn remove(&self, index: &ProposalStatusModificationIndex) -> bool { + DB.with(|m| m.borrow_mut().remove(index).is_some()) + } + + fn find_by_criteria(&self, criteria: Self::FindByCriteria) -> HashSet { + DB.with(|db| { + let start_key = ProposalStatusModificationIndex { + status: criteria.status.to_owned(), + modification_timestamp: criteria.from_dt.unwrap_or(u64::MIN), + proposal_id: [std::u8::MIN; 16], + }; + let end_key = ProposalStatusModificationIndex { + status: criteria.status.to_owned(), + modification_timestamp: criteria.to_dt.unwrap_or(u64::MAX), + proposal_id: [std::u8::MAX; 16], + }; + + db.borrow() + .range(start_key..=end_key) + .map(|(index, _)| index.proposal_id) + .collect::>() + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::models::ProposalStatusCode; + + #[test] + fn test_repository_crud() { + let repository = ProposalStatusModificationIndexRepository; + let index = ProposalStatusModificationIndex { + status: ProposalStatusCode::Created, + modification_timestamp: 1, + proposal_id: [1; 16], + }; + + assert!(!repository.exists(&index)); + + repository.insert(index.clone()); + + assert!(repository.exists(&index)); + assert!(repository.remove(&index)); + assert!(!repository.exists(&index)); + } + + #[test] + fn test_find_by_criteria() { + let repository = ProposalStatusModificationIndexRepository; + let index = ProposalStatusModificationIndex { + status: ProposalStatusCode::Created, + modification_timestamp: 1, + proposal_id: [1; 16], + }; + + repository.insert(index.clone()); + repository.insert(ProposalStatusModificationIndex { + status: ProposalStatusCode::Created, + modification_timestamp: 2, + proposal_id: [2; 16], + }); + + let criteria = ProposalStatusModificationIndexCriteria { + status: ProposalStatusCode::Created, + from_dt: Some(0), + to_dt: Some(1), + }; + + let result = repository.find_by_criteria(criteria); + + assert_eq!(result.len(), 1); + assert!(result.contains(&index.proposal_id)); + } +} diff --git a/canisters/wallet/impl/src/repositories/indexes/proposal_voter_index.rs b/canisters/wallet/impl/src/repositories/indexes/proposal_voter_index.rs index 84f607df1..42a986285 100644 --- a/canisters/wallet/impl/src/repositories/indexes/proposal_voter_index.rs +++ b/canisters/wallet/impl/src/repositories/indexes/proposal_voter_index.rs @@ -39,13 +39,11 @@ impl IndexRepository for ProposalVoterIndexRepos fn find_by_criteria(&self, criteria: Self::FindByCriteria) -> HashSet { DB.with(|db| { let start_key = ProposalVoterIndex { - user_id: criteria.user_id.to_owned(), - created_at: criteria.from_dt.to_owned().unwrap_or(u64::MIN), + voter_id: criteria.voter_id.to_owned(), proposal_id: [u8::MIN; 16], }; let end_key = ProposalVoterIndex { - user_id: criteria.user_id.to_owned(), - created_at: criteria.to_dt.to_owned().unwrap_or(u64::MAX), + voter_id: criteria.voter_id.to_owned(), proposal_id: [u8::MAX; 16], }; @@ -66,8 +64,7 @@ mod tests { let repository = ProposalVoterIndexRepository::default(); let index = ProposalVoterIndex { proposal_id: [0; 16], - created_at: 10, - user_id: [1; 16], + voter_id: [1; 16], }; assert!(!repository.exists(&index)); @@ -84,17 +81,12 @@ mod tests { let repository = ProposalVoterIndexRepository::default(); let index = ProposalVoterIndex { proposal_id: [0; 16], - created_at: 10, - user_id: [1; 16], + voter_id: [1; 16], }; repository.insert(index.clone()); - let criteria = ProposalVoterIndexCriteria { - user_id: [1; 16], - from_dt: None, - to_dt: None, - }; + let criteria = ProposalVoterIndexCriteria { voter_id: [1; 16] }; let result = repository.find_by_criteria(criteria); diff --git a/canisters/wallet/impl/src/repositories/proposal.rs b/canisters/wallet/impl/src/repositories/proposal.rs index cf6b0f735..0f06a35da 100644 --- a/canisters/wallet/impl/src/repositories/proposal.rs +++ b/canisters/wallet/impl/src/repositories/proposal.rs @@ -2,30 +2,41 @@ use super::indexes::{ proposal_account_index::ProposalAccountIndexRepository, proposal_creation_time_index::ProposalCreationTimeIndexRepository, proposal_expiration_time_index::ProposalExpirationTimeIndexRepository, + proposal_key_creation_time_index::ProposalKeyCreationTimeIndexRepository, + proposal_key_expiration_time_index::ProposalKeyExpirationTimeIndexRepository, proposal_proposer_index::ProposalProposerIndexRepository, proposal_scheduled_index::ProposalScheduledIndexRepository, + proposal_sort_index::ProposalSortIndexRepository, proposal_status_index::ProposalStatusIndexRepository, + proposal_status_modification_index::ProposalStatusModificationIndexRepository, proposal_voter_index::ProposalVoterIndexRepository, }; use crate::{ - core::{utils::match_date_range, with_memory_manager, Memory, PROPOSAL_MEMORY_ID}, + core::{with_memory_manager, Memory, PROPOSAL_MEMORY_ID}, errors::{MapperError, RepositoryError}, - mappers::{HelperMapper, ProposalStatusMapper}, + mappers::HelperMapper, models::{ indexes::{ - proposal_account_index::ProposalAccountIndexCriteria, + proposal_account_index::{ProposalAccountIndex, ProposalAccountIndexCriteria}, proposal_creation_time_index::ProposalCreationTimeIndexCriteria, proposal_expiration_time_index::ProposalExpirationTimeIndexCriteria, - proposal_proposer_index::ProposalProposerIndexCriteria, + proposal_key_creation_time_index::ProposalKeyCreationTimeIndexCriteria, + proposal_key_expiration_time_index::ProposalKeyExpirationTimeIndexCriteria, + proposal_proposer_index::{ProposalProposerIndex, ProposalProposerIndexCriteria}, proposal_scheduled_index::ProposalScheduledIndexCriteria, - proposal_status_index::ProposalStatusIndexCriteria, - proposal_voter_index::ProposalVoterIndexCriteria, + proposal_sort_index::ProposalSortIndexKey, + proposal_status_index::{ProposalStatusIndex, ProposalStatusIndexCriteria}, + proposal_status_modification_index::ProposalStatusModificationIndexCriteria, + proposal_voter_index::{ProposalVoterIndex, ProposalVoterIndexCriteria}, }, - AccountId, Proposal, ProposalKey, ProposalStatusCode, UserId, + AccountId, Proposal, ProposalKey, ProposalStatusCode, }, }; use ic_canister_core::{ - repository::{IndexRepository, RefreshIndexMode, Repository}, + repository::{ + IndexRepository, OrSelectionFilter, RefreshIndexMode, Repository, SelectionFilter, + SortDirection, SortingStrategy, + }, types::{Timestamp, UUID}, }; use ic_stable_structures::{memory_manager::VirtualMemory, StableBTreeMap}; @@ -56,6 +67,10 @@ pub struct ProposalRepository { status_index: ProposalStatusIndexRepository, scheduled_index: ProposalScheduledIndexRepository, proposer_index: ProposalProposerIndexRepository, + status_modification_index: ProposalStatusModificationIndexRepository, + prefixed_creation_time_index: ProposalKeyCreationTimeIndexRepository, + prefixed_expiration_time_index: ProposalKeyExpirationTimeIndexRepository, + sort_index: ProposalSortIndexRepository, } impl Repository for ProposalRepository { @@ -97,16 +112,42 @@ impl Repository for ProposalRepository { previous: prev.clone().map(|prev| prev.to_index_by_creation_dt()), current: Some(value.to_index_by_creation_dt()), }); + self.prefixed_creation_time_index + .refresh_index_on_modification(RefreshIndexMode::Value { + previous: prev + .clone() + .map(|prev| prev.to_index_by_key_and_creation_dt()), + current: Some(value.to_index_by_key_and_creation_dt()), + }); self.expiration_dt_index .refresh_index_on_modification(RefreshIndexMode::Value { previous: prev.clone().map(|prev| prev.to_index_by_expiration_dt()), current: Some(value.to_index_by_expiration_dt()), }); + self.prefixed_expiration_time_index + .refresh_index_on_modification(RefreshIndexMode::Value { + previous: prev + .clone() + .map(|prev| prev.to_index_by_key_and_expiration_dt()), + current: Some(value.to_index_by_key_and_expiration_dt()), + }); self.status_index .refresh_index_on_modification(RefreshIndexMode::Value { previous: prev.clone().clone().map(|prev| prev.to_index_by_status()), current: Some(value.to_index_by_status()), }); + self.sort_index + .refresh_index_on_modification(RefreshIndexMode::Value { + previous: prev.clone().map(|prev| prev.to_index_for_sorting()), + current: Some(value.to_index_for_sorting()), + }); + self.status_modification_index + .refresh_index_on_modification(RefreshIndexMode::Value { + previous: prev + .clone() + .map(|prev| prev.to_index_by_status_and_modification()), + current: Some(value.to_index_by_status_and_modification()), + }); prev }) @@ -142,10 +183,32 @@ impl Repository for ProposalRepository { current: prev.clone().map(|prev| prev.to_index_by_expiration_dt()), }, ); + self.prefixed_creation_time_index + .refresh_index_on_modification(RefreshIndexMode::CleanupValue { + current: prev + .clone() + .map(|prev| prev.to_index_by_key_and_creation_dt()), + }); + self.prefixed_expiration_time_index + .refresh_index_on_modification(RefreshIndexMode::CleanupValue { + current: prev + .clone() + .map(|prev| prev.to_index_by_key_and_expiration_dt()), + }); self.status_index .refresh_index_on_modification(RefreshIndexMode::CleanupValue { current: prev.clone().map(|prev| prev.to_index_by_status()), }); + self.sort_index + .refresh_index_on_modification(RefreshIndexMode::CleanupValue { + current: prev.clone().map(|prev| prev.to_index_for_sorting()), + }); + self.status_modification_index + .refresh_index_on_modification(RefreshIndexMode::CleanupValue { + current: prev + .clone() + .map(|prev| prev.to_index_by_status_and_modification()), + }); prev }) @@ -176,6 +239,7 @@ impl ProposalRepository { Some(proposal) => { if proposal .status + .to_type() .to_string() .eq_ignore_ascii_case(status.as_str()) { @@ -191,20 +255,19 @@ impl ProposalRepository { pub fn find_by_status( &self, - status: String, - from_last_update_dt: Option, - to_last_update_dt: Option, + status: ProposalStatusCode, + from_last_modified_dt: Option, + to_last_modified_dt: Option, ) -> Vec { - let proposals = self - .status_index - .find_by_criteria(ProposalStatusIndexCriteria { - status: status.to_owned(), - from_dt: from_last_update_dt, - to_dt: to_last_update_dt, - }); + let ids = self.status_modification_index.find_by_criteria( + ProposalStatusModificationIndexCriteria { + status, + from_dt: from_last_modified_dt, + to_dt: to_last_modified_dt, + }, + ); - proposals - .iter() + ids.iter() .filter_map(|id| self.get(&Proposal::key(*id))) .collect::>() } @@ -224,298 +287,182 @@ impl ProposalRepository { .collect::>() } - pub fn find_by_account( - &self, - account_id: AccountId, - created_from_dt: Option, - created_to_dt: Option, - ) -> Vec { - let filtered_by_accounts = - self.account_index - .find_by_criteria(ProposalAccountIndexCriteria { - account_id: account_id.to_owned(), - from_dt: created_from_dt.to_owned(), - to_dt: created_to_dt.to_owned(), - }); - - filtered_by_accounts - .iter() - .filter_map(|id| self.get(&Proposal::key(*id))) - .collect() - } - - pub fn find_by_account_where( - &self, - account_id: AccountId, - condition: ProposalWhereClause, - ) -> Vec { - self.account_index - .find_by_criteria(ProposalAccountIndexCriteria { - account_id: account_id.to_owned(), - from_dt: condition.created_dt_from.to_owned(), - to_dt: condition.created_dt_to.to_owned(), - }) - .iter() - .filter_map(|id| self.check_condition(*id, &condition)) - .collect() - } - - pub fn find_by_voter_where( - &self, - user_id: UserId, - condition: ProposalWhereClause, - ) -> Vec { - self.voter_index - .find_by_criteria(ProposalVoterIndexCriteria { - user_id: user_id.to_owned(), - from_dt: condition.created_dt_from, - to_dt: condition.created_dt_to, - }) - .iter() - .filter_map(|id| self.check_condition(*id, &condition)) - .collect() - } - - pub fn find_where( + pub fn find_ids_where( &self, condition: ProposalWhereClause, sort_by: Option, - ) -> Result, RepositoryError> { - let strategy = self.pick_most_selective_where_filter(&condition); - let proposal_ids = self.find_with_strategy(strategy, &condition)?; - let mut proposals = proposal_ids - .iter() - .filter_map(|id| self.check_condition(*id, &condition)) - .collect::>(); + ) -> Result, RepositoryError> { + let filters = self.build_where_filtering_strategy(condition); + let proposal_ids = self.find_with_filters(filters); + let mut ids = proposal_ids.into_iter().collect::>(); - self.sort_proposals(&mut proposals, &sort_by); + self.sort_ids_with_strategy(&mut ids, &sort_by); - Ok(proposals) + Ok(ids) } - fn sort_proposals(&self, proposals: &mut [Proposal], sort_by: &Option) { + /// Sorts the proposal IDs based on the provided sort strategy. + /// + /// If no sort strategy is provided, it defaults to sorting by creation timestamp descending. + fn sort_ids_with_strategy( + &self, + proposal_ids: &mut [UUID], + sort_by: &Option, + ) { match sort_by { Some(wallet_api::ListProposalsSortBy::CreatedAt(direction)) => { - proposals.sort_by(|a, b| match direction { - wallet_api::SortDirection::Asc => a.created_timestamp.cmp(&b.created_timestamp), - wallet_api::SortDirection::Desc => { - b.created_timestamp.cmp(&a.created_timestamp) - } - }); + let sort_strategy = TimestampSortingStrategy { + index: &self.sort_index, + timestamp_type: TimestampType::Creation, + direction: match direction { + wallet_api::SortDirection::Asc => Some(SortDirection::Ascending), + wallet_api::SortDirection::Desc => Some(SortDirection::Descending), + }, + }; + + sort_strategy.sort(proposal_ids); } Some(wallet_api::ListProposalsSortBy::ExpirationDt(direction)) => { - proposals.sort_by(|a, b| match direction { - wallet_api::SortDirection::Asc => a.expiration_dt.cmp(&b.expiration_dt), - wallet_api::SortDirection::Desc => b.expiration_dt.cmp(&a.expiration_dt), - }); + let sort_strategy = TimestampSortingStrategy { + index: &self.sort_index, + timestamp_type: TimestampType::Expiration, + direction: match direction { + wallet_api::SortDirection::Asc => Some(SortDirection::Ascending), + wallet_api::SortDirection::Desc => Some(SortDirection::Descending), + }, + }; + + sort_strategy.sort(proposal_ids); } Some(wallet_api::ListProposalsSortBy::LastModificationDt(direction)) => { - proposals.sort_by(|a, b| match direction { - wallet_api::SortDirection::Asc => a - .last_modification_timestamp - .cmp(&b.last_modification_timestamp), - wallet_api::SortDirection::Desc => b - .last_modification_timestamp - .cmp(&a.last_modification_timestamp), - }); + let sort_strategy = TimestampSortingStrategy { + index: &self.sort_index, + timestamp_type: TimestampType::Modification, + direction: match direction { + wallet_api::SortDirection::Asc => Some(SortDirection::Ascending), + wallet_api::SortDirection::Desc => Some(SortDirection::Descending), + }, + }; + + sort_strategy.sort(proposal_ids); } None => { - // Default sort by created timestamp descending - proposals.sort_by(|a, b| b.created_timestamp.cmp(&a.created_timestamp)); + // Default sort by creation timestamp descending + let sort_strategy = TimestampSortingStrategy { + index: &self.sort_index, + timestamp_type: TimestampType::Creation, + direction: Some(SortDirection::Descending), + }; + + sort_strategy.sort(proposal_ids); } } } - fn pick_most_selective_where_filter( - &self, - condition: &ProposalWhereClause, - ) -> WhereSelectionStrategy { - let mut strategy = WhereSelectionStrategy::CreationDt; + fn build_where_filtering_strategy<'a>( + &'a self, + condition: ProposalWhereClause, + ) -> Vec + 'a>> { + let mut filters = Vec::new(); - if condition.expiration_dt_from.is_some() || condition.expiration_dt_to.is_some() { - strategy = WhereSelectionStrategy::ExpirationDt; + if condition.created_dt_from.is_some() || condition.created_dt_to.is_some() { + filters.push(Box::new(CreationDtSelectionFilter { + repository: &self.creation_dt_index, + prefixed_repository: &self.prefixed_creation_time_index, + from: condition.created_dt_from, + to: condition.created_dt_to, + }) as Box>); } - if condition.created_dt_from.is_some() || condition.created_dt_to.is_some() { - strategy = WhereSelectionStrategy::CreationDt; + if condition.expiration_dt_from.is_some() || condition.expiration_dt_to.is_some() { + filters.push(Box::new(ExpirationDtSelectionFilter { + repository: &self.expiration_dt_index, + prefixed_repository: &self.prefixed_expiration_time_index, + from: condition.expiration_dt_from, + to: condition.expiration_dt_to, + }) as Box>); } if !condition.statuses.is_empty() { - strategy = WhereSelectionStrategy::Status; + let includes_status = Box::new(OrSelectionFilter { + filters: condition + .statuses + .iter() + .map(|status| { + Box::new(StatusSelectionFilter { + repository: &self.status_index, + status: status.to_owned(), + }) as Box> + }) + .collect(), + }) as Box>; + + filters.push(includes_status); } if !condition.account_ids().unwrap_or_default().is_empty() { - strategy = WhereSelectionStrategy::Account; + let includes_account = Box::new(OrSelectionFilter { + filters: condition + .account_ids() + .unwrap_or_default() + .iter() + .map(|account_id| { + Box::new(AccountSelectionFilter { + repository: &self.account_index, + account_id: *account_id, + }) as Box> + }) + .collect(), + }) as Box>; + + filters.push(includes_account); } if !condition.voters.is_empty() { - strategy = WhereSelectionStrategy::Voter; - } + let includes_voter = Box::new(OrSelectionFilter { + filters: condition + .voters + .iter() + .map(|voter_id| { + Box::new(VoterSelectionFilter { + repository: &self.voter_index, + voter_id: *voter_id, + }) as Box> + }) + .collect(), + }) as Box>; - if !condition.proposers.is_empty() { - strategy = WhereSelectionStrategy::Proposer; + filters.push(includes_voter); } - strategy - } - - fn find_with_strategy( - &self, - strategy: WhereSelectionStrategy, - condition: &ProposalWhereClause, - ) -> Result, RepositoryError> { - let ids = match strategy { - WhereSelectionStrategy::Account => { - let mut proposal_ids = HashSet::::new(); - let account_ids = condition.account_ids().map_err(|e| { - RepositoryError::CriteriaValidationError { - reason: e.to_string(), - } - })?; - - for account_id in account_ids { - proposal_ids.extend(self.account_index.find_by_criteria( - ProposalAccountIndexCriteria { - account_id, - from_dt: condition.created_dt_from, - to_dt: condition.created_dt_to, - }, - )); - } - - proposal_ids - } - WhereSelectionStrategy::Voter => { - let mut proposal_ids = HashSet::::new(); - let user_ids: HashSet<_> = condition.voters.iter().collect(); - - for user_id in user_ids { - proposal_ids.extend(self.voter_index.find_by_criteria( - ProposalVoterIndexCriteria { - user_id: *user_id, - from_dt: condition.created_dt_from, - to_dt: condition.created_dt_to, - }, - )); - } - - proposal_ids - } - WhereSelectionStrategy::Proposer => { - let mut proposal_ids = HashSet::::new(); - let user_ids: HashSet<_> = condition.proposers.iter().collect(); - - for user_id in user_ids { - proposal_ids.extend(self.proposer_index.find_by_criteria( - ProposalProposerIndexCriteria { - user_id: *user_id, - from_dt: condition.created_dt_from, - to_dt: condition.created_dt_to, - }, - )); - } - - proposal_ids - } - WhereSelectionStrategy::Status => { - let mut proposal_ids = HashSet::::new(); - let statuses: HashSet<_> = condition - .statuses + if !condition.proposers.is_empty() { + let includes_proposer = Box::new(OrSelectionFilter { + filters: condition + .proposers .iter() - .map(ProposalStatusMapper::from_status_code_dto) - .collect(); - - for status in statuses { - proposal_ids.extend(self.status_index.find_by_criteria( - ProposalStatusIndexCriteria { - status: status.to_string(), - from_dt: condition.created_dt_from, - to_dt: condition.created_dt_to, - }, - )); - } - - proposal_ids - } - WhereSelectionStrategy::ExpirationDt => { - self.expiration_dt_index - .find_by_criteria(ProposalExpirationTimeIndexCriteria { - from_dt: condition.expiration_dt_from, - to_dt: condition.expiration_dt_to, - }) - } - WhereSelectionStrategy::CreationDt => { - self.creation_dt_index - .find_by_criteria(ProposalCreationTimeIndexCriteria { - from_dt: condition.created_dt_from, - to_dt: condition.created_dt_to, + .map(|proposer_id| { + Box::new(ProposerSelectionFilter { + repository: &self.proposer_index, + proposer_id: *proposer_id, + }) as Box> }) - } - }; - - Ok(ids) - } - - fn check_condition( - &self, - proposal_id: UUID, - condition: &ProposalWhereClause, - ) -> Option { - match self.get(&Proposal::key(proposal_id)) { - Some(proposal) => { - let mut match_operation_types = true; - let mut match_statuses = true; - let mut match_voters = true; - let mut match_proposers = true; - let match_creation_dt_range = match_date_range( - &proposal.created_timestamp, - &condition.created_dt_from, - &condition.created_dt_to, - ); - let match_expiration_dt_range = match_date_range( - &proposal.expiration_dt, - &condition.expiration_dt_from, - &condition.expiration_dt_to, - ); - - if !condition.operation_types.is_empty() { - match_operation_types = condition - .operation_types - .iter() - .any(|operation_type| proposal.operation.is_of_type(operation_type)); - } - - if !condition.statuses.is_empty() { - match_statuses = condition - .statuses - .iter() - .any(|s| ProposalStatusCode::from(proposal.status.clone()) == *s); - } + .collect(), + }) as Box>; - if !condition.voters.is_empty() { - match_voters = condition - .voters - .iter() - .any(|v| proposal.voters().contains(v)); - } - - if !condition.proposers.is_empty() { - match_proposers = condition.proposers.contains(&proposal.proposed_by); - } + filters.push(includes_proposer); + } - match match_expiration_dt_range - && match_creation_dt_range - && match_operation_types - && match_statuses - && match_proposers - && match_voters - { - true => Some(proposal), - false => None, - } - } - None => None, + if filters.is_empty() { + // If no filters are provided, return all + filters.push(Box::new(CreationDtSelectionFilter { + repository: &self.creation_dt_index, + prefixed_repository: &self.prefixed_creation_time_index, + from: None, + to: None, + }) as Box>); } + + filters } } @@ -547,14 +494,210 @@ impl ProposalWhereClause { } } -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum WhereSelectionStrategy { - Account, - Voter, - Proposer, - Status, - ExpirationDt, - CreationDt, +#[derive(Debug, Clone)] +pub(crate) struct CreationDtSelectionFilter<'a> { + repository: &'a ProposalCreationTimeIndexRepository, + prefixed_repository: &'a ProposalKeyCreationTimeIndexRepository, + from: Option, + to: Option, +} + +impl<'a> SelectionFilter<'a> for CreationDtSelectionFilter<'a> { + type IdType = UUID; + + fn matches(&self, id: &Self::IdType) -> bool { + self.prefixed_repository + .exists_by_criteria(ProposalKeyCreationTimeIndexCriteria { + proposal_id: *id, + from_dt: self.from, + to_dt: self.to, + }) + } + + fn select(&self) -> HashSet { + self.repository + .find_by_criteria(ProposalCreationTimeIndexCriteria { + from_dt: self.from, + to_dt: self.to, + }) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct ExpirationDtSelectionFilter<'a> { + repository: &'a ProposalExpirationTimeIndexRepository, + prefixed_repository: &'a ProposalKeyExpirationTimeIndexRepository, + from: Option, + to: Option, +} + +impl<'a> SelectionFilter<'a> for ExpirationDtSelectionFilter<'a> { + type IdType = UUID; + + fn matches(&self, id: &Self::IdType) -> bool { + self.prefixed_repository + .exists_by_criteria(ProposalKeyExpirationTimeIndexCriteria { + proposal_id: *id, + from_dt: self.from, + to_dt: self.to, + }) + } + + fn select(&self) -> HashSet { + self.repository + .find_by_criteria(ProposalExpirationTimeIndexCriteria { + from_dt: self.from, + to_dt: self.to, + }) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct AccountSelectionFilter<'a> { + repository: &'a ProposalAccountIndexRepository, + account_id: AccountId, +} + +impl<'a> SelectionFilter<'a> for AccountSelectionFilter<'a> { + type IdType = UUID; + + fn matches(&self, id: &Self::IdType) -> bool { + self.repository.exists(&ProposalAccountIndex { + account_id: self.account_id, + proposal_id: *id, + }) + } + + fn select(&self) -> HashSet { + self.repository + .find_by_criteria(ProposalAccountIndexCriteria { + account_id: self.account_id, + }) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct VoterSelectionFilter<'a> { + repository: &'a ProposalVoterIndexRepository, + voter_id: UUID, +} + +impl<'a> SelectionFilter<'a> for VoterSelectionFilter<'a> { + type IdType = UUID; + + fn matches(&self, id: &Self::IdType) -> bool { + self.repository.exists(&ProposalVoterIndex { + voter_id: self.voter_id, + proposal_id: *id, + }) + } + + fn select(&self) -> HashSet { + self.repository + .find_by_criteria(ProposalVoterIndexCriteria { + voter_id: self.voter_id, + }) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct ProposerSelectionFilter<'a> { + repository: &'a ProposalProposerIndexRepository, + proposer_id: UUID, +} + +impl<'a> SelectionFilter<'a> for ProposerSelectionFilter<'a> { + type IdType = UUID; + + fn matches(&self, id: &Self::IdType) -> bool { + self.repository.exists(&ProposalProposerIndex { + proposer_id: self.proposer_id, + proposal_id: *id, + }) + } + + fn select(&self) -> HashSet { + self.repository + .find_by_criteria(ProposalProposerIndexCriteria { + proposer_id: self.proposer_id, + }) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct StatusSelectionFilter<'a> { + repository: &'a ProposalStatusIndexRepository, + status: ProposalStatusCode, +} + +impl<'a> SelectionFilter<'a> for StatusSelectionFilter<'a> { + type IdType = UUID; + + fn matches(&self, id: &Self::IdType) -> bool { + self.repository.exists(&ProposalStatusIndex { + status: self.status.to_owned(), + proposal_id: *id, + }) + } + + fn select(&self) -> HashSet { + self.repository + .find_by_criteria(ProposalStatusIndexCriteria { + status: self.status.to_owned(), + }) + } +} + +#[derive(Debug, Clone)] +enum TimestampType { + Creation, + Expiration, + Modification, +} + +#[derive(Debug, Clone)] +struct TimestampSortingStrategy<'a> { + index: &'a ProposalSortIndexRepository, + timestamp_type: TimestampType, + direction: Option, +} + +impl<'a> SortingStrategy<'a> for TimestampSortingStrategy<'a> { + type IdType = UUID; + + fn sort(&self, ids: &mut [Self::IdType]) { + let direction = self.direction.unwrap_or(SortDirection::Ascending); + let mut id_with_timestamps: Vec<(Timestamp, Self::IdType)> = ids + .iter() + .map(|id| { + let key = ProposalSortIndexKey { proposal_id: *id }; + let timestamp = self + .index + .get(&key) + .map(|index| match self.timestamp_type { + TimestampType::Creation => index.creation_timestamp, + TimestampType::Expiration => index.expiration_timestamp, + TimestampType::Modification => index.modification_timestamp, + }) + .unwrap_or_default(); + (timestamp, *id) + }) + .collect(); + + id_with_timestamps.sort_by(|a, b| { + { + let ord = a.0.cmp(&b.0); // Compare timestamps + match direction { + SortDirection::Ascending => ord, + SortDirection::Descending => ord.reverse(), + } + } + .then_with(|| a.1.cmp(&b.1)) // Compare proposal IDs if timestamps are equal + }); + + let sorted_ids: Vec = id_with_timestamps.into_iter().map(|(_, id)| id).collect(); + ids.copy_from_slice(&sorted_ids); + } } #[cfg(test)] @@ -562,9 +705,8 @@ mod tests { use super::*; use crate::models::{ proposal_test_utils::{self, mock_proposal}, - Metadata, ProposalOperation, ProposalStatus, TransferOperation, TransferOperationInput, + ProposalStatus, }; - use num_bigint::BigUint; use uuid::Uuid; #[test] @@ -581,33 +723,6 @@ mod tests { assert!(repository.get(&proposal.to_key()).is_none()); } - #[test] - fn find_by_account() { - let repository = ProposalRepository::default(); - let mut proposal = mock_proposal(); - let user_id = Uuid::new_v4(); - let account_id = Uuid::new_v4(); - proposal.proposed_by = *user_id.as_bytes(); - proposal.operation = ProposalOperation::Transfer(TransferOperation { - transfer_id: None, - input: TransferOperationInput { - amount: candid::Nat(BigUint::from(100u32)), - fee: None, - metadata: Metadata::default(), - network: "mainnet".to_string(), - to: "0x1234".to_string(), - from_account_id: *account_id.as_bytes(), - }, - }); - - repository.insert(proposal.to_key(), proposal.clone()); - - assert_eq!( - repository.find_by_account(*account_id.as_bytes(), None, None), - vec![proposal] - ); - } - #[test] fn find_by_expiration_dt_and_status() { let repository = ProposalRepository::default(); @@ -622,19 +737,19 @@ mod tests { let last_six = repository.find_by_expiration_dt_and_status( Some(45), None, - ProposalStatus::Created.to_string(), + ProposalStatusCode::Created.to_string(), ); let middle_eleven = repository.find_by_expiration_dt_and_status( Some(30), Some(40), - ProposalStatus::Created.to_string(), + ProposalStatusCode::Created.to_string(), ); let first_three = repository.find_by_expiration_dt_and_status( None, Some(2), - ProposalStatus::Created.to_string(), + ProposalStatusCode::Created.to_string(), ); assert_eq!(last_six.len(), 6); @@ -643,7 +758,7 @@ mod tests { } #[test] - fn no_of_future_expiration_dt() { + fn no_future_expiration_dt() { let repository = ProposalRepository::default(); let mut proposal = proposal_test_utils::mock_proposal(); proposal.expiration_dt = 10; @@ -653,70 +768,14 @@ mod tests { let proposals = repository.find_by_expiration_dt_and_status( Some(20), None, - proposal.status.to_string(), + proposal.status.to_type().to_string(), ); assert!(proposals.is_empty()); } #[test] - fn pick_optmized_lookup_strategy() { - let mut condition = ProposalWhereClause { - created_dt_from: None, - created_dt_to: None, - expiration_dt_from: Some(10), - expiration_dt_to: None, - operation_types: vec![], - statuses: vec![], - voters: vec![], - proposers: vec![], - }; - - assert_eq!( - WhereSelectionStrategy::ExpirationDt, - PROPOSAL_REPOSITORY.pick_most_selective_where_filter(&condition) - ); - - condition.created_dt_from = Some(10); - - assert_eq!( - WhereSelectionStrategy::CreationDt, - PROPOSAL_REPOSITORY.pick_most_selective_where_filter(&condition) - ); - - condition.statuses = vec![ProposalStatusCode::Created]; - - assert_eq!( - WhereSelectionStrategy::Status, - PROPOSAL_REPOSITORY.pick_most_selective_where_filter(&condition) - ); - - condition.operation_types = vec![ListProposalsOperationTypeDTO::Transfer(Some( - Uuid::new_v4().to_string(), - ))]; - - assert_eq!( - WhereSelectionStrategy::Account, - PROPOSAL_REPOSITORY.pick_most_selective_where_filter(&condition) - ); - - condition.voters = vec![[0; 16]]; - - assert_eq!( - WhereSelectionStrategy::Voter, - PROPOSAL_REPOSITORY.pick_most_selective_where_filter(&condition) - ); - - condition.proposers = vec![[0; 16]]; - - assert_eq!( - WhereSelectionStrategy::Proposer, - PROPOSAL_REPOSITORY.pick_most_selective_where_filter(&condition) - ); - } - - #[test] - fn find_where_with_expiration_dt() { + fn find_with_expiration_dt() { let mut proposal = proposal_test_utils::mock_proposal(); proposal.id = *Uuid::new_v4().as_bytes(); proposal.created_timestamp = 5; @@ -743,23 +802,28 @@ mod tests { }; let proposals = PROPOSAL_REPOSITORY - .find_where(condition.clone(), None) + .find_ids_where(condition.clone(), None) .unwrap(); assert_eq!(proposals.len(), 1); - assert_eq!(proposals[0], proposal); + + let found_proposal = PROPOSAL_REPOSITORY + .get(&ProposalKey { id: proposals[0] }) + .unwrap(); + + assert_eq!(found_proposal, proposal); condition.expiration_dt_from = Some(11); let proposals = PROPOSAL_REPOSITORY - .find_where(condition.clone(), None) + .find_ids_where(condition.clone(), None) .unwrap(); assert!(proposals.is_empty()); } #[test] - fn find_where_with_creation_dt() { + fn find_with_creation_dt() { let mut proposal = proposal_test_utils::mock_proposal(); proposal.id = *Uuid::new_v4().as_bytes(); proposal.created_timestamp = 10; @@ -784,38 +848,152 @@ mod tests { }; let proposals = PROPOSAL_REPOSITORY - .find_where(condition.clone(), None) + .find_ids_where(condition.clone(), None) .unwrap(); assert_eq!(proposals.len(), 1); - assert_eq!(proposals[0], proposal); + + let found_proposal = PROPOSAL_REPOSITORY + .get(&ProposalKey { id: proposals[0] }) + .unwrap(); + + assert_eq!(found_proposal, proposal); condition.created_dt_from = Some(8); condition.created_dt_to = Some(9); let proposals = PROPOSAL_REPOSITORY - .find_where(condition.clone(), None) + .find_ids_where(condition.clone(), None) .unwrap(); assert!(proposals.is_empty()); } + + #[test] + fn find_with_default_filters() { + for i in 0..100 { + let mut proposal = mock_proposal(); + proposal.id = *Uuid::new_v4().as_bytes(); + proposal.created_timestamp = i; + proposal.expiration_dt = i + 100; + proposal.status = match i % 2 { + 0 => ProposalStatus::Created, + 1 => ProposalStatus::Adopted, + _ => ProposalStatus::Rejected, + }; + + PROPOSAL_REPOSITORY.insert(proposal.to_key(), proposal.to_owned()); + } + + let condition = ProposalWhereClause { + created_dt_from: Some(50), + created_dt_to: Some(100), + expiration_dt_from: None, + expiration_dt_to: None, + operation_types: Vec::new(), + proposers: Vec::new(), + voters: Vec::new(), + statuses: vec![ProposalStatusCode::Created], + }; + + let proposals = PROPOSAL_REPOSITORY + .find_ids_where(condition.clone(), None) + .unwrap(); + + assert_eq!(proposals.len(), 25); + + let condition = ProposalWhereClause { + created_dt_from: Some(0), + created_dt_to: Some(100), + expiration_dt_from: None, + expiration_dt_to: None, + operation_types: Vec::new(), + proposers: Vec::new(), + voters: Vec::new(), + statuses: vec![ProposalStatusCode::Adopted], + }; + + let proposals = PROPOSAL_REPOSITORY + .find_ids_where(condition.clone(), None) + .unwrap(); + + assert_eq!(proposals.len(), 50); + + let condition = ProposalWhereClause { + created_dt_from: Some(0), + created_dt_to: Some(100), + expiration_dt_from: None, + expiration_dt_to: None, + operation_types: Vec::new(), + proposers: Vec::new(), + voters: Vec::new(), + statuses: vec![ProposalStatusCode::Adopted, ProposalStatusCode::Created], + }; + + let proposals = PROPOSAL_REPOSITORY + .find_ids_where(condition.clone(), None) + .unwrap(); + + assert_eq!(proposals.len(), 100); + + let condition = ProposalWhereClause { + created_dt_from: Some(0), + created_dt_to: Some(100), + expiration_dt_from: Some(110), + expiration_dt_to: Some(120), + operation_types: Vec::new(), + proposers: Vec::new(), + voters: Vec::new(), + statuses: vec![ProposalStatusCode::Adopted], + }; + + let proposals = PROPOSAL_REPOSITORY + .find_ids_where(condition.clone(), None) + .unwrap(); + + assert_eq!(proposals.len(), 5); + } + + #[test] + fn find_with_empty_where_clause_should_return_all() { + proposal_repository_test_utils::add_proposals_to_repository(100); + + let condition = ProposalWhereClause { + created_dt_from: None, + created_dt_to: None, + expiration_dt_from: None, + expiration_dt_to: None, + operation_types: vec![], + statuses: vec![], + voters: vec![], + proposers: vec![], + }; + + let proposals = PROPOSAL_REPOSITORY + .find_ids_where(condition.clone(), None) + .unwrap(); + + assert_eq!(proposals.len(), 100); + } } -#[cfg(feature = "canbench-rs")] +#[cfg(feature = "canbench")] mod benchs { use super::*; - use crate::models::proposal_test_utils::mock_proposal; + use crate::models::{proposal_test_utils::mock_proposal, ProposalStatus}; use canbench_rs::{bench, BenchResult}; use uuid::Uuid; - #[bench] - fn batch_insert_100_proposals() { - add_proposals_to_repository(100); + #[bench(raw)] + fn repository_batch_insert_100_proposals() -> BenchResult { + canbench_rs::bench_fn(|| { + proposal_repository_test_utils::add_proposals_to_repository(100); + }) } #[bench(raw)] - fn list_all_proposals() -> BenchResult { - add_proposals_to_repository(1_000); + fn repository_list_all_proposals() -> BenchResult { + proposal_repository_test_utils::add_proposals_to_repository(1_000); canbench_rs::bench_fn(|| { let _ = PROPOSAL_REPOSITORY.list(); @@ -823,27 +1001,45 @@ mod benchs { } #[bench(raw)] - fn filter_all_proposals_by_default_filters() -> BenchResult { - add_proposals_to_repository(1_000); + fn repository_filter_all_proposal_ids_by_default_filters() -> BenchResult { + for i in 0..2_500 { + let mut proposal = mock_proposal(); + proposal.id = *Uuid::new_v4().as_bytes(); + proposal.created_timestamp = i; + proposal.status = match i % 2 { + 0 => ProposalStatus::Created, + 1 => ProposalStatus::Adopted, + _ => ProposalStatus::Rejected, + }; + + PROPOSAL_REPOSITORY.insert(proposal.to_key(), proposal.to_owned()); + } canbench_rs::bench_fn(|| { - let _ = PROPOSAL_REPOSITORY.find_where( + let _ = PROPOSAL_REPOSITORY.find_ids_where( ProposalWhereClause { - created_dt_from: None, - created_dt_to: None, + created_dt_from: Some(500), + created_dt_to: Some(1500), expiration_dt_from: None, expiration_dt_to: None, operation_types: Vec::new(), proposers: Vec::new(), voters: Vec::new(), - statuses: Vec::new(), + statuses: vec![ProposalStatusCode::Created], }, None, ); }) } +} + +#[cfg(any(test, feature = "canbench"))] +mod proposal_repository_test_utils { + use super::*; + use crate::models::proposal_test_utils::mock_proposal; + use uuid::Uuid; - fn add_proposals_to_repository(count: usize) { + pub fn add_proposals_to_repository(count: usize) { for _ in 0..count { let mut proposal = mock_proposal(); proposal.id = *Uuid::new_v4().as_bytes(); diff --git a/canisters/wallet/impl/src/services/proposal.rs b/canisters/wallet/impl/src/services/proposal.rs index b3f3a02c1..0dbd69878 100644 --- a/canisters/wallet/impl/src/services/proposal.rs +++ b/canisters/wallet/impl/src/services/proposal.rs @@ -16,8 +16,6 @@ use crate::{ repositories::{ProposalRepository, ProposalWhereClause, PROPOSAL_REPOSITORY}, services::{NotificationService, UserService, NOTIFICATION_SERVICE, USER_SERVICE}, }; -use futures::stream; -use futures::StreamExt; use ic_canister_core::utils::rfc3339_to_timestamp; use ic_canister_core::{api::ServiceResult, model::ModelValidator}; use ic_canister_core::{repository::Repository, types::UUID}; @@ -130,7 +128,7 @@ impl ProposalService { }) .transpose()?; - let mut proposals = self.proposal_repository.find_where( + let mut proposal_ids = self.proposal_repository.find_ids_where( ProposalWhereClause { created_dt_from: input .created_from_dt @@ -145,7 +143,10 @@ impl ProposalService { .expiration_to_dt .map(|dt| rfc3339_to_timestamp(dt.as_str())), operation_types: input.operation_types.unwrap_or_default(), - statuses: input.statuses.unwrap_or_default(), + statuses: input + .statuses + .map(|statuses| statuses.into_iter().map(Into::into).collect::<_>()) + .unwrap_or_default(), proposers: filter_by_proposers.unwrap_or_default(), voters: filter_by_voters.unwrap_or_default(), }, @@ -154,33 +155,41 @@ impl ProposalService { // filter out proposals that the caller does not have access to read if let Some(ctx) = ctx { - proposals = stream::iter(proposals.iter()) - .filter_map(|proposal| async move { - match evaluate_caller_access( - ctx, - &ResourceSpecifier::Proposal(ProposalActionSpecifier::Read( - CommonSpecifier::Id(vec![proposal.id.to_owned()]), - )), - ) - .await - { - Ok(_) => Some(proposal.to_owned()), - Err(_) => None, - } - }) - .collect() + let mut ids_with_access = Vec::new(); + for proposal_id in &proposal_ids { + if evaluate_caller_access( + ctx, + &ResourceSpecifier::Proposal(ProposalActionSpecifier::Read( + CommonSpecifier::Id(vec![proposal_id.to_owned()]), + )), + ) .await + .is_ok() + { + ids_with_access.push(*proposal_id); + } + } + + proposal_ids = ids_with_access; } - let paginated_proposals = paginated_items(PaginatedItemsArgs { + let paginated_ids = paginated_items(PaginatedItemsArgs { offset: input.paginate.to_owned().and_then(|p| p.offset), limit: input.paginate.and_then(|p| p.limit), default_limit: Some(Self::DEFAULT_PROPOSAL_LIST_LIMIT), max_limit: Some(Self::MAX_PROPOSAL_LIST_LIMIT), - items: &proposals, + items: &proposal_ids, })?; - Ok(paginated_proposals) + Ok(PaginatedData { + total: paginated_ids.total, + next_offset: paginated_ids.next_offset, + items: paginated_ids + .items + .into_iter() + .map(|id| self.get_proposal(&id).expect("Failed to get proposal")) + .collect::>(), + }) } pub async fn edit_proposal(&self, input: ProposalEditInput) -> ServiceResult { @@ -547,3 +556,98 @@ mod tests { assert_eq!(result.unwrap().items.len(), 1); } } + +#[cfg(feature = "canbench")] +mod benchs { + use super::*; + use crate::{ + core::ic_cdk::spawn, + models::{ + access_control::{access_control_test_utils::mock_access_policy, UserSpecifier}, + proposal_test_utils::mock_proposal, + user_test_utils::mock_user, + UserStatus, + }, + repositories::{access_control::ACCESS_CONTROL_REPOSITORY, USER_REPOSITORY}, + }; + use canbench_rs::{bench, BenchResult}; + use candid::Principal; + use ic_canister_core::utils::timestamp_to_rfc3339; + use wallet_api::ProposalStatusCodeDTO; + + #[bench(raw)] + fn service_filter_all_proposals_with_default_filters() -> BenchResult { + let proposals_to_insert = 1000u64; + let end_creation_time = proposals_to_insert * 1_000_000_000; + // this emulates a real world scenario where the proposals are created in a time span and + // the filter is used to fetch the proposals created in the last half of the time span + let start_creation_time = end_creation_time / 2; + + for i in 0..proposals_to_insert { + let mut proposal = mock_proposal(); + proposal.created_timestamp = i * 1_000_000_000; + proposal.status = match i % 2 { + 0 => ProposalStatus::Created, + 1 => ProposalStatus::Adopted, + _ => ProposalStatus::Rejected, + }; + + PROPOSAL_REPOSITORY.insert(proposal.to_key(), proposal.to_owned()); + } + + let mut users = Vec::new(); + // adding some users that will be added to the access control repository later + for i in 0..10 { + let mut user = mock_user(); + user.identities = vec![Principal::from_slice(&[i; 29])]; + user.status = UserStatus::Active; + + USER_REPOSITORY.insert(user.to_key(), user.to_owned()); + + users.push(user); + } + + // adding some access policies since the filter will check for access + for user in users.iter() { + let mut access_policy = mock_access_policy(); + access_policy.resource = + ResourceSpecifier::Proposal(ProposalActionSpecifier::Read(CommonSpecifier::Any)); + access_policy.user = UserSpecifier::Id(vec![user.id]); + + ACCESS_CONTROL_REPOSITORY.insert(access_policy.id, access_policy.to_owned()); + } + + canbench_rs::bench_fn(|| { + spawn(async move { + let result = PROPOSAL_SERVICE + .list_proposals( + wallet_api::ListProposalsInput { + created_from_dt: Some(timestamp_to_rfc3339(&start_creation_time)), + created_to_dt: Some(timestamp_to_rfc3339(&end_creation_time)), + statuses: Some(vec![ProposalStatusCodeDTO::Created]), + voter_ids: None, + proposer_ids: None, + operation_types: None, + expiration_from_dt: None, + expiration_to_dt: None, + paginate: Some(wallet_api::PaginationInput { + limit: Some(25), + offset: None, + }), + sort_by: Some(wallet_api::ListProposalsSortBy::CreatedAt( + wallet_api::SortDirection::Asc, + )), + }, + None, + ) + .await; + + let paginated_data = result.unwrap(); + + if paginated_data.total == 0 { + panic!("No proposals were found with the given filters"); + } + }); + }) + } +} diff --git a/canisters/wallet/impl/src/services/transfer.rs b/canisters/wallet/impl/src/services/transfer.rs index d190033b1..baf0b1c67 100644 --- a/canisters/wallet/impl/src/services/transfer.rs +++ b/canisters/wallet/impl/src/services/transfer.rs @@ -92,8 +92,6 @@ impl TransferService { #[cfg(test)] mod tests { - use candid::Principal; - use super::*; use crate::{ core::test_utils, @@ -101,8 +99,9 @@ mod tests { account_test_utils::mock_account, transfer_test_utils::mock_transfer, user_test_utils::mock_user, User, }, - repositories::UserRepository, + repositories::{ACCOUNT_REPOSITORY, TRANSFER_REPOSITORY, USER_REPOSITORY}, }; + use candid::Principal; struct TestContext { repository: TransferRepository, @@ -119,12 +118,12 @@ mod tests { let mut user = mock_user(); user.identities = vec![call_context.caller()]; - UserRepository::default().insert(user.to_key(), user.clone()); + USER_REPOSITORY.insert(user.to_key(), user.clone()); let mut account = mock_account(); account.owners.push(user.id); - AccountRepository::default().insert(account.to_key(), account.clone()); + ACCOUNT_REPOSITORY.insert(account.to_key(), account.clone()); TestContext { repository: TransferRepository::default(), @@ -153,13 +152,20 @@ mod tests { fn fail_get_transfer_not_allowed() { let ctx = setup(); let mut user = mock_user(); - user.identities = vec![Principal::anonymous()]; - UserRepository::default().insert(user.to_key(), user.clone()); + user.identities = vec![Principal::from_slice(&[10; 29])]; + + USER_REPOSITORY.insert(user.to_key(), user.clone()); + + let mut account = mock_account(); + account.owners.push(user.id); + + ACCOUNT_REPOSITORY.insert(account.to_key(), account.clone()); + let mut transfer = mock_transfer(); - transfer.from_account = ctx.account.id; + transfer.from_account = account.id; transfer.initiator_user = user.id; - ctx.repository.insert(transfer.to_key(), transfer.clone()); + TRANSFER_REPOSITORY.insert(transfer.to_key(), transfer.clone()); let result = ctx.service.get_transfer(&transfer.id, &ctx.call_context); diff --git a/libs/ic-canister-core/Cargo.toml b/libs/ic-canister-core/Cargo.toml index cdcb8a2dc..3b038804a 100644 --- a/libs/ic-canister-core/Cargo.toml +++ b/libs/ic-canister-core/Cargo.toml @@ -16,6 +16,7 @@ candid = { workspace = true } convert_case = { workspace = true } getrandom = { workspace = true, features = ["custom"] } ic-cdk = { workspace = true } +ic-stable-structures = { workspace = true } rand_chacha = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/libs/ic-canister-core/src/repository.rs b/libs/ic-canister-core/src/repository.rs index e3ddf2df4..09810f268 100644 --- a/libs/ic-canister-core/src/repository.rs +++ b/libs/ic-canister-core/src/repository.rs @@ -1,5 +1,7 @@ use std::collections::HashSet; +use crate::types::UUID; + /// A repository is a generic interface for storing and retrieving data. pub trait Repository { /// Returns the list of records from the repository. @@ -25,10 +27,23 @@ pub trait Repository { fn refresh_indexes(&self, _current: Value, _previous: Option) { // no-op } + + fn find_with_filters<'a>( + &self, + filters: Vec + 'a>>, + ) -> HashSet { + let mut found_ids = None; + + for filter in filters { + found_ids = Some(filter.apply(found_ids.as_ref())); + } + + found_ids.unwrap_or_default() + } } /// An index repository is a generic interface for storing and retrieving data based on an index. -pub trait IndexRepository { +pub trait IndexRepository { type FindByCriteria; /// Checks if an index exists. @@ -41,7 +56,7 @@ pub trait IndexRepository { fn remove(&self, index: &Index) -> bool; /// Returns all records from the repository that match a set criteria. - fn find_by_criteria(&self, criteria: Self::FindByCriteria) -> HashSet; + fn find_by_criteria(&self, criteria: Self::FindByCriteria) -> HashSet; fn refresh_index_on_modification(&self, mode: RefreshIndexMode) where @@ -105,3 +120,148 @@ pub enum RefreshIndexMode { current: Vec, }, } + +/// A filter that can be applied to a set of ids to select or filter them down based on some criteria. +/// +/// By default, the filter is meant to filter down the set of IDs and not select them, unless +/// the `is_selective` method is overridden. +pub trait SelectionFilter<'a> +where + Self::IdType: Clone + Eq + std::hash::Hash, +{ + /// The type of the IDs that the filter operates on + type IdType; + + /// Applies the filter to the existing set of IDs and returns the new set of IDs + fn apply(&self, existing_ids: Option<&HashSet>) -> HashSet { + match (existing_ids, self.is_selective()) { + (Some(ids), true) => { + let new_ids = self.select(); + new_ids.intersection(ids).cloned().collect() + } + (Some(ids), false) => ids.iter().filter(|id| self.matches(id)).cloned().collect(), + // If the existing set of IDs is None, then we are meant to select all IDs that match the filter + (None, true) | (None, false) => self.select(), + } + } + + /// Returns true if the item matches the filter criteria + /// + /// By default it is true for all items, which means that the filter is a no-op. + /// + /// A no-op match is useful for logical operations (e.g. AND and OR), where the filter is meant to + /// rely on other filters to do the actual filtering. + fn matches(&self, _item_id: &Self::IdType) -> bool { + true + } + + /// Returns true if the filter is meant to select rather than filter down + /// + /// By default it is false + fn is_selective(&self) -> bool { + false + } + + /// Returns the initial set of IDs for the filter + /// + /// By default it is an empty set + fn select(&self) -> HashSet { + HashSet::new() + } +} + +/// A filter that combines multiple filters using a logical AND operation. +pub struct AndSelectionFilter<'a> { + pub filters: Vec + 'a>>, +} + +impl<'a> SelectionFilter<'a> for AndSelectionFilter<'a> { + type IdType = UUID; + + fn apply(&self, existing_ids: Option<&HashSet>) -> HashSet { + let mut found_ids: Option> = None; + + for filter in &self.filters { + found_ids = Some(filter.apply(found_ids.as_ref())); + + if found_ids.is_some() && found_ids.as_ref().unwrap().is_empty() { + break; + } + } + + let newly_found_ids = found_ids.unwrap_or_default(); + + match existing_ids { + Some(ids) => { + let mut new_ids = newly_found_ids; + new_ids.retain(|id| ids.contains(id)); + + new_ids + } + None => newly_found_ids, + } + } +} + +/// A filter that combines multiple filters using a logical OR operation. +pub struct OrSelectionFilter<'a> { + pub filters: Vec + 'a>>, +} + +impl<'a> SelectionFilter<'a> for OrSelectionFilter<'a> { + type IdType = UUID; + + fn apply(&self, existing_ids: Option<&HashSet>) -> HashSet { + let mut found_ids = HashSet::new(); + + for filter in &self.filters { + let new_ids = filter.apply(existing_ids); + + found_ids.extend(new_ids); + } + + match existing_ids { + Some(ids) => { + let mut new_ids = found_ids; + new_ids.retain(|id| ids.contains(id)); + + new_ids + } + None => found_ids, + } + } +} + +/// The sorting direction for a list of items in a repository. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SortDirection { + Ascending, + Descending, +} + +/// A strategy for sorting a list of loaded items. +pub trait SortingStrategy<'a> { + type IdType; + + fn sort(&self, ids: &mut [Self::IdType]); +} + +/// The default sorting strategy that sorts items based on their natural ordering. +/// +/// When the sorting direction is not specified, it defaults to ascending. +pub struct DefaultSortingStrategy { + pub direction: Option, +} + +impl<'a> SortingStrategy<'a> for DefaultSortingStrategy { + type IdType = UUID; + + fn sort(&self, ids: &mut [Self::IdType]) { + let direction = self.direction.unwrap_or(SortDirection::Ascending); + + ids.sort_by(|a, b| match direction { + SortDirection::Ascending => a.cmp(b), + SortDirection::Descending => b.cmp(a), + }); + } +} diff --git a/libs/ic-canister-macros/src/lib.rs b/libs/ic-canister-macros/src/lib.rs index d6b75fcd9..f4e6cb913 100644 --- a/libs/ic-canister-macros/src/lib.rs +++ b/libs/ic-canister-macros/src/lib.rs @@ -101,3 +101,58 @@ pub fn with_middleware(input_args: TokenStream, input: TokenStream) -> TokenStre input, ) } + +/// The `storable` procedural macro is designed to generate serialization and deserialization +/// implementations for a given struct or enum. It allows for the serialization format to be +/// specified, as well as, optionally the maximum byte size that the object can be serialized to. +/// +/// This macro is compatible with the `ic_stable_strcutures` crate and it can be used to generate +/// objects that can be stored in the stable memory of a canister. +/// +/// # Parameters +/// +/// - `size`: The maximum byte size that the object can be serialized to. +/// - `format`: The serialization format to be used. Supported formats are `cbor` and `candid`. +/// +/// # Usage +/// +/// Annotate a struct or enum with `#[storable]`. +/// +/// # Examples +/// +/// Basic usage with default parameters: +/// +/// ```ignore +/// #[storable] +/// struct MyStruct { +/// field1: u32, +/// field2: String, +/// } +/// ``` +/// +/// Usage with custom parameters: +/// +/// ```ignore +/// #[storable(size = 1000, format = "cbor")] +/// struct MyStruct { +/// field1: u32, +/// field2: String, +/// } +/// ``` +/// +/// # Notes +/// +/// - The macro currently supports only struct and enum items. +#[proc_macro_attribute] +pub fn storable(input_args: TokenStream, input: TokenStream) -> TokenStream { + utils::handle_macro_errors( + |input_args, input| { + let macro_impl = macros::storable::StorableMacro::new(input_args, input); + + macro_impl.build() + }, + macros::storable::StorableMacro::MACRO_NAME, + input_args, + input, + ) +} diff --git a/libs/ic-canister-macros/src/macros/mod.rs b/libs/ic-canister-macros/src/macros/mod.rs index c5f751635..1f2b90860 100644 --- a/libs/ic-canister-macros/src/macros/mod.rs +++ b/libs/ic-canister-macros/src/macros/mod.rs @@ -3,6 +3,7 @@ use proc_macro::TokenStream; pub use stable_object::*; use syn::Error; +pub mod storable; pub mod with_middleware; pub trait MacroDefinition { diff --git a/libs/ic-canister-macros/src/macros/storable.rs b/libs/ic-canister-macros/src/macros/storable.rs new file mode 100644 index 000000000..20e019bed --- /dev/null +++ b/libs/ic-canister-macros/src/macros/storable.rs @@ -0,0 +1,256 @@ +use super::MacroDefinition; +use proc_macro::TokenStream; +use quote::quote; +use std::str::FromStr; +use syn::{parse::Parser, parse2, DeriveInput, Error, Token}; + +/// The arguments passed to the `storable` macro. +/// +/// The macro accepts a list of arguments separated by `,`. +#[derive(Debug)] +struct MacroArguments { + /// The maximum byte size that the object can be serialized to. + /// + /// This should only be used when the object is of a fixed size and the size is known at compile time, otherwise, + /// the bytes will be reserved but not used. + pub size: Option, + + /// The name of the serializer to use for the object. + /// + /// This should be the serialization format that the object should be serialized to and deserialized from. + /// + /// If this is not provided, the object will be serialized to and deserialized from the default serialization format. + pub serializer: SerializerFormat, +} + +#[derive(Debug)] +pub struct StorableMacro { + input_args: TokenStream, + input: TokenStream, +} + +impl MacroDefinition for StorableMacro { + const MACRO_NAME: &'static str = "storable"; + + fn new(input_args: TokenStream, input: TokenStream) -> Self { + Self { input, input_args } + } + + fn build(&self) -> Result { + let args: MacroArguments = self.parse_input_arguments()?; + let expanded_input = self.expand_implementation(&args)?; + + Ok(expanded_input) + } +} + +impl StorableMacro { + const MACRO_ARG_KEY_SIZE: &'static str = "size"; + const MACRO_ARG_KEY_SERIALIZER: &'static str = "serializer"; + + fn expand_implementation(&self, args: &MacroArguments) -> Result { + let parsed_input: DeriveInput = parse2(self.input.clone().into())?; + + match parsed_input.data { + syn::Data::Struct(_) | syn::Data::Enum(_) => { + let input = parsed_input.clone(); + let size_value: Option = args.size; + + match args.serializer { + SerializerFormat::Candid => self.expand_candid_impl(&input, size_value), + SerializerFormat::Cbor => self.expand_cbor_impl(&input, size_value), + } + } + _ => Err(Error::new_spanned( + parsed_input, + "Only structs and enums are supported by the storable macro", + )), + } + } + + fn parse_input_arguments(&self) -> Result { + let parser = syn::punctuated::Punctuated::::parse_terminated; + let args = parser.parse(self.input_args.clone())?; + + let mut size: Option = None; + let mut serializer: SerializerFormat = SerializerFormat::Cbor; + + for expr in args { + let syn::ExprAssign { + left, + right, + attrs: _, + eq_token: _, + } = expr; + + if let syn::Expr::Path(expr_path) = *left { + match expr_path.path.get_ident().unwrap().to_string().as_str() { + Self::MACRO_ARG_KEY_SIZE => { + if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Int(lit_int), + .. + }) = *right + { + size = Some(lit_int.base10_parse()?); + } + } + + Self::MACRO_ARG_KEY_SERIALIZER => { + if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit_str), + .. + }) = *right + { + serializer = + SerializerFormat::from_str(&lit_str.value()).map_err(|err| { + Error::new( + lit_str.span(), + format!( + "Invalid value for the \"{}\" argument: {}", + Self::MACRO_ARG_KEY_SERIALIZER, + err + ), + ) + })?; + } + } + + unknown_arg => { + return Err(Error::new( + expr_path.path.get_ident().unwrap().span(), + format!( + "Unknown argument \"{}\" passed to the \"{}\" macro", + unknown_arg, + Self::MACRO_NAME + ), + )); + } + } + } + } + + Ok(MacroArguments { size, serializer }) + } + + fn expand_cbor_impl( + &self, + input: &DeriveInput, + size: Option, + ) -> Result { + let object_name = input.ident.clone(); + + let expanded = match size { + Some(size) => quote! { + #[derive(serde::Serialize, serde::Deserialize)] + #input + + impl ic_stable_structures::Storable for #object_name { + fn to_bytes(&self) -> std::borrow::Cow<[u8]> { + std::borrow::Cow::Owned(serde_cbor::to_vec(self).unwrap()) + } + + fn from_bytes(bytes: std::borrow::Cow<[u8]>) -> Self { + serde_cbor::from_slice(bytes.as_ref()).unwrap() + } + + const BOUND: ic_stable_structures::storable::Bound = ic_stable_structures::storable::Bound::Bounded { + max_size: #size, + is_fixed_size: false, + }; + } + }, + None => quote! { + #[derive(serde::Serialize, serde::Deserialize)] + #input + + impl ic_stable_structures::Storable for #object_name { + fn to_bytes(&self) -> std::borrow::Cow<[u8]> { + std::borrow::Cow::Owned(serde_cbor::to_vec(self).unwrap()) + } + + fn from_bytes(bytes: std::borrow::Cow<[u8]>) -> Self { + serde_cbor::from_slice(bytes.as_ref()).unwrap() + } + + const BOUND: ic_stable_structures::storable::Bound = ic_stable_structures::storable::Bound::Unbounded; + } + }, + }; + + Ok(expanded.into()) + } + + fn expand_candid_impl( + &self, + input: &DeriveInput, + size: Option, + ) -> Result { + let object_name = input.ident.clone(); + + let expanded = match size { + Some(size) => quote! { + #[derive(candid::CandidType, candid::Deserialize)] + #input + + impl ic_stable_structures::Storable for #object_name { + fn to_bytes(&self) -> std::borrow::Cow<[u8]> { + use candid::Encode; + + std::borrow::Cow::Owned(candid::Encode!(self).unwrap()) + } + + fn from_bytes(bytes: std::borrow::Cow<[u8]>) -> Self { + use candid::Decode; + + candid::Decode!(bytes.as_ref(), Self).unwrap() + } + + const BOUND: ic_stable_structures::storable::Bound = ic_stable_structures::storable::Bound::Bounded { + max_size: #size, + is_fixed_size: false, + }; + } + }, + None => quote! { + #[derive(candid::CandidType, candid::Deserialize)] + #input + + impl ic_stable_structures::Storable for #object_name { + fn to_bytes(&self) -> std::borrow::Cow<[u8]> { + use candid::Encode; + + std::borrow::Cow::Owned(candid::Encode!(self).unwrap()) + } + + fn from_bytes(bytes: std::borrow::Cow<[u8]>) -> Self { + use candid::Decode; + + candid::Decode!(bytes.as_ref(), Self).unwrap() + } + + const BOUND: ic_stable_structures::storable::Bound = ic_stable_structures::storable::Bound::Unbounded; + } + }, + }; + + Ok(expanded.into()) + } +} + +#[derive(Debug)] +enum SerializerFormat { + Candid, + Cbor, +} + +impl std::str::FromStr for SerializerFormat { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "candid" => Ok(Self::Candid), + "cbor" => Ok(Self::Cbor), + _ => Err(format!("Unknown serializer format \"{}\"", s)), + } + } +}