From 3184c9c4737fa81b2f8aa2490bc89a28d1030438 Mon Sep 17 00:00:00 2001 From: Linwei Shang Date: Fri, 10 Jan 2025 15:14:22 -0500 Subject: [PATCH] feat: macros that decoding arguments can set custom decoder using decode_with (#544) * cleanup dfn_macro internal * no arg decoding if function sig has no args * name check * decode_with: set custom arg decoder --- e2e-tests/src/bin/macros.rs | 28 ++++++ e2e-tests/tests/macros.rs | 31 ++++++ ic-cdk-macros/src/export.rs | 184 ++++++++++++++++++++++-------------- 3 files changed, 171 insertions(+), 72 deletions(-) create mode 100644 e2e-tests/src/bin/macros.rs create mode 100644 e2e-tests/tests/macros.rs diff --git a/e2e-tests/src/bin/macros.rs b/e2e-tests/src/bin/macros.rs new file mode 100644 index 00000000..4e6f7c06 --- /dev/null +++ b/e2e-tests/src/bin/macros.rs @@ -0,0 +1,28 @@ +use candid::utils::{decode_args, decode_one}; +use ic_cdk::api::msg_arg_data; +use ic_cdk::update; + +#[update(decode_with = "decode_u0")] +fn u0() {} +fn decode_u0() {} + +#[update(decode_with = "decode_u1")] +fn u1(a: u32) { + assert_eq!(a, 1) +} +fn decode_u1() -> u32 { + let arg_bytes = msg_arg_data(); + decode_one(&arg_bytes).unwrap() +} + +#[update(decode_with = "decode_u2")] +fn u2(a: u32, b: u32) { + assert_eq!(a, 1); + assert_eq!(b, 2); +} +fn decode_u2() -> (u32, u32) { + let arg_bytes = msg_arg_data(); + decode_args(&arg_bytes).unwrap() +} + +fn main() {} diff --git a/e2e-tests/tests/macros.rs b/e2e-tests/tests/macros.rs new file mode 100644 index 00000000..d0a3b85d --- /dev/null +++ b/e2e-tests/tests/macros.rs @@ -0,0 +1,31 @@ +use pocket_ic::call_candid; +use pocket_ic::common::rest::RawEffectivePrincipal; + +mod test_utilities; +use test_utilities::{cargo_build_canister, pocket_ic}; + +#[test] +fn call_macros() { + let pic = pocket_ic(); + let wasm = cargo_build_canister("macros"); + let canister_id = pic.create_canister(); + pic.add_cycles(canister_id, 100_000_000_000_000); + pic.install_canister(canister_id, wasm, vec![], None); + let _: () = call_candid(&pic, canister_id, RawEffectivePrincipal::None, "u0", ()).unwrap(); + let _: () = call_candid( + &pic, + canister_id, + RawEffectivePrincipal::None, + "u1", + (1u32,), + ) + .unwrap(); + let _: () = call_candid( + &pic, + canister_id, + RawEffectivePrincipal::None, + "u2", + (1u32, 2u32), + ) + .unwrap(); +} diff --git a/ic-cdk-macros/src/export.rs b/ic-cdk-macros/src/export.rs index 2b80cf84..b43be818 100644 --- a/ic-cdk-macros/src/export.rs +++ b/ic-cdk-macros/src/export.rs @@ -10,6 +10,7 @@ use syn::{spanned::Spanned, FnArg, ItemFn, Pat, PatIdent, PatType, ReturnType, S struct ExportAttributes { pub name: Option, pub guard: Option, + pub decode_with: Option, #[serde(default)] pub manual_reply: bool, #[serde(default)] @@ -31,6 +32,12 @@ enum MethodType { } impl MethodType { + /// A lifecycle method is a method that is called by the system and not by the user. + /// So far, `update` and `query` are the only methods that are not lifecycle methods. + /// + /// We have a few assumptions for lifecycle methods: + /// - They cannot have a return value. + /// - The export name is prefixed with `canister_`, e.g. `init` => `canister_init`. pub fn is_lifecycle(&self) -> bool { match self { MethodType::Init @@ -42,6 +49,19 @@ impl MethodType { MethodType::Update | MethodType::Query => false, } } + + /// init, post_upgrade, update, query can have arguments. + pub fn can_have_args(&self) -> bool { + match self { + MethodType::Init | MethodType::PostUpgrade | MethodType::Update | MethodType::Query => { + true + } + MethodType::PreUpgrade + | MethodType::Heartbeat + | MethodType::InspectMessage + | MethodType::OnLowWasmMemory => false, + } + } } impl std::fmt::Display for MethodType { @@ -121,79 +141,38 @@ fn dfn_macro( )); } - let is_async = signature.asyncness.is_some(); - - let return_length = match &signature.output { - ReturnType::Default => 0, - ReturnType::Type(_, ty) => match ty.as_ref() { - Type::Tuple(tuple) => tuple.elems.len(), - _ => 1, - }, - }; - - if method.is_lifecycle() && return_length > 0 { - return Err(Error::new( - Span::call_site(), - format!("#[{}] function cannot have a return value.", method), - )); - } - - let (arg_tuple, _): (Vec, Vec>) = - get_args(method, signature)?.iter().cloned().unzip(); + // 1. function name(s) let name = &signature.ident; - let outer_function_ident = format_ident!("__canister_method_{name}"); - - let function_name = attrs.name.unwrap_or_else(|| name.to_string()); - let export_name = if method.is_lifecycle() { - format!("canister_{}", method) - } else if method == MethodType::Query && attrs.composite { - format!("canister_composite_query {function_name}",) - } else { - if function_name.starts_with("") { + let function_name = if let Some(custom_name) = attrs.name { + if method.is_lifecycle() { return Err(Error::new( - Span::call_site(), + attr.span(), + format!("#[{0}] cannot have a custom name.", method), + )); + } + if custom_name.starts_with("") { + return Err(Error::new( + attr.span(), "Functions starting with `` are reserved for CDK internal use.", )); } - format!("canister_{method} {function_name}") - }; - let host_compatible_name = export_name.replace(' ', ".").replace(['-', '<', '>'], "_"); - - let function_call = if is_async { - quote! { #name ( #(#arg_tuple),* ) .await } + custom_name } else { - quote! { #name ( #(#arg_tuple),* ) } + name.to_string() }; - - let arg_count = arg_tuple.len(); - - let return_encode = if method.is_lifecycle() || attrs.manual_reply { - quote! {} - } else { - let return_bytes = match return_length { - 0 => quote! { ::candid::utils::encode_one(()).unwrap() }, - 1 => quote! { ::candid::utils::encode_one(result).unwrap() }, - _ => quote! { ::candid::utils::encode_args(result).unwrap() }, - }; - quote! { - ::ic_cdk::api::msg_reply(#return_bytes); - } - }; - - // On initialization we can actually not receive any input and it's okay, only if - // we don't have any arguments either. - // If the data we receive is not empty, then try to unwrap it as if it's DID. - let arg_decode = if method.is_lifecycle() && arg_count == 0 { - quote! {} + let export_name = if method.is_lifecycle() { + format!("canister_{}", method) + } else if method == MethodType::Query && attrs.composite { + format!("canister_composite_query {function_name}",) } else { - quote! { - let arg_bytes = ::ic_cdk::api::msg_arg_data(); - let ( #( #arg_tuple, )* ) = ::candid::utils::decode_args(&arg_bytes).unwrap(); } + format!("canister_{method} {function_name}") }; + let host_compatible_name = export_name.replace(' ', ".").replace(['-', '<', '>'], "_"); + // 2. guard let guard = if let Some(guard_name) = attrs.guard { - // ic_cdk::api::call::reject calls ic0::msg_reject which is only allowed in update/query + // ic0.msg_reject is only allowed in update/query if method.is_lifecycle() { return Err(Error::new( attr.span(), @@ -213,6 +192,78 @@ fn dfn_macro( quote! {} }; + // 3. decode arguments + let (arg_tuple, _): (Vec, Vec>) = + get_args(method, signature)?.iter().cloned().unzip(); + if !method.can_have_args() { + if !arg_tuple.is_empty() { + return Err(Error::new( + Span::call_site(), + format!("#[{}] function cannot have arguments.", method), + )); + } + if attrs.decode_with.is_some() { + return Err(Error::new( + attr.span(), + format!( + "#[{}] function cannot have a decode_with attribute.", + method + ), + )); + } + } + let arg_decode = if let Some(decode_with) = attrs.decode_with { + let decode_with_ident = syn::Ident::new(&decode_with, Span::call_site()); + if arg_tuple.len() == 1 { + let arg_one = &arg_tuple[0]; + quote! { let #arg_one = #decode_with_ident(); } + } else { + quote! { let ( #( #arg_tuple, )* ) = #decode_with_ident(); } + } + } else if arg_tuple.is_empty() { + quote! {} + } else { + quote! { + let arg_bytes = ::ic_cdk::api::msg_arg_data(); + let ( #( #arg_tuple, )* ) = ::candid::utils::decode_args(&arg_bytes).unwrap(); + } + }; + + // 4. function call + let function_call = if signature.asyncness.is_some() { + quote! { #name ( #(#arg_tuple),* ) .await } + } else { + quote! { #name ( #(#arg_tuple),* ) } + }; + + // 5. return + let return_length = match &signature.output { + ReturnType::Default => 0, + ReturnType::Type(_, ty) => match ty.as_ref() { + Type::Tuple(tuple) => tuple.elems.len(), + _ => 1, + }, + }; + if method.is_lifecycle() && return_length > 0 { + return Err(Error::new( + Span::call_site(), + format!("#[{}] function cannot have a return value.", method), + )); + } + let return_encode = if method.is_lifecycle() || attrs.manual_reply { + quote! {} + } else { + let return_bytes = match return_length { + 0 => quote! { ::candid::utils::encode_one(()).unwrap() }, + 1 => quote! { ::candid::utils::encode_one(result).unwrap() }, + _ => quote! { ::candid::utils::encode_args(result).unwrap() }, + }; + quote! { + ::ic_cdk::api::msg_reply(#return_bytes); + } + }; + + // 6. candid attributes for export_candid!() let candid_method_attr = if attrs.hidden { quote! {} } else { @@ -262,9 +313,6 @@ pub(crate) fn ic_update(attr: TokenStream, item: TokenStream) -> Result Result { dfn_macro(MethodType::Init, attr, item) } @@ -320,8 +368,6 @@ mod test { fn #fn_name() { ::ic_cdk::setup(); ::ic_cdk::spawn(async { - let arg_bytes = ::ic_cdk::api::msg_arg_data(); - let () = ::candid::utils::decode_args(&arg_bytes).unwrap(); let result = query(); ::ic_cdk::api::msg_reply(::candid::utils::encode_one(()).unwrap()); }); @@ -359,8 +405,6 @@ mod test { fn #fn_name() { ::ic_cdk::setup(); ::ic_cdk::spawn(async { - let arg_bytes = ::ic_cdk::api::msg_arg_data(); - let () = ::candid::utils::decode_args(&arg_bytes).unwrap(); let result = query(); ::ic_cdk::api::msg_reply(::candid::utils::encode_one(result).unwrap()); }); @@ -398,8 +442,6 @@ mod test { fn #fn_name() { ::ic_cdk::setup(); ::ic_cdk::spawn(async { - let arg_bytes = ::ic_cdk::api::msg_arg_data(); - let () = ::candid::utils::decode_args(&arg_bytes).unwrap(); let result = query(); ::ic_cdk::api::msg_reply(::candid::utils::encode_args(result).unwrap()); }); @@ -553,8 +595,6 @@ mod test { fn #fn_name() { ::ic_cdk::setup(); ::ic_cdk::spawn(async { - let arg_bytes = ::ic_cdk::api::msg_arg_data(); - let () = ::candid::utils::decode_args(&arg_bytes).unwrap(); let result = query(); ::ic_cdk::api::msg_reply(::candid::utils::encode_one(()).unwrap()); });