From 63a73d1a90965033134b6cc2810f160e3eb9f639 Mon Sep 17 00:00:00 2001 From: Jonathan LEI Date: Thu, 2 Jan 2025 18:55:42 +0800 Subject: [PATCH] feat: snip-12 typed data parsing and hashing --- Cargo.lock | 26 +- README.md | 18 +- examples/snip_12_json.rs | 44 + starknet-core/Cargo.toml | 4 +- starknet-core/src/codec.rs | 20 + starknet-core/src/types/mod.rs | 4 + starknet-core/src/types/typed_data/domain.rs | 170 +++ starknet-core/src/types/typed_data/error.rs | 162 +++ starknet-core/src/types/typed_data/hasher.rs | 40 + starknet-core/src/types/typed_data/mod.rs | 999 ++++++++++++++++++ .../src/types/typed_data/revision.rs | 66 ++ .../src/types/typed_data/shortstring.rs | 78 ++ .../src/types/typed_data/type_definition.rs | 369 +++++++ .../src/types/typed_data/type_reference.rs | 448 ++++++++ starknet-core/src/types/typed_data/types.rs | 371 +++++++ starknet-core/src/types/typed_data/value.rs | 229 ++++ 16 files changed, 3030 insertions(+), 18 deletions(-) create mode 100644 examples/snip_12_json.rs create mode 100644 starknet-core/src/types/typed_data/domain.rs create mode 100644 starknet-core/src/types/typed_data/error.rs create mode 100644 starknet-core/src/types/typed_data/hasher.rs create mode 100644 starknet-core/src/types/typed_data/mod.rs create mode 100644 starknet-core/src/types/typed_data/revision.rs create mode 100644 starknet-core/src/types/typed_data/shortstring.rs create mode 100644 starknet-core/src/types/typed_data/type_definition.rs create mode 100644 starknet-core/src/types/typed_data/type_reference.rs create mode 100644 starknet-core/src/types/typed_data/types.rs create mode 100644 starknet-core/src/types/typed_data/value.rs diff --git a/Cargo.lock b/Cargo.lock index f2df0c1c..9edeeb94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -770,6 +770,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -877,7 +883,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap 2.5.0", + "indexmap 2.7.0", "slab", "tokio", "tokio-util", @@ -902,9 +908,9 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hashbrown" -version = "0.14.5" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" [[package]] name = "hermit-abi" @@ -1116,12 +1122,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.5.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" +checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" dependencies = [ "equivalent", - "hashbrown 0.14.5", + "hashbrown 0.15.2", "serde", ] @@ -1924,7 +1930,7 @@ dependencies = [ "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.5.0", + "indexmap 2.7.0", "serde", "serde_derive", "serde_json", @@ -2091,8 +2097,10 @@ dependencies = [ "criterion", "crypto-bigint", "flate2", + "foldhash", "hex", "hex-literal", + "indexmap 2.7.0", "num-traits", "serde", "serde_json", @@ -2462,7 +2470,7 @@ version = "0.22.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b072cee73c449a636ffd6f32bd8de3a9f7119139aff882f44943ce2986dc5cf" dependencies = [ - "indexmap 2.5.0", + "indexmap 2.7.0", "toml_datetime", "winnow", ] diff --git a/README.md b/README.md index c6d4e2a8..a996f290 100644 --- a/README.md +++ b/README.md @@ -95,21 +95,23 @@ Examples can be found in the [examples folder](./examples): 7. [Encoding and decoding Cairo types](./examples/serde.rs) -8. [Batched JSON-RPC requests](./examples/batch.rs) +8. [Parse a SNIP-12 message and compute its hash](./examples/snip_12_json.rs) -9. [Call a contract view function](./examples/erc20_balance.rs) +9. [Batched JSON-RPC requests](./examples/batch.rs) -10. [Deploy an Argent X account to a pre-funded address](./examples/deploy_argent_account.rs) +10. [Call a contract view function](./examples/erc20_balance.rs) -11. [Inspect public key with Ledger](./examples/ledger_public_key.rs) +11. [Deploy an Argent X account to a pre-funded address](./examples/deploy_argent_account.rs) -12. [Deploy an OpenZeppelin account with Ledger](./examples/deploy_account_with_ledger.rs) +12. [Inspect public key with Ledger](./examples/ledger_public_key.rs) -13. [Transfer ERC20 tokens with Ledger](./examples/transfer_with_ledger.rs) +13. [Deploy an OpenZeppelin account with Ledger](./examples/deploy_account_with_ledger.rs) -14. [Parsing a JSON-RPC request on the server side](./examples/parse_jsonrpc_request.rs) +14. [Transfer ERC20 tokens with Ledger](./examples/transfer_with_ledger.rs) -15. [Inspecting a erased provider-specific error type](./examples/downcast_provider_error.rs) +15. [Parsing a JSON-RPC request on the server side](./examples/parse_jsonrpc_request.rs) + +16. [Inspecting a erased provider-specific error type](./examples/downcast_provider_error.rs) ## License diff --git a/examples/snip_12_json.rs b/examples/snip_12_json.rs new file mode 100644 index 00000000..790fe06f --- /dev/null +++ b/examples/snip_12_json.rs @@ -0,0 +1,44 @@ +use starknet::{core::types::TypedData, macros::felt}; + +fn main() { + let raw = r#"{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Example Message": [ + { "name": "Name", "type": "string" }, + { "name": "Some Array", "type": "u128*" }, + { "name": "Some Object", "type": "My Object" } + ], + "My Object": [ + { "name": "Some Selector", "type": "selector" }, + { "name": "Some Contract Address", "type": "ContractAddress" } + ] + }, + "primaryType": "Example Message", + "domain": { + "name": "Starknet Example", + "version": "1", + "chainId": "SN_MAIN", + "revision": "1" + }, + "message": { + "Name": "some name", + "Some Array": [1, 2, 3, 4], + "Some Object": { + "Some Selector": "transfer", + "Some Contract Address": "0x0123" + } + } +}"#; + + let typed_data = serde_json::from_str::(raw).unwrap(); + println!("SNIP-12 revision: {}", typed_data.revision()); + + let message_hash = typed_data.message_hash(felt!("0x1234")).unwrap(); + println!("SNIP-12 hash: {:#064x}", message_hash); +} diff --git a/starknet-core/Cargo.toml b/starknet-core/Cargo.toml index f12658f7..16d243c0 100644 --- a/starknet-core/Cargo.toml +++ b/starknet-core/Cargo.toml @@ -22,7 +22,9 @@ starknet-core-derive = { version = "0.1.0", path = "../starknet-core-derive" } base64 = { version = "0.21.0", default-features = false, features = ["alloc"] } crypto-bigint = { version = "0.5.1", default-features = false } flate2 = { version = "1.0.25", optional = true } +foldhash = { version = "0.1.4", default-features = false } hex = { version = "0.4.3", default-features = false, features = ["alloc"] } +indexmap = { version = "2.7.0", default-features = false, features = ["serde"] } num-traits = { version = "0.2.19", default-features = false } serde = { version = "1.0.160", default-features = false, features = ["derive"] } serde_json = { version = "1.0.96", default-features = false, features = ["alloc", "raw_value"] } @@ -42,7 +44,7 @@ wasm-bindgen-test = "0.3.34" [features] default = ["std"] -std = ["dep:flate2", "starknet-crypto/std", "starknet-types-core/std"] +std = ["dep:flate2", "starknet-crypto/std", "starknet-types-core/std", "indexmap/std"] no_unknown_fields = [] [[bench]] diff --git a/starknet-core/src/codec.rs b/starknet-core/src/codec.rs index 5f1eadda..ebdefe1d 100644 --- a/starknet-core/src/codec.rs +++ b/starknet-core/src/codec.rs @@ -1,5 +1,6 @@ use alloc::{boxed::Box, fmt::Formatter, format, string::*, vec::*}; use core::{fmt::Display, mem::MaybeUninit}; +use starknet_crypto::{PedersenHasher, PoseidonHasher}; use num_traits::ToPrimitive; @@ -117,12 +118,31 @@ pub struct Error { repr: Box, } +// This implementation is useful for encoding single-felt types. +impl FeltWriter for Felt { + fn write(&mut self, felt: Felt) { + *self = felt; + } +} + impl FeltWriter for Vec { fn write(&mut self, felt: Felt) { self.push(felt); } } +impl FeltWriter for PedersenHasher { + fn write(&mut self, felt: Felt) { + self.update(felt); + } +} + +impl FeltWriter for PoseidonHasher { + fn write(&mut self, felt: Felt) { + self.update(felt); + } +} + impl Encode for Felt { fn encode(&self, writer: &mut W) -> Result<(), Error> { writer.write(*self); diff --git a/starknet-core/src/types/mod.rs b/starknet-core/src/types/mod.rs index fcb43b3a..bfb3dd9a 100644 --- a/starknet-core/src/types/mod.rs +++ b/starknet-core/src/types/mod.rs @@ -11,6 +11,10 @@ mod conversions; mod serde_impls; +/// SNIP-12 typed data. +pub mod typed_data; +pub use typed_data::TypedData; + // TODO: better namespacing of exports? mod codegen; pub use codegen::{ diff --git a/starknet-core/src/types/typed_data/domain.rs b/starknet-core/src/types/typed_data/domain.rs new file mode 100644 index 00000000..d736fe6e --- /dev/null +++ b/starknet-core/src/types/typed_data/domain.rs @@ -0,0 +1,170 @@ +use serde::Deserialize; +use starknet_crypto::poseidon_hash_many; + +use crate::{crypto::compute_hash_on_elements, types::Felt}; + +use super::{revision::Revision, shortstring}; + +/// SNIP-12 type hash of the domain type of revision 0. +/// +/// Compuated as: +/// +/// ```ignore +/// starknet_keccak("StarkNetDomain(name:felt,version:felt,chainId:felt)") +/// ``` +const DOMAIN_TYPE_HASH_V0: Felt = Felt::from_raw([ + 454097714883350422, + 18110465409072164514, + 49961291536018317, + 11250613311408382492, +]); + +/// SNIP-12 type hash of the domain type of revision 1. +/// +/// Compuated as: +/// +/// ```ignore +/// starknet_keccak("\"StarknetDomain\"(\"name\":\"shortstring\",\"version\":\"shortstring\",\"chainId\":\"shortstring\",\"revision\":\"shortstring\")") +/// ``` +const DOMAIN_TYPE_HASH_V1: Felt = Felt::from_raw([ + 45164882192052528, + 3320515356094353366, + 7437117071726711362, + 6953663458211852539, +]); + +/// SNIP-12 typed data domain separator. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +pub struct Domain { + /// Domain name. + #[serde(deserialize_with = "shortstring::deserialize")] + pub name: Felt, + /// Domain version. + #[serde(deserialize_with = "shortstring::deserialize")] + pub version: Felt, + /// Domain chain ID. + #[serde(deserialize_with = "shortstring::deserialize", rename = "chainId")] + pub chain_id: Felt, + /// Domain revision. + #[serde(default = "default_revision")] + pub revision: Revision, +} + +impl Domain { + /// Computes the SNIP-12 hash of the encoded domain. + /// + /// The resulting hash is typically used in calculating the full typed data hash as per SNIP-12. + pub fn encoded_hash(&self) -> Felt { + match self.revision { + Revision::V0 => compute_hash_on_elements(&[ + DOMAIN_TYPE_HASH_V0, + self.name, + self.version, + self.chain_id, + ]), + Revision::V1 => poseidon_hash_many(&[ + DOMAIN_TYPE_HASH_V1, + self.name, + self.version, + self.chain_id, + Felt::ONE, + ]), + } + } +} + +const fn default_revision() -> Revision { + Revision::V0 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_implicit_v0_domain_deser() { + let raw = r###"{ + "name": "Starknet Example", + "version": "1", + "chainId": "SN_MAIN" +}"###; + + let domain = serde_json::from_str::(raw).unwrap(); + assert_eq!(domain.revision, Revision::V0); + + // `shortstring` spec deviation for `starknet.js` compatibility + assert_eq!(domain.version, Felt::ONE); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_explicit_v0_domain_deser() { + let raw = r###"{ + "name": "Starknet Example", + "version": 1, + "chainId": "SN_MAIN", + "revision": "0" +}"###; + + let domain = serde_json::from_str::(raw).unwrap(); + assert_eq!(domain.revision, Revision::V0); + + // `shortstring` spec deviation for `starknet.js` compatibility + assert_eq!(domain.version, Felt::ONE); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_explicit_v1_domain_deser() { + let raw = r###"{ + "name": "Starknet Example", + "version": "1", + "chainId": "SN_MAIN", + "revision": "1" +}"###; + + let domain = serde_json::from_str::(raw).unwrap(); + assert_eq!(domain.revision, Revision::V1); + + // `shortstring` spec deviation for `starknet.js` compatibility + assert_eq!(domain.version, Felt::ONE); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_v0_domain_hash() { + let raw = r###"{ + "name": "Starknet Example", + "version": "1", + "chainId": "SN_MAIN" +}"###; + + let domain = serde_json::from_str::(raw).unwrap(); + assert_eq!( + domain.encoded_hash(), + Felt::from_hex_unchecked( + "0x04f8ee4d303cd69ce9c78edadf62442865c89a1eec01fa413e126a058a69c28a" + ) + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_v1_domain_hash() { + let raw = r###"{ + "name": "Starknet Example", + "version": "1", + "chainId": "SN_MAIN", + "revision": "1" +}"###; + + let domain = serde_json::from_str::(raw).unwrap(); + assert_eq!( + domain.encoded_hash(), + Felt::from_hex_unchecked( + "0x03bfc3e1ff0f5c85c05bb8073a64a40b038eed00a449bc337c8cd2758f634640" + ) + ); + } +} diff --git a/starknet-core/src/types/typed_data/error.rs b/starknet-core/src/types/typed_data/error.rs new file mode 100644 index 00000000..05001679 --- /dev/null +++ b/starknet-core/src/types/typed_data/error.rs @@ -0,0 +1,162 @@ +use alloc::string::*; +use core::fmt::Display; + +use super::{revision::Revision, value::ValueKind}; + +/// Possible errors when processing [`TypedData`](super::TypedData) and its related types. +#[derive(Debug)] +pub enum TypedDataError { + /// Revision implied by `types` is differernt from revision specified by `domain`. + InconsistentRevision { + /// The revision implied from `types` with the domain type definition. + types: Revision, + /// The revision specified by `domain`. + domain: Revision, + }, + /// The type name is invalid. + InvalidTypeName( + /// Type name. + String, + ), + /// The `contains` field exists when it's expected to be absent. + UnexpectedContainsField, + /// A referenced custom type is not defined. + CustomTypeNotFound(String), + /// An expected field is not found. + FieldNotFound( + /// Field name. + String, + ), + /// The value is of a different type than expected> + UnexpectedValueType { + /// The list of expected value types. + expected: &'static [ValueKind], + /// The actual value type. + actual: ValueKind, + }, + /// The number of fields from struct definition is different from the one in value. + StructFieldCountMismatch { + /// The number of fields specificed by the struct definition. + expected: usize, + /// The actual number of fields found in value. + actual: usize, + }, + /// The number of elements from enum variant definition is different from the one in value. + EnumElementCountMismatch { + /// The number of elements specificed by the enum variant definition. + expected: usize, + /// The actual number of elements found in value. + actual: usize, + }, + /// The object representation of an enum value does not have exactly one field. + InvalidEnumFieldCount, + /// The variant name is not found in the enum definition. + EnumVariantNotFound( + /// Variant name. + String, + ), + /// Found a struct when an enum is expected. + UnexpectedStruct( + /// Name of the struct type. + String, + ), + /// Found an enum when a struct is expected. + UnexpectedEnum( + /// Name of the enum type. + String, + ), + /// A Cairo short string cannot be parsed. + InvalidShortString( + /// The Cairo short string. + String, + ), + /// Invalid function selector. + InvalidSelector( + /// The function selector. + String, + ), + /// The string value cannot be parsed into a number. + InvalidNumber( + /// The string value. + String, + ), + /// The Merkle tree is empty. + EmptyMerkleTree, +} + +#[cfg(feature = "std")] +impl std::error::Error for TypedDataError {} + +impl Display for TypedDataError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::InconsistentRevision { types, domain } => { + write!( + f, + "`types` implies revision {} but `domain` uses revision {}", + types, domain + ) + } + Self::InvalidTypeName(type_name) => write!(f, "invalid type name: {}", type_name), + Self::UnexpectedContainsField => { + write!(f, "unexpected presence of the `contains` field") + } + Self::CustomTypeNotFound(type_name) => { + write!(f, "type `{}` not defined", type_name) + } + Self::FieldNotFound(field_name) => { + write!(f, "field `{}` not found in value", field_name) + } + Self::UnexpectedValueType { expected, actual } => { + write!(f, "unexpected value type {}, expecting", actual)?; + + let mut kind_iter = expected.iter().peekable(); + while let Some(kind) = kind_iter.next() { + write!(f, " {}", kind)?; + if kind_iter.peek().is_some() { + write!(f, ",")?; + } + } + Ok(()) + } + Self::StructFieldCountMismatch { expected, actual } => { + write!( + f, + "expected {} fields in struct but found {}", + expected, actual + ) + } + Self::EnumElementCountMismatch { expected, actual } => { + write!( + f, + "expected {} elements in enum variant but found {}", + expected, actual + ) + } + Self::InvalidEnumFieldCount => { + write!(f, "enum values must have 1 and only 1 field") + } + Self::EnumVariantNotFound(variant_name) => { + write!(f, "enum variant `{}` not defined", variant_name) + } + Self::UnexpectedStruct(type_name) => { + write!(f, "expected type `{}` to be enum but is struct", type_name) + } + Self::UnexpectedEnum(type_name) => { + write!(f, "expected type `{}` to be struct but is enum", type_name) + } + Self::InvalidShortString(short_string) => { + write!(f, "\"{}\" is not a valid Cairo short string", short_string) + } + Self::InvalidSelector(selector) => { + write!(f, "\"{}\" is not a valid function selector", selector) + } + Self::InvalidNumber(string_repr) => { + write!(f, "\"{}\" is not a valid number", string_repr) + } + Self::EmptyMerkleTree => { + write!(f, "`merkletree` values must not be empty") + } + } + } +} diff --git a/starknet-core/src/types/typed_data/hasher.rs b/starknet-core/src/types/typed_data/hasher.rs new file mode 100644 index 00000000..e2003375 --- /dev/null +++ b/starknet-core/src/types/typed_data/hasher.rs @@ -0,0 +1,40 @@ +use starknet_crypto::{pedersen_hash, poseidon_hash, PedersenHasher, PoseidonHasher}; + +use crate::{codec::FeltWriter, types::Felt}; + +/// SNIP-12 revision-dependant hasher that can be used to encode data. +pub trait TypedDataHasher: FeltWriter + Default { + fn update(&mut self, msg: Felt); + + fn finalize(self) -> Felt; + + fn hash_two_elements(x: Felt, y: Felt) -> Felt; +} + +impl TypedDataHasher for PedersenHasher { + fn update(&mut self, msg: Felt) { + Self::update(self, msg); + } + + fn finalize(self) -> Felt { + Self::finalize(&self) + } + + fn hash_two_elements(x: Felt, y: Felt) -> Felt { + pedersen_hash(&x, &y) + } +} + +impl TypedDataHasher for PoseidonHasher { + fn update(&mut self, msg: Felt) { + Self::update(self, msg); + } + + fn finalize(self) -> Felt { + Self::finalize(self) + } + + fn hash_two_elements(x: Felt, y: Felt) -> Felt { + poseidon_hash(x, y) + } +} diff --git a/starknet-core/src/types/typed_data/mod.rs b/starknet-core/src/types/typed_data/mod.rs new file mode 100644 index 00000000..22d89987 --- /dev/null +++ b/starknet-core/src/types/typed_data/mod.rs @@ -0,0 +1,999 @@ +use alloc::{borrow::ToOwned, format, vec::*}; +use core::str::FromStr; + +use serde::Deserialize; +use starknet_crypto::{PedersenHasher, PoseidonHasher}; + +use crate::{ + codec::Encode, + types::Felt, + utils::{cairo_short_string_to_felt, get_selector_from_name}, +}; + +mod domain; +pub use domain::Domain; + +mod error; +pub use error::TypedDataError; + +mod hasher; +use hasher::TypedDataHasher; + +mod revision; +pub use revision::Revision; + +mod shortstring; + +mod type_definition; +use type_definition::{CompositeType, EnumDefinition, PresetType, TypeDefinition}; + +mod type_reference; +use type_reference::{CommonTypeReference, TypeReference}; +pub use type_reference::{ElementTypeReference, FullTypeReference, InlineTypeReference}; + +mod types; +pub use types::Types; + +mod value; +pub use value::{ArrayValue, ObjectValue, Value, ValueKind}; + +use super::ByteArray; + +/// Cairo short string encoding of `StarkNet Message`. +const STARKNET_MESSAGE_PREFIX: Felt = Felt::from_raw([ + 257012186512350467, + 18446744073709551605, + 10480951322775611302, + 16156019428408348868, +]); + +/// SNIP-12 typed data for off-chain signatures. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TypedData { + /// Type definitions for the domain separator type and user-defined custom types. + types: Types, + /// Domain separator. + domain: Domain, + /// Reference to the primary/entrypoint type that the `message` field represents. + primary_type: InlineTypeReference, + /// The main message data to be signed, structured as per `primary_type`'s definition. + message: Value, +} + +impl TypedData { + /// Creates a new [`TypedDataError`]. Returns `Err` if `types` and `domain` use + /// different revisions. + pub fn new( + types: Types, + domain: Domain, + primary_type: InlineTypeReference, + message: Value, + ) -> Result { + if types.revision() == domain.revision { + Ok(Self { + types, + domain, + primary_type, + message, + }) + } else { + Err(TypedDataError::InconsistentRevision { + types: types.revision(), + domain: domain.revision, + }) + } + } + + /// Gets the SNIP-12 revision of this [`TypedData`]. + pub const fn revision(&self) -> Revision { + // No need to check against `self.types` as revision consistency is maintained as an + // invariant. + self.domain.revision + } + + /// Computes the SNIP-12 typed data hash to be used for message signing and verification. + /// + /// On-chain signature verification usually involves calling the `is_valid_signature()` function + /// with this hash. + pub fn message_hash(&self, address: Felt) -> Result { + match self.revision() { + Revision::V0 => self.message_hash_with_hasher::(address), + Revision::V1 => self.message_hash_with_hasher::(address), + } + } + + fn message_hash_with_hasher(&self, address: Felt) -> Result + where + H: TypedDataHasher, + { + let mut hasher = H::default(); + hasher.update(STARKNET_MESSAGE_PREFIX); + hasher.update(self.domain.encoded_hash()); + hasher.update(address); + hasher.update(self.encode_value::(&self.primary_type, &self.message)?); + Ok(hasher.finalize()) + } + + fn encode_value(&self, type_ref: &R, value: &Value) -> Result + where + H: TypedDataHasher, + R: TypeReference, + { + let encoded = match type_ref.common() { + CommonTypeReference::Custom(name) => { + // This is either an enum or struct. Depending on the type of the type reference we + // may or may not care which one it is. + + let type_def = self + .types + .get_type(name) + .ok_or_else(|| TypedDataError::CustomTypeNotFound(name.to_owned()))?; + let type_hash = self.types.get_type_hash(name)?; + + // Both struct and enum require the value to be represented as an object + let obj_value = match value { + Value::Object(obj_value) => obj_value, + Value::String(_) + | Value::UnsignedInteger(_) + | Value::SignedInteger(_) + | Value::Boolean(_) + | Value::Array(_) => { + return Err(TypedDataError::UnexpectedValueType { + expected: &[ValueKind::Object], + actual: value.kind(), + }); + } + }; + + match type_def { + TypeDefinition::Struct(struct_def) => { + if type_ref.must_be_enum() { + return Err(TypedDataError::UnexpectedStruct(name.to_owned())); + } + + self.encode_composite::(type_hash, struct_def, obj_value)? + } + TypeDefinition::Enum(enum_def) => { + if type_ref.must_be_struct() { + return Err(TypedDataError::UnexpectedEnum(name.to_owned())); + } + + self.encode_enum::(type_hash, enum_def, obj_value)? + } + } + } + CommonTypeReference::Array(element_type) => { + let arr_value = match value { + Value::Array(arr_value) => arr_value, + Value::String(_) + | Value::UnsignedInteger(_) + | Value::SignedInteger(_) + | Value::Boolean(_) + | Value::Object(_) => { + return Err(TypedDataError::UnexpectedValueType { + expected: &[ValueKind::Array], + actual: value.kind(), + }); + } + }; + + let mut hasher = H::default(); + + for element in &arr_value.elements { + hasher.update(self.encode_value::(element_type, element)?); + } + + hasher.finalize() + } + // Technically, SNIP-12 specifies that `felt` and `shortstring` should behave + // differently. Unfortunately, `starknet.js` ships a buggy implementation that treats + // both types the same. We deviate from the spec here to be compatible: + // + // https://github.com/starknet-io/starknet.js/issues/1039 + CommonTypeReference::Felt | CommonTypeReference::ShortString => match value { + Value::String(str_value) => { + // This is to reimplement the `starknet.js` bug + let decoded_as_raw = match str_value.strip_prefix("0x") { + Some(hexadecimal) => { + if hexadecimal.chars().all(|c| c.is_ascii_hexdigit()) { + Felt::from_hex(str_value).ok() + } else { + None + } + } + None => { + if str_value.chars().all(|c| c.is_ascii_digit()) { + Felt::from_dec_str(str_value).ok() + } else { + None + } + } + }; + + match decoded_as_raw { + Some(raw) => raw, + None => cairo_short_string_to_felt(str_value).map_err(|_| { + TypedDataError::InvalidShortString(str_value.to_owned()) + })?, + } + } + Value::UnsignedInteger(int_value) => (*int_value).into(), + Value::SignedInteger(_) + | Value::Boolean(_) + | Value::Object(_) + | Value::Array(_) => { + return Err(TypedDataError::UnexpectedValueType { + expected: &[ValueKind::String, ValueKind::UnsignedInteger], + actual: value.kind(), + }); + } + }, + CommonTypeReference::Bool => match value { + Value::Boolean(false) => Felt::ZERO, + Value::Boolean(true) => Felt::ONE, + Value::String(_) + | Value::UnsignedInteger(_) + | Value::SignedInteger(_) + | Value::Object(_) + | Value::Array(_) => { + return Err(TypedDataError::UnexpectedValueType { + expected: &[ValueKind::Boolean], + actual: value.kind(), + }); + } + }, + CommonTypeReference::String => { + let str_value = match value { + Value::String(str_value) => str_value, + Value::UnsignedInteger(_) + | Value::SignedInteger(_) + | Value::Boolean(_) + | Value::Object(_) + | Value::Array(_) => { + return Err(TypedDataError::UnexpectedValueType { + expected: &[ValueKind::String], + actual: value.kind(), + }); + } + }; + + match self.revision() { + Revision::V0 => { + // In revision 0 `string` is treated as short string. + + cairo_short_string_to_felt(str_value) + .map_err(|_| TypedDataError::InvalidShortString(str_value.to_owned()))? + } + Revision::V1 => { + // In revision 1 `string` is treated as `ByteArray`. + + let mut hasher = H::default(); + + // `ByteArray` encoding never fails + ByteArray::from(str_value.as_str()) + .encode(&mut hasher) + .unwrap(); + + hasher.finalize() + } + } + } + CommonTypeReference::Selector => { + let str_value = match value { + Value::String(str_value) => str_value, + Value::UnsignedInteger(_) + | Value::SignedInteger(_) + | Value::Boolean(_) + | Value::Object(_) + | Value::Array(_) => { + return Err(TypedDataError::UnexpectedValueType { + expected: &[ValueKind::String], + actual: value.kind(), + }); + } + }; + + get_selector_from_name(str_value) + .map_err(|_| TypedDataError::InvalidSelector(str_value.to_owned()))? + } + CommonTypeReference::MerkleTree(leaf) => { + let arr_value = match value { + Value::Array(arr_value) => arr_value, + Value::String(_) + | Value::UnsignedInteger(_) + | Value::SignedInteger(_) + | Value::Boolean(_) + | Value::Object(_) => { + return Err(TypedDataError::UnexpectedValueType { + expected: &[ValueKind::Array], + actual: value.kind(), + }); + } + }; + + self.encode_merkletree::(leaf, arr_value)? + } + // Technically `timestamp` should be restricted to `u64` range but `starknet.js` allows + // it to be treated the same way as `u128`. + CommonTypeReference::Timestamp | CommonTypeReference::U128 => { + let int_value = match value { + Value::UnsignedInteger(int_value) => *int_value, + // Technically SNIP-12 does not allow strings here but `starknet.js` does, so we + // do it here to be compatible. + Value::String(str_value) => match str_value.strip_prefix("0x") { + Some(hex_str) => u128::from_str_radix(hex_str, 16), + None => str_value.parse::(), + } + .map_err(|_| TypedDataError::InvalidNumber(str_value.to_owned()))?, + Value::SignedInteger(_) + | Value::Boolean(_) + | Value::Object(_) + | Value::Array(_) => { + return Err(TypedDataError::UnexpectedValueType { + expected: &[ValueKind::UnsignedInteger, ValueKind::String], + actual: value.kind(), + }); + } + }; + + int_value.into() + } + CommonTypeReference::I128 => { + let int_value = match value { + Value::SignedInteger(int_value) => *int_value, + Value::String(_) + | Value::UnsignedInteger(_) + | Value::Boolean(_) + | Value::Object(_) + | Value::Array(_) => { + return Err(TypedDataError::UnexpectedValueType { + expected: &[ValueKind::UnsignedInteger, ValueKind::String], + actual: value.kind(), + }); + } + }; + + let mut encoded = Felt::ZERO; + + // Encoding `i128` never fails + int_value.encode(&mut encoded).unwrap(); + + encoded + } + CommonTypeReference::ContractAddress | CommonTypeReference::ClassHash => { + let str_value = match value { + Value::String(str_value) => str_value, + Value::UnsignedInteger(_) + | Value::SignedInteger(_) + | Value::Boolean(_) + | Value::Object(_) + | Value::Array(_) => { + return Err(TypedDataError::UnexpectedValueType { + expected: &[ValueKind::String], + actual: value.kind(), + }); + } + }; + + Felt::from_str(str_value) + .map_err(|_| TypedDataError::InvalidNumber(str_value.to_owned()))? + } + CommonTypeReference::U256 => { + let obj_value = match value { + Value::Object(obj_value) => obj_value, + Value::String(_) + | Value::UnsignedInteger(_) + | Value::SignedInteger(_) + | Value::Boolean(_) + | Value::Array(_) => { + return Err(TypedDataError::UnexpectedValueType { + expected: &[ValueKind::Object], + actual: value.kind(), + }); + } + }; + + self.encode_composite::( + PresetType::U256.type_hash(self.revision()), + &PresetType::U256, + obj_value, + )? + } + CommonTypeReference::TokenAmount => { + let obj_value = match value { + Value::Object(obj_value) => obj_value, + Value::String(_) + | Value::UnsignedInteger(_) + | Value::SignedInteger(_) + | Value::Boolean(_) + | Value::Array(_) => { + return Err(TypedDataError::UnexpectedValueType { + expected: &[ValueKind::Object], + actual: value.kind(), + }); + } + }; + + self.encode_composite::( + PresetType::TokenAmount.type_hash(self.revision()), + &PresetType::TokenAmount, + obj_value, + )? + } + CommonTypeReference::NftId => { + let obj_value = match value { + Value::Object(obj_value) => obj_value, + Value::String(_) + | Value::UnsignedInteger(_) + | Value::SignedInteger(_) + | Value::Boolean(_) + | Value::Array(_) => { + return Err(TypedDataError::UnexpectedValueType { + expected: &[ValueKind::Object], + actual: value.kind(), + }); + } + }; + + self.encode_composite::( + PresetType::NftId.type_hash(self.revision()), + &PresetType::NftId, + obj_value, + )? + } + }; + + Ok(encoded) + } + + fn encode_composite( + &self, + type_hash: Felt, + struct_def: &T, + value: &ObjectValue, + ) -> Result + where + H: TypedDataHasher, + T: CompositeType, + { + let mut hasher = H::default(); + hasher.update(type_hash); + + if value.fields.len() != struct_def.field_len() { + return Err(TypedDataError::StructFieldCountMismatch { + expected: struct_def.field_len(), + actual: value.fields.len(), + }); + } + + for (field_name, field_type) in struct_def.field_iter() { + let value = value + .fields + .get(field_name) + .ok_or_else(|| TypedDataError::FieldNotFound(field_name.to_owned()))?; + hasher.update(self.encode_value::(field_type, value)?); + } + + Ok(hasher.finalize()) + } + + fn encode_enum( + &self, + type_hash: Felt, + enum_def: &EnumDefinition, + value: &ObjectValue, + ) -> Result + where + H: TypedDataHasher, + { + let mut hasher = H::default(); + hasher.update(type_hash); + + let mut value_field_iter = value.fields.iter(); + + let (variant_name, variant_value) = value_field_iter + .next() + .ok_or(TypedDataError::InvalidEnumFieldCount)?; + let tuple_values = match variant_value { + Value::Array(arr_value) => arr_value, + Value::String(_) + | Value::UnsignedInteger(_) + | Value::SignedInteger(_) + | Value::Boolean(_) + | Value::Object(_) => { + return Err(TypedDataError::UnexpectedValueType { + expected: &[ValueKind::Array], + actual: variant_value.kind(), + }); + } + }; + + let (variant_ind, variant_def) = enum_def + .variants + .iter() + .enumerate() + .find(|(_, variant)| &variant.name == variant_name) + .ok_or_else(|| TypedDataError::EnumVariantNotFound(variant_name.to_owned()))?; + hasher.update(variant_ind.into()); + + if variant_def.tuple_types.len() != tuple_values.elements.len() { + return Err(TypedDataError::EnumElementCountMismatch { + expected: variant_def.tuple_types.len(), + actual: tuple_values.elements.len(), + }); + } + + for (tuple_slot_def, tuple_slot_value) in variant_def + .tuple_types + .iter() + .zip(tuple_values.elements.iter()) + { + hasher.update(self.encode_value::(tuple_slot_def, tuple_slot_value)?); + } + + // Enum repr must have only one field + if value_field_iter.next().is_some() { + return Err(TypedDataError::InvalidEnumFieldCount); + } + + Ok(hasher.finalize()) + } + + fn encode_merkletree( + &self, + leaf_type_def: &InlineTypeReference, + value: &ArrayValue, + ) -> Result + where + H: TypedDataHasher, + { + // It's unclear how an empty Merkle tree should be hashed. Interestingly, `starknet.js` gets + // stuck in an infinite recursion loop when fed with an empty list of leaves. So it should + // be safe to reject empty Merkle trees here. + if value.elements.is_empty() { + return Err(TypedDataError::EmptyMerkleTree); + } + + let element_hashes = value + .elements + .iter() + .map(|element| self.encode_value::(leaf_type_def, element)) + .collect::, _>>()?; + + Ok(Self::compute_merkle_root::(&element_hashes)) + } + + fn compute_merkle_root(layer: &[Felt]) -> Felt + where + H: TypedDataHasher, + { + let mut new_layer = Vec::with_capacity((layer.len() + 1) / 2); + for chunk in layer.chunks(2) { + new_layer.push(if chunk.len() == 2 { + if chunk[0] <= chunk[1] { + H::hash_two_elements(chunk[0], chunk[1]) + } else { + H::hash_two_elements(chunk[1], chunk[0]) + } + } else { + H::hash_two_elements(Felt::ZERO, chunk[0]) + }) + } + + // TODO: refactor to remove recursion and reuse a single buffer + if new_layer.len() == 1 { + new_layer[0] + } else { + Self::compute_merkle_root::(&new_layer) + } + } +} + +impl<'de> Deserialize<'de> for TypedData { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + struct Raw { + types: Types, + domain: Domain, + #[serde(rename = "primaryType")] + primary_type: InlineTypeReference, + message: Value, + } + + let raw = Raw::deserialize(deserializer)?; + Self::new(raw.types, raw.domain, raw.primary_type, raw.message) + .map_err(|err| serde::de::Error::custom(format!("{}", err))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const VALID_V0_DATA: &str = r###"{ + "types": { + "StarkNetDomain": [ + { "name": "name", "type": "felt" }, + { "name": "version", "type": "felt" }, + { "name": "chainId", "type": "felt" } + ], + "Example Message": [ + { "name": "Name", "type": "string" }, + { "name": "Some Array", "type": "u128*" }, + { "name": "Some Object", "type": "My Object" } + ], + "My Object": [ + { "name": "Some Selector", "type": "selector" }, + { "name": "Some Contract Address", "type": "ContractAddress" } + ] + }, + "primaryType": "Example Message", + "domain": { + "name": "Starknet Example", + "version": "1", + "chainId": "SN_MAIN" + }, + "message": { + "Name": "some name", + "Some Array": [1, 2, 3, 4], + "Some Object": { + "Some Selector": "transfer", + "Some Contract Address": "0x0123" + } + } +}"###; + + const VALID_V1_DATA: &str = r###"{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Example Message": [ + { "name": "Name", "type": "string" }, + { "name": "Some Array", "type": "u128*" }, + { "name": "Some Object", "type": "My Object" } + ], + "My Object": [ + { "name": "Some Selector", "type": "selector" }, + { "name": "Some Contract Address", "type": "ContractAddress" } + ] + }, + "primaryType": "Example Message", + "domain": { + "name": "Starknet Example", + "version": "1", + "chainId": "SN_MAIN", + "revision": "1" + }, + "message": { + "Name": "some name", + "Some Array": [1, 2, 3, 4], + "Some Object": { + "Some Selector": "transfer", + "Some Contract Address": "0x0123" + } + } +}"###; + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_successful_deser_v0() { + serde_json::from_str::(VALID_V0_DATA).unwrap(); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_successful_deser_v1() { + serde_json::from_str::(VALID_V1_DATA).unwrap(); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_inconsistent_revision_deser() { + let raw = r###"{ + "types": { + "StarkNetDomain": [ + { "name": "name", "type": "felt" }, + { "name": "version", "type": "felt" }, + { "name": "chainId", "type": "felt" } + ], + "Example Message": [ + { "name": "Name", "type": "string" }, + { "name": "Some Array", "type": "u128*" }, + { "name": "Some Object", "type": "My Object" } + ], + "My Object": [ + { "name": "Some Selector", "type": "selector" }, + { "name": "Some Contract Address", "type": "ContractAddress" } + ] + }, + "primaryType": "Example Message", + "domain": { + "name": "Starknet Example", + "version": "1", + "chainId": "SN_MAIN", + "revision": "1" + }, + "message": { + "Name": "some name", + "Some Array": [1, 2, 3, 4], + "Some Object": { + "Some Selector": "transfer", + "Some Contract Address": "0x0123" + } + } +}"###; + + assert_eq!( + serde_json::from_str::(raw) + .unwrap_err() + .to_string(), + "`types` implies revision 0 but `domain` uses revision 1" + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_message_hash_v0() { + let data = serde_json::from_str::(VALID_V0_DATA).unwrap(); + + assert_eq!( + data.message_hash(Felt::from_hex_unchecked("0x1234")) + .unwrap(), + Felt::from_hex_unchecked( + "0x0778d68fe2baf73ee78a6711c29bad4722680984c1553a8035c8cb3feb5310c9" + ) + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_message_hash_v1_with_struct() { + let data = serde_json::from_str::(VALID_V1_DATA).unwrap(); + + assert_eq!( + data.message_hash(Felt::from_hex_unchecked("0x1234")) + .unwrap(), + Felt::from_hex_unchecked( + "0x045bca39274d2b7fdf7dc7c4ecf75f6549f614ce44359cc62ec106f4e5cc87b4" + ) + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_message_hash_v1_with_basic_types() { + let raw = r###"{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Example Message": [ + { "name": "Bool", "type": "bool" }, + { "name": "I128", "type": "i128" }, + { "name": "Classhash", "type": "ClassHash" }, + { "name": "Timestamp", "type": "timestamp" }, + { "name": "Short1", "type": "shortstring" }, + { "name": "Short2", "type": "shortstring" }, + { "name": "Short3", "type": "shortstring" } + ] + }, + "primaryType": "Example Message", + "domain": { + "name": "Starknet Example", + "version": "1", + "chainId": "SN_MAIN", + "revision": "1" + }, + "message": { + "Bool": true, + "I128": -123, + "Classhash": "0x1234", + "Timestamp": 1234, + "Short1": 123, + "Short2": "0x123", + "Short3": "hello" + } +}"###; + + let data = serde_json::from_str::(raw).unwrap(); + + assert_eq!( + data.message_hash(Felt::from_hex_unchecked("0x1234")) + .unwrap(), + Felt::from_hex_unchecked( + "0x0795c7e03a0ef83c4e3dee6942ef64d4126a91cafbda207356dae1de3bed4063" + ) + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_message_hash_v1_with_preset() { + let raw = r###"{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Example Message": [ + { "name": "Uint", "type": "u256" }, + { "name": "Amount", "type": "TokenAmount" }, + { "name": "Id", "type": "NftId" } + ] + }, + "primaryType": "Example Message", + "domain": { + "name": "Starknet Example", + "version": "1", + "chainId": "SN_MAIN", + "revision": "1" + }, + "message": { + "Uint": { + "low": "1234", + "high": "0x5678" + }, + "Amount": { + "token_address": "0x11223344", + "amount": { + "low": 1000000, + "high": 0 + } + }, + "Id": { + "collection_address": "0x55667788", + "token_id": { + "low": "0x12345678", + "high": 0 + } + } + } +}"###; + + let data = serde_json::from_str::(raw).unwrap(); + + assert_eq!( + data.message_hash(Felt::from_hex_unchecked("0x1234")) + .unwrap(), + Felt::from_hex_unchecked( + "0x068b85f4061d8155c0445f7e3c6bae1e7641b88b1d3b7c034c0b4f6c30eb5049" + ) + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_message_hash_v1_with_enum() { + let raw = r###"{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Example Message": [ + { "name": "Value", "type": "enum", "contains": "My Enum" } + ], + "My Enum": [ + { "name": "Variant 1", "type": "()" }, + { "name": "Variant 2", "type": "(string,My Object*)" }, + { "name": "Variant 3", "type": "(u128)" } + ], + "My Object": [ + { "name": "Some Selector", "type": "selector" }, + { "name": "Some Contract Address", "type": "ContractAddress" } + ] + }, + "primaryType": "Example Message", + "domain": { + "name": "Starknet Example", + "version": "1", + "chainId": "SN_MAIN", + "revision": "1" + }, + "message": { + "Value": { + "Variant 2": [ + "tuple element", + [ + { + "Some Selector": "transfer", + "Some Contract Address": "0x1234" + }, + { + "Some Selector": "approve", + "Some Contract Address": "0x5678" + } + ] + ] + } + } +}"###; + + let data = serde_json::from_str::(raw).unwrap(); + + assert_eq!( + data.message_hash(Felt::from_hex_unchecked("0x1234")) + .unwrap(), + Felt::from_hex_unchecked( + "0x03745761c0f8ab5f0dbbba52b448f7db6ebfecbf74069073dcbf4fc5a6608125" + ) + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_message_hash_v1_with_merkletree() { + let raw = r###"{ + "types": { + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Example Message": [ + { "name": "Value", "type": "merkletree", "contains": "My Object" } + ], + "My Object": [ + { "name": "Some Selector", "type": "selector" }, + { "name": "Some Contract Address", "type": "ContractAddress" } + ] + }, + "primaryType": "Example Message", + "domain": { + "name": "Starknet Example", + "version": "1", + "chainId": "SN_MAIN", + "revision": "1" + }, + "message": { + "Value": [ + { + "Some Selector": "selector1", + "Some Contract Address": "0x1111" + }, + { + "Some Selector": "selector2", + "Some Contract Address": "0x2222" + }, + { + "Some Selector": "selector3", + "Some Contract Address": "0x3333" + }, + { + "Some Selector": "selector4", + "Some Contract Address": "0x4444" + }, + { + "Some Selector": "selector5", + "Some Contract Address": "0x5555" + } + ] + } +}"###; + + let data = serde_json::from_str::(raw).unwrap(); + + assert_eq!( + data.message_hash(Felt::from_hex_unchecked("0x1234")) + .unwrap(), + Felt::from_hex_unchecked( + "0x064bd27eb802de8c83ff1437394c142bbe771530a248c548fab27ac3bcd2a503" + ) + ); + } +} diff --git a/starknet-core/src/types/typed_data/revision.rs b/starknet-core/src/types/typed_data/revision.rs new file mode 100644 index 00000000..b47ae3ce --- /dev/null +++ b/starknet-core/src/types/typed_data/revision.rs @@ -0,0 +1,66 @@ +use serde::{de::Visitor, Deserialize}; + +/// Revision of SNIP-12. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Revision { + /// The legacy, deprecated revision of SNIP-12. + V0, + /// The current active revision of SNIP-12. + V1, +} + +impl core::fmt::Display for Revision { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::V0 => write!(f, "0"), + Self::V1 => write!(f, "1"), + } + } +} + +struct RevisionVisitor; + +impl Visitor<'_> for RevisionVisitor { + type Value = Revision; + + fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(formatter, "string or integer") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + match v { + "0" => Ok(Revision::V0), + "1" => Ok(Revision::V1), + _ => Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(v), + &"\"0\" or \"1\"", + )), + } + } + + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + match v { + 0 => Ok(Revision::V0), + 1 => Ok(Revision::V1), + _ => Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Unsigned(v), + &"0 or 1", + )), + } + } +} + +impl<'de> Deserialize<'de> for Revision { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(RevisionVisitor) + } +} diff --git a/starknet-core/src/types/typed_data/shortstring.rs b/starknet-core/src/types/typed_data/shortstring.rs new file mode 100644 index 00000000..0bd440b8 --- /dev/null +++ b/starknet-core/src/types/typed_data/shortstring.rs @@ -0,0 +1,78 @@ +//! Module for handling `shortstring` serialization/desesrialization. +//! +//! Technically this module shouldn't exist, or at least should be straightforward, as a very simple +//! Cairo short string encoding/decoding step would suffice. Unfortunately, starknet.js ships a bug: +//! +//! +//! +//! Since starknet.js is widely used, it's essentially the de facto spec. We must reimplement the +//! bug here by conditionally encoding as Cairo short string only when the source string is not a +//! valid integer or decimal/hexadecimal repr. + +use serde::de::Visitor; + +use crate::{types::Felt, utils::cairo_short_string_to_felt}; + +struct ShortStringVisitor; + +impl Visitor<'_> for ShortStringVisitor { + type Value = Felt; + + fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(formatter, "string or integer") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + // This is to reimplement the `starknet.js` bug + let decoded_as_raw = match v.strip_prefix("0x") { + Some(hexadecimal) => { + if hexadecimal.chars().all(|c| c.is_ascii_hexdigit()) { + Felt::from_hex(v).ok() + } else { + None + } + } + None => { + if v.chars().all(|c| c.is_ascii_digit()) { + Felt::from_dec_str(v).ok() + } else { + None + } + } + }; + + match decoded_as_raw { + Some(raw) => Ok(raw), + None => cairo_short_string_to_felt(v).map_err(|_| { + serde::de::Error::invalid_value( + serde::de::Unexpected::Str(v), + &"valid Cairo short string", + ) + }), + } + } + + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + Ok(v.into()) + } + + fn visit_u128(self, v: u128) -> Result + where + E: serde::de::Error, + { + Ok(v.into()) + } +} + +pub fn deserialize<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + deserializer.deserialize_any(ShortStringVisitor) +} diff --git a/starknet-core/src/types/typed_data/type_definition.rs b/starknet-core/src/types/typed_data/type_definition.rs new file mode 100644 index 00000000..080cf6d2 --- /dev/null +++ b/starknet-core/src/types/typed_data/type_definition.rs @@ -0,0 +1,369 @@ +use alloc::{format, string::*, vec::*}; +use core::str::FromStr; + +use serde::{de::Unexpected, Deserialize}; + +use crate::{types::Felt, utils::starknet_keccak}; + +use super::{ + revision::Revision, + type_reference::{FullTypeReference, InlineTypeReference}, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TypeDefinition { + Struct(StructDefinition), + Enum(EnumDefinition), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StructDefinition { + pub fields: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FieldDefinition { + pub name: String, + pub r#type: FullTypeReference, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EnumDefinition { + pub variants: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct VariantDefinition { + pub name: String, + pub tuple_types: Vec, +} + +/// Internal trait for working with both user-defined types and preset types at the same time. +pub(crate) trait CompositeType { + fn field_iter(&self) -> impl Iterator; + + fn field_len(&self) -> usize; +} + +/// Internal type for type signature generation for preset types. +pub(crate) enum PresetType { + U256, + TokenAmount, + NftId, +} + +/// Internal type for implementing [`TypeDefinition`] deserialization. +enum FieldOrVariantDefinition { + Field(FieldDefinition), + Variant(VariantDefinition), +} + +impl TypeDefinition { + pub(crate) fn is_v0_domain(&self) -> bool { + match self { + Self::Struct(def) => { + def.fields.len() == 3 + && def.fields[0].name == "name" + && def.fields[0].r#type == FullTypeReference::Felt + && def.fields[1].name == "version" + && def.fields[1].r#type == FullTypeReference::Felt + && def.fields[2].name == "chainId" + && def.fields[2].r#type == FullTypeReference::Felt + } + Self::Enum(_) => false, + } + } + + pub(crate) fn is_v1_domain(&self) -> bool { + match self { + Self::Struct(def) => { + def.fields.len() == 4 + && def.fields[0].name == "name" + && def.fields[0].r#type == FullTypeReference::ShortString + && def.fields[1].name == "version" + && def.fields[1].r#type == FullTypeReference::ShortString + && def.fields[2].name == "chainId" + && def.fields[2].r#type == FullTypeReference::ShortString + && def.fields[3].name == "revision" + && def.fields[3].r#type == FullTypeReference::ShortString + } + Self::Enum(_) => false, + } + } +} + +impl PresetType { + pub const fn name(&self) -> &'static str { + match self { + Self::U256 => "u256", + Self::TokenAmount => "TokenAmount", + Self::NftId => "NftId", + } + } + + pub const fn type_signature(&self, revision: Revision) -> &'static str { + match self { + Self::U256 => match revision { + Revision::V0 => "u256(low:u128,high:u128)", + Revision::V1 => "\"u256\"(\"low\":\"u128\",\"high\":\"u128\")", + }, + Self::TokenAmount => match revision { + Revision::V0 => "TokenAmount(token_address:ContractAddress,amount:u256)", + Revision::V1 => { + "\"TokenAmount\"(\"token_address\":\"ContractAddress\",\"amount\":\"u256\")" + } + }, + Self::NftId => match revision { + Revision::V0 => "NftId(collection_address:ContractAddress,token_id:u256)", + Revision::V1 => { + "\"NftId\"(\"collection_address\":\"ContractAddress\",\"token_id\":\"u256\")" + } + }, + } + } + + // TODO: make this a const fn + pub fn type_hash(&self, revision: Revision) -> Felt { + match self { + Self::U256 => starknet_keccak(self.type_signature(revision).as_bytes()), + Self::TokenAmount | Self::NftId => starknet_keccak( + format!( + "{}{}", + self.type_signature(revision), + Self::U256.type_signature(revision) + ) + .as_bytes(), + ), + } + } +} + +impl<'de> Deserialize<'de> for TypeDefinition { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let elements = Vec::::deserialize(deserializer)?; + + match elements.first() { + Some(FieldOrVariantDefinition::Field(_)) => { + // This is a struct definition + let mut fields = Vec::new(); + for element in elements { + match element { + FieldOrVariantDefinition::Field(field) => fields.push(field), + FieldOrVariantDefinition::Variant(_) => { + return Err(serde::de::Error::invalid_type( + Unexpected::Other("enum variant definition"), + &"struct field definition", + )) + } + } + } + Ok(Self::Struct(StructDefinition { fields })) + } + Some(FieldOrVariantDefinition::Variant(_)) => { + // This is an enum definition + let mut variants = Vec::new(); + for element in elements { + match element { + FieldOrVariantDefinition::Variant(variant) => variants.push(variant), + FieldOrVariantDefinition::Field(_) => { + return Err(serde::de::Error::invalid_type( + Unexpected::Other("struct field definition"), + &"enum variant definition", + )) + } + } + } + Ok(Self::Enum(EnumDefinition { variants })) + } + None => Err(serde::de::Error::invalid_length( + 0, + &"at least 1 field or variant", + )), + } + } +} + +impl<'de> Deserialize<'de> for FieldOrVariantDefinition { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(deny_unknown_fields)] + struct Raw { + name: String, + r#type: String, + contains: Option, + } + + let raw = Raw::deserialize(deserializer)?; + if raw.name.is_empty() { + return Err(serde::de::Error::invalid_value( + Unexpected::Str(""), + &"non-empty name", + )); + } + + if raw.r#type.starts_with('(') { + // Enum variant definition + + if !raw.r#type.ends_with(')') { + return Err(serde::de::Error::invalid_value( + Unexpected::Str(&raw.r#type), + &"enclosing parentheses", + )); + } + if raw.contains.is_some() { + // Enum variants have no `contains` field + return Err(serde::de::Error::unknown_field( + "contains", + &["name", "type"], + )); + } + + let joined_tuple_types = &raw.r#type[1..(raw.r#type.len() - 1)]; + if joined_tuple_types.is_empty() { + Ok(Self::Variant(VariantDefinition { + name: raw.name, + tuple_types: Vec::new(), + })) + } else { + let tuple_types = joined_tuple_types + .split(',') + .map(|raw_type| { + // Trimming here feels weird but the example from SNIP-12 has a space after + // `,` so it seems that whitespaces are allowed. + InlineTypeReference::from_str(raw_type.trim()).map_err(|err| { + serde::de::Error::custom(format!( + "invalid inline type reference: {}", + err + )) + }) + }) + .collect::, _>>()?; + + Ok(Self::Variant(VariantDefinition { + name: raw.name, + tuple_types, + })) + } + } else { + // Struct field definition + Ok(Self::Field(FieldDefinition { + name: raw.name, + r#type: FullTypeReference::from_parts(raw.r#type, raw.contains).map_err(|err| { + serde::de::Error::custom(format!("invalid full type reference: {}", err)) + })?, + })) + } + } +} + +impl CompositeType for StructDefinition { + fn field_iter(&self) -> impl Iterator { + self.fields + .iter() + .map(|field| (field.name.as_str(), &field.r#type)) + } + + fn field_len(&self) -> usize { + self.fields.len() + } +} + +impl CompositeType for PresetType { + fn field_iter(&self) -> impl Iterator { + match self { + Self::U256 => [ + ("low", &FullTypeReference::U128), + ("high", &FullTypeReference::U128), + ] + .into_iter(), + Self::TokenAmount => [ + ("token_address", &FullTypeReference::ContractAddress), + ("amount", &FullTypeReference::U256), + ] + .into_iter(), + Self::NftId => [ + ("collection_address", &FullTypeReference::ContractAddress), + ("token_id", &FullTypeReference::U256), + ] + .into_iter(), + } + } + + fn field_len(&self) -> usize { + 2 + } +} + +#[cfg(test)] +mod tests { + use super::super::type_reference::ElementTypeReference; + use super::*; + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_struct_def_deser() { + let raw = r###"[ + { "name": "Name", "type": "string" }, + { "name": "Some Array", "type": "u128*" }, + { "name": "Some Object", "type": "My Object" }, + { "name": "Some Enum", "type": "enum", "contains": "My Enum" } +]"###; + + let def = serde_json::from_str::(raw).unwrap(); + match def { + TypeDefinition::Struct(struct_def) => { + assert_eq!(struct_def.fields.len(), 4); + assert_eq!(struct_def.fields[0].r#type, FullTypeReference::String); + assert_eq!( + struct_def.fields[1].r#type, + FullTypeReference::Array(ElementTypeReference::U128) + ); + assert_eq!( + struct_def.fields[2].r#type, + FullTypeReference::Object("My Object".into()) + ); + assert_eq!( + struct_def.fields[3].r#type, + FullTypeReference::Enum("My Enum".into()) + ); + } + _ => panic!("unexpected definition type"), + } + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_enum_def_deser() { + let raw = r###"[ + { "name": "Variant 1", "type": "()" }, + { "name": "Variant 2", "type": "(u128, u128*)" }, + { "name": "Variant N", "type": "(u128)" } +]"###; + + let def = serde_json::from_str::(raw).unwrap(); + match def { + TypeDefinition::Enum(enum_def) => { + assert_eq!(enum_def.variants.len(), 3); + assert_eq!(enum_def.variants[0].tuple_types, vec![]); + assert_eq!( + enum_def.variants[1].tuple_types, + vec![ + InlineTypeReference::U128, + InlineTypeReference::Array(ElementTypeReference::U128) + ] + ); + assert_eq!( + enum_def.variants[2].tuple_types, + vec![InlineTypeReference::U128] + ); + } + _ => panic!("unexpected definition type"), + } + } +} diff --git a/starknet-core/src/types/typed_data/type_reference.rs b/starknet-core/src/types/typed_data/type_reference.rs new file mode 100644 index 00000000..d34884a2 --- /dev/null +++ b/starknet-core/src/types/typed_data/type_reference.rs @@ -0,0 +1,448 @@ +use alloc::{borrow::ToOwned, format, string::*}; +use core::str::FromStr; + +use serde::{de::Visitor, Deserialize}; + +use super::error::TypedDataError; + +/// A full type reference is used for defining custom struct fields and enum variants. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum FullTypeReference { + /// Reference to a struct type. + Object(String), + /// Reference to an enum type. + Enum(String), + /// Reference to an array type. + Array(ElementTypeReference), + /// Reference to the basic type `felt`. + Felt, + /// Reference to the basic type `bool`. + Bool, + /// Reference to the basic type `string`. + String, + /// Reference to the basic type `selector`. + Selector, + /// Reference to the basic type `merkletree`. + MerkleTree(InlineTypeReference), + /// Reference to the basic type `u128`. + U128, + /// Reference to the basic type `i128`. + I128, + /// Reference to the basic type `ContractAddress`. + ContractAddress, + /// Reference to the basic type `ClassHash`. + ClassHash, + /// Reference to the basic type `timestamp`. + Timestamp, + /// Reference to the preset type `u256`. + U256, + /// Reference to the preset type `TokenAmount`. + TokenAmount, + /// Reference to the preset type `NftId`. + NftId, + /// Reference to the basic type `shortstring`. + ShortString, +} + +/// A type reference that can be canonically represented as a single string. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum InlineTypeReference { + /// Reference to a user-defined type. With an inline reference it's impossible to tell whether + /// the pointee is a struct or enum. + Custom(String), + /// Reference to an array type. + Array(ElementTypeReference), + /// Reference to the basic type `felt`. + Felt, + /// Reference to the basic type `bool`. + Bool, + /// Reference to the basic type `string`. + String, + /// Reference to the basic type `selector`. + Selector, + /// Reference to the basic type `u128`. + U128, + /// Reference to the basic type `i128`. + I128, + /// Reference to the basic type `ContractAddress`. + ContractAddress, + /// Reference to the basic type `ClassHash`. + ClassHash, + /// Reference to the basic type `timestamp`. + Timestamp, + /// Reference to the preset type `u256`. + U256, + /// Reference to the preset type `TokenAmount`. + TokenAmount, + /// Reference to the preset type `NftId`. + NftId, + /// Reference to the basic type `shortstring`. + ShortString, +} + +/// Reference to any type that can be used as array elements. +/// +/// This type is a strict subset of [`InlineTypeReference`]. +/// +/// SNIP-12 specifies that for an array: +/// +/// > The inner type could be any of the other types supported in this specification. +/// +/// Note the use of "other" here, implying that only one-dimensional arrays are supported. While +/// SNIP-12 does not have the most precise technical language, interpreting this way has the benefit +/// of avoiding unlimited nesting. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ElementTypeReference { + /// Reference to a user-defined type. With an inline reference it's impossible to tell whether + /// the pointee is a struct or enum. + Custom(String), + /// Reference to the basic type `felt`. + Felt, + /// Reference to the basic type `bool`. + Bool, + /// Reference to the basic type `string`. + String, + /// Reference to the basic type `selector`. + Selector, + /// Reference to the basic type `u128`. + U128, + /// Reference to the basic type `i128`. + I128, + /// Reference to the basic type `ContractAddress`. + ContractAddress, + /// Reference to the basic type `ClassHash`. + ClassHash, + /// Reference to the basic type `timestamp`. + Timestamp, + /// Reference to the preset type `u256`. + U256, + /// Reference to the preset type `TokenAmount`. + TokenAmount, + /// Reference to the preset type `NftId`. + NftId, + /// Reference to the basic type `shortstring`. + ShortString, +} + +/// An internal trait for working across the different type reference types defined above. +pub(crate) trait TypeReference { + /// Creates a common type reference representation useful for type transversal. + fn common(&self) -> CommonTypeReference<'_>; + + /// Gets the "canonical" string representation to be used in type signature encoding as field + /// type references. + fn signature_ref_repr(&self) -> String; + + /// Whether the referenced type must be a struct. + fn must_be_struct(&self) -> bool; + + /// Whether the referenced type must be an enum. + fn must_be_enum(&self) -> bool; +} + +/// An internal type reference type that can be created from all of: +/// +/// - [`FullTypeReference`] +/// - [`InlineTypeReference`] +/// - [`ElementTypeReference`] +/// +/// This type exists instead of just using [`FullTypeReference`] as when traversing user-defined +/// type definitions, only a type's name matters, not whether it's a struct or enum. +/// +/// It's *technically* possible to still use [`FullTypeReference`] as the common repr anyway, by +/// always using the [`FullTypeReference::Object`] variant. However, that would be far from ideal. +pub(crate) enum CommonTypeReference<'a> { + Custom(&'a str), + Array(&'a ElementTypeReference), + Felt, + Bool, + String, + Selector, + MerkleTree(&'a InlineTypeReference), + U128, + I128, + ContractAddress, + ClassHash, + Timestamp, + U256, + TokenAmount, + NftId, + ShortString, +} + +impl FullTypeReference { + pub(crate) fn from_parts( + r#type: String, + contains: Option, + ) -> Result { + Ok(match (r#type.as_str(), contains) { + ("felt", None) => Self::Felt, + ("bool", None) => Self::Bool, + ("string", None) => Self::String, + ("selector", None) => Self::Selector, + ("merkletree", Some(item)) => Self::MerkleTree(InlineTypeReference::from_str(&item)?), + ("u128", None) => Self::U128, + ("i128", None) => Self::I128, + ("ContractAddress", None) => Self::ContractAddress, + ("ClassHash", None) => Self::ClassHash, + ("timestamp", None) => Self::Timestamp, + ("u256", None) => Self::U256, + ("TokenAmount", None) => Self::TokenAmount, + ("NftId", None) => Self::NftId, + ("shortstring", None) => Self::ShortString, + ("enum", Some(enum_type)) => Self::Enum(enum_type), + (item, None) if item.ends_with('*') => Self::Array(ElementTypeReference::from_str( + &r#type[..(r#type.len() - 1)], + )?), + (type_name, None) if is_valid_type_name(type_name) => Self::Object(r#type), + (_, Some(_)) => { + return Err(TypedDataError::UnexpectedContainsField); + } + (type_name, _) => { + return Err(TypedDataError::InvalidTypeName(type_name.to_owned())); + } + }) + } +} + +impl TypeReference for FullTypeReference { + fn common(&self) -> CommonTypeReference<'_> { + match self { + Self::Object(name) | Self::Enum(name) => CommonTypeReference::Custom(name), + Self::Array(element) => CommonTypeReference::Array(element), + Self::Felt => CommonTypeReference::Felt, + Self::Bool => CommonTypeReference::Bool, + Self::String => CommonTypeReference::String, + Self::Selector => CommonTypeReference::Selector, + Self::MerkleTree(leaf) => CommonTypeReference::MerkleTree(leaf), + Self::U128 => CommonTypeReference::U128, + Self::I128 => CommonTypeReference::I128, + Self::ContractAddress => CommonTypeReference::ContractAddress, + Self::ClassHash => CommonTypeReference::ClassHash, + Self::Timestamp => CommonTypeReference::Timestamp, + Self::U256 => CommonTypeReference::U256, + Self::TokenAmount => CommonTypeReference::TokenAmount, + Self::NftId => CommonTypeReference::NftId, + Self::ShortString => CommonTypeReference::ShortString, + } + } + + fn signature_ref_repr(&self) -> String { + match self { + Self::Object(name) | Self::Enum(name) => name.to_owned(), + Self::Array(element) => format!("{}*", element.signature_ref_repr()), + Self::Felt => "felt".to_owned(), + Self::Bool => "bool".to_owned(), + Self::String => "string".to_owned(), + Self::Selector => "selector".to_owned(), + Self::MerkleTree(_) => "merkletree".to_owned(), + Self::U128 => "u128".to_owned(), + Self::I128 => "i128".to_owned(), + Self::ContractAddress => "ContractAddress".to_owned(), + Self::ClassHash => "ClassHash".to_owned(), + Self::Timestamp => "timestamp".to_owned(), + Self::U256 => "u256".to_owned(), + Self::TokenAmount => "TokenAmount".to_owned(), + Self::NftId => "NftId".to_owned(), + Self::ShortString => "shortstring".to_owned(), + } + } + + fn must_be_struct(&self) -> bool { + matches!(self, Self::Object(_)) + } + + fn must_be_enum(&self) -> bool { + matches!(self, Self::Enum(_)) + } +} + +impl TypeReference for InlineTypeReference { + fn common(&self) -> CommonTypeReference<'_> { + match self { + Self::Custom(name) => CommonTypeReference::Custom(name), + Self::Array(element) => CommonTypeReference::Array(element), + Self::Felt => CommonTypeReference::Felt, + Self::Bool => CommonTypeReference::Bool, + Self::String => CommonTypeReference::String, + Self::Selector => CommonTypeReference::Selector, + Self::U128 => CommonTypeReference::U128, + Self::I128 => CommonTypeReference::I128, + Self::ContractAddress => CommonTypeReference::ContractAddress, + Self::ClassHash => CommonTypeReference::ClassHash, + Self::Timestamp => CommonTypeReference::Timestamp, + Self::U256 => CommonTypeReference::U256, + Self::TokenAmount => CommonTypeReference::TokenAmount, + Self::NftId => CommonTypeReference::NftId, + Self::ShortString => CommonTypeReference::ShortString, + } + } + + fn signature_ref_repr(&self) -> String { + match self { + Self::Custom(name) => name.to_owned(), + Self::Array(element) => format!("{}*", element.signature_ref_repr()), + Self::Felt => "felt".to_owned(), + Self::Bool => "bool".to_owned(), + Self::String => "string".to_owned(), + Self::Selector => "selector".to_owned(), + Self::U128 => "u128".to_owned(), + Self::I128 => "i128".to_owned(), + Self::ContractAddress => "ContractAddress".to_owned(), + Self::ClassHash => "ClassHash".to_owned(), + Self::Timestamp => "timestamp".to_owned(), + Self::U256 => "u256".to_owned(), + Self::TokenAmount => "TokenAmount".to_owned(), + Self::NftId => "NftId".to_owned(), + Self::ShortString => "shortstring".to_owned(), + } + } + + fn must_be_struct(&self) -> bool { + false + } + + fn must_be_enum(&self) -> bool { + false + } +} + +impl TypeReference for ElementTypeReference { + fn common(&self) -> CommonTypeReference<'_> { + match self { + Self::Custom(name) => CommonTypeReference::Custom(name), + Self::Felt => CommonTypeReference::Felt, + Self::Bool => CommonTypeReference::Bool, + Self::String => CommonTypeReference::String, + Self::Selector => CommonTypeReference::Selector, + Self::U128 => CommonTypeReference::U128, + Self::I128 => CommonTypeReference::I128, + Self::ContractAddress => CommonTypeReference::ContractAddress, + Self::ClassHash => CommonTypeReference::ClassHash, + Self::Timestamp => CommonTypeReference::Timestamp, + Self::U256 => CommonTypeReference::U256, + Self::TokenAmount => CommonTypeReference::TokenAmount, + Self::NftId => CommonTypeReference::NftId, + Self::ShortString => CommonTypeReference::ShortString, + } + } + + fn signature_ref_repr(&self) -> String { + match self { + Self::Custom(name) => name.to_owned(), + Self::Felt => "felt".to_owned(), + Self::Bool => "bool".to_owned(), + Self::String => "string".to_owned(), + Self::Selector => "selector".to_owned(), + Self::U128 => "u128".to_owned(), + Self::I128 => "i128".to_owned(), + Self::ContractAddress => "ContractAddress".to_owned(), + Self::ClassHash => "ClassHash".to_owned(), + Self::Timestamp => "timestamp".to_owned(), + Self::U256 => "u256".to_owned(), + Self::TokenAmount => "TokenAmount".to_owned(), + Self::NftId => "NftId".to_owned(), + Self::ShortString => "shortstring".to_owned(), + } + } + + fn must_be_struct(&self) -> bool { + false + } + + fn must_be_enum(&self) -> bool { + false + } +} + +impl FromStr for InlineTypeReference { + type Err = TypedDataError; + + fn from_str(s: &str) -> Result { + // Same as `FullTypeReference::from_parts` except cases involving `contains`. + Ok(match s { + "felt" => Self::Felt, + "bool" => Self::Bool, + "string" => Self::String, + "selector" => Self::Selector, + "u128" => Self::U128, + "i128" => Self::I128, + "ContractAddress" => Self::ContractAddress, + "ClassHash" => Self::ClassHash, + "timestamp" => Self::Timestamp, + "u256" => Self::U256, + "TokenAmount" => Self::TokenAmount, + "NftId" => Self::NftId, + "shortstring" => Self::ShortString, + item if item.ends_with('*') => { + Self::Array(ElementTypeReference::from_str(&s[..(s.len() - 1)])?) + } + type_name if is_valid_type_name(type_name) => Self::Custom(s.to_owned()), + type_name => { + return Err(TypedDataError::InvalidTypeName(type_name.to_owned())); + } + }) + } +} + +impl FromStr for ElementTypeReference { + type Err = TypedDataError; + + fn from_str(s: &str) -> Result { + // Same as `InlineTypeReference::from_parts` except the array case. + Ok(match s { + "felt" => Self::Felt, + "bool" => Self::Bool, + "string" => Self::String, + "selector" => Self::Selector, + "u128" => Self::U128, + "i128" => Self::I128, + "ContractAddress" => Self::ContractAddress, + "ClassHash" => Self::ClassHash, + "timestamp" => Self::Timestamp, + "u256" => Self::U256, + "TokenAmount" => Self::TokenAmount, + "NftId" => Self::NftId, + "shortstring" => Self::ShortString, + type_name if is_valid_type_name(type_name) => Self::Custom(s.to_owned()), + type_name => { + return Err(TypedDataError::InvalidTypeName(type_name.to_owned())); + } + }) + } +} + +struct InlineTypeReferenceVisitor; + +impl Visitor<'_> for InlineTypeReferenceVisitor { + type Value = InlineTypeReference; + + fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(formatter, "string") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + InlineTypeReference::from_str(v).map_err(|err| { + serde::de::Error::custom(format!("invalid inline type reference: {}", err)) + }) + } +} + +impl<'de> Deserialize<'de> for InlineTypeReference { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_str(InlineTypeReferenceVisitor) + } +} + +fn is_valid_type_name(type_name: &str) -> bool { + !(type_name.is_empty() + || type_name.contains(',') + || type_name.contains('(') + || type_name.contains(')')) +} diff --git a/starknet-core/src/types/typed_data/types.rs b/starknet-core/src/types/typed_data/types.rs new file mode 100644 index 00000000..848c147e --- /dev/null +++ b/starknet-core/src/types/typed_data/types.rs @@ -0,0 +1,371 @@ +use alloc::{borrow::ToOwned, collections::BTreeMap, string::*}; + +use indexmap::IndexMap; +use serde::Deserialize; + +use crate::{ + types::{typed_data::CommonTypeReference, Felt}, + utils::starknet_keccak, +}; + +use super::{ + error::TypedDataError, + revision::Revision, + type_definition::{PresetType, TypeDefinition}, + TypeReference, +}; + +#[cfg(feature = "std")] +type RandomState = std::hash::RandomState; +#[cfg(not(feature = "std"))] +type RandomState = foldhash::fast::RandomState; + +const DOMAIN_TYPE_NAME_V0: &str = "StarkNetDomain"; +const DOMAIN_TYPE_NAME_V1: &str = "StarknetDomain"; + +/// The user-defined types section of a SNIP-12 [`TypedData`](super::TypedData) instance. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Types { + /// The SNIP-12 revision as inferred from the supplied domain type definition. The actual + /// domain type definition is not stored as it's universal for all instances. + revision: Revision, + /// Definitions for types other than the domain type. + user_defined_types: IndexMap, +} + +enum SignatureGenerator<'a> { + UserDefinedType(&'a TypeDefinition), + PresetType(&'static PresetType), +} + +impl Types { + /// Gets the revision implied from the definition of the domain type. + /// + /// Returns [`Revision::V0`] if and only if only `StarkNetDomain` is defined. + /// + /// Returns [`Revision::V1`] if and only if only `StarknetDomain` is defined. + pub const fn revision(&self) -> Revision { + self.revision + } + + /// Gets the type definition by name. Returns `None` if not defined. + pub fn get_type(&self, type_name: &str) -> Option<&TypeDefinition> { + self.user_defined_types.get(type_name) + } + + /// Gets the SNIP-12 type hash of a user-defined type. + /// + /// Returns `Err` if the type or any of its dependancies are not defined. + pub fn get_type_hash(&self, type_name: &str) -> Result { + let type_def = self + .get_type(type_name) + .ok_or_else(|| TypedDataError::CustomTypeNotFound(type_name.to_owned()))?; + + let mut full_signature = String::new(); + SignatureGenerator::UserDefinedType(type_def).write_signature( + type_name, + &mut full_signature, + self.revision(), + ); + + let mut dependency_signatures: BTreeMap<&str, SignatureGenerator<'_>> = BTreeMap::new(); + self.collect_dep_sigs_from_type_def(&mut dependency_signatures, type_def)?; + + for (name, sig) in dependency_signatures { + sig.write_signature(name, &mut full_signature, self.revision()); + } + + Ok(starknet_keccak(full_signature.as_bytes())) + } + + fn collect_dep_sigs_from_type_ref<'a, R>( + &'a self, + signatures: &mut BTreeMap<&'a str, SignatureGenerator<'a>>, + type_ref: &'a R, + ) -> Result<(), TypedDataError> + where + R: TypeReference, + { + #[allow(clippy::match_same_arms)] + match type_ref.common() { + CommonTypeReference::Custom(name) => { + let type_def = self + .get_type(name) + .ok_or_else(|| TypedDataError::CustomTypeNotFound(name.to_owned()))?; + + // No need to advance further if the type has already been visited + if signatures + .insert(name, SignatureGenerator::UserDefinedType(type_def)) + .is_none() + { + self.collect_dep_sigs_from_type_def(signatures, type_def)?; + } + } + CommonTypeReference::Array(element) => { + self.collect_dep_sigs_from_type_ref(signatures, element)?; + } + CommonTypeReference::MerkleTree(_) => { + // SNIP-12 is a bit vague on whether the leaf type here should be collected as + // dependency, as it's unclear whether `merkletree`'s leaf type counts as being + // "referenced" by the parent type, given that `merkletree`'s own type encoding does + // not include any information of the leaf type. + // + // Since the `starknet.js` implementation discards the leaf type,, we do the same + // here to be compatible. + } + // Preset types + CommonTypeReference::U256 => { + signatures.insert( + PresetType::U256.name(), + SignatureGenerator::PresetType(&PresetType::U256), + ); + } + CommonTypeReference::TokenAmount => { + signatures.insert( + PresetType::TokenAmount.name(), + SignatureGenerator::PresetType(&PresetType::TokenAmount), + ); + + // `TokenAmount` depends on `u256` + signatures.insert( + PresetType::U256.name(), + SignatureGenerator::PresetType(&PresetType::U256), + ); + } + CommonTypeReference::NftId => { + signatures.insert( + PresetType::NftId.name(), + SignatureGenerator::PresetType(&PresetType::NftId), + ); + + // `NftId` depends on `u256` + signatures.insert( + PresetType::U256.name(), + SignatureGenerator::PresetType(&PresetType::U256), + ); + } + // Basic types. Nothing to collect. + CommonTypeReference::Felt + | CommonTypeReference::Bool + | CommonTypeReference::String + | CommonTypeReference::Selector + | CommonTypeReference::U128 + | CommonTypeReference::I128 + | CommonTypeReference::ContractAddress + | CommonTypeReference::ClassHash + | CommonTypeReference::Timestamp + | CommonTypeReference::ShortString => {} + } + + Ok(()) + } + + fn collect_dep_sigs_from_type_def<'a>( + &'a self, + signatures: &mut BTreeMap<&'a str, SignatureGenerator<'a>>, + type_def: &'a TypeDefinition, + ) -> Result<(), TypedDataError> { + match type_def { + TypeDefinition::Struct(struct_def) => { + for field in &struct_def.fields { + self.collect_dep_sigs_from_type_ref(signatures, &field.r#type)?; + } + } + TypeDefinition::Enum(enum_def) => { + for variant in &enum_def.variants { + for tuple_type in &variant.tuple_types { + self.collect_dep_sigs_from_type_ref(signatures, tuple_type)?; + } + } + } + } + + Ok(()) + } +} + +impl SignatureGenerator<'_> { + fn write_signature(&self, name: &str, signature: &mut String, revision: Revision) { + match self { + Self::UserDefinedType(TypeDefinition::Struct(struct_def)) => { + Self::write_escaped_name(name, signature, revision); + signature.push('('); + + let mut field_iter = struct_def.fields.iter().peekable(); + while let Some(field) = field_iter.next() { + Self::write_escaped_name(&field.name, signature, revision); + signature.push(':'); + Self::write_escaped_name( + &field.r#type.signature_ref_repr(), + signature, + revision, + ); + + if field_iter.peek().is_some() { + signature.push(','); + }; + } + + signature.push(')'); + } + Self::UserDefinedType(TypeDefinition::Enum(enum_def)) => { + Self::write_escaped_name(name, signature, revision); + signature.push('('); + + let mut variant_iter = enum_def.variants.iter().peekable(); + while let Some(variant) = variant_iter.next() { + Self::write_escaped_name(&variant.name, signature, revision); + signature.push('('); + + let mut tuple_type_iter = variant.tuple_types.iter().peekable(); + while let Some(tuple_type) = tuple_type_iter.next() { + Self::write_escaped_name( + &tuple_type.signature_ref_repr(), + signature, + revision, + ); + if tuple_type_iter.peek().is_some() { + signature.push(',') + }; + } + + signature.push_str(if variant_iter.peek().is_some() { + ")," + } else { + ")" + }); + } + + signature.push(')'); + } + Self::PresetType(preset) => { + signature.push_str(preset.type_signature(revision)); + } + } + } + + fn write_escaped_name(name: &str, signature: &mut String, revision: Revision) { + match revision { + Revision::V0 => { + signature.push_str(name); + } + Revision::V1 => { + // TODO: check if this can ever fail + signature.push_str(&serde_json::to_string(name).unwrap()); + } + } + } +} + +impl<'de> Deserialize<'de> for Types { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let mut raw = IndexMap::::deserialize(deserializer)?; + + if let Some(domain_v1) = raw.shift_remove(DOMAIN_TYPE_NAME_V1) { + if raw.contains_key(DOMAIN_TYPE_NAME_V0) { + Err(serde::de::Error::custom( + "conflicting domain type definitions", + )) + } else if !domain_v1.is_v1_domain() { + Err(serde::de::Error::custom( + "invalid domain type definition for revision 1", + )) + } else { + Ok(Self { + revision: Revision::V1, + user_defined_types: raw, + }) + } + } else if let Some(domain_v0) = raw.shift_remove(DOMAIN_TYPE_NAME_V0) { + if domain_v0.is_v0_domain() { + Ok(Self { + revision: Revision::V0, + user_defined_types: raw, + }) + } else { + Err(serde::de::Error::custom( + "invalid domain type definition for revision 0", + )) + } + } else { + Err(serde::de::Error::custom("missing domain type definition")) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const VALID_V1_DATA: &str = r###"{ + "StarknetDomain": [ + { "name": "name", "type": "shortstring" }, + { "name": "version", "type": "shortstring" }, + { "name": "chainId", "type": "shortstring" }, + { "name": "revision", "type": "shortstring" } + ], + "Example Message": [ + { "name": "Name", "type": "string" }, + { "name": "Some Array", "type": "u128*" }, + { "name": "Some Object", "type": "My Object" } + ], + "My Object": [ + { "name": "Some Selector", "type": "selector" }, + { "name": "Some Contract Address", "type": "ContractAddress" } + ] +}"###; + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_revision_0_deser() { + let raw = r###"{ + "StarkNetDomain": [ + { "name": "name", "type": "felt" }, + { "name": "version", "type": "felt" }, + { "name": "chainId", "type": "felt" } + ], + "Example Message": [ + { "name": "Name", "type": "string" }, + { "name": "Some Array", "type": "u128*" }, + { "name": "Some Object", "type": "My Object" } + ], + "My Object": [ + { "name": "Some Selector", "type": "selector" }, + { "name": "Some Contract Address", "type": "ContractAddress" } + ] +}"###; + + let types = serde_json::from_str::(raw).unwrap(); + assert_eq!(types.revision, Revision::V0); + assert_eq!(types.user_defined_types.len(), 2); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_revision_1_deser() { + let types = serde_json::from_str::(VALID_V1_DATA).unwrap(); + assert_eq!(types.revision, Revision::V1); + assert_eq!(types.user_defined_types.len(), 2); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_revision_1_type_hash() { + let types = serde_json::from_str::(VALID_V1_DATA).unwrap(); + assert_eq!( + types.get_type_hash("Example Message").unwrap(), + Felt::from_hex_unchecked( + "0x01ef2892585a840aee9165aac7aaf811ba2f8619e43c119bd76a6109f81cecc3" + ) + ); + assert_eq!( + types.get_type_hash("My Object").unwrap(), + Felt::from_hex_unchecked( + "0x02f0ee9d399d4e7ccbc5d7e96df767296cc4b8a516600c121b393427ae3779f2" + ) + ); + } +} diff --git a/starknet-core/src/types/typed_data/value.rs b/starknet-core/src/types/typed_data/value.rs new file mode 100644 index 00000000..dfcec7e8 --- /dev/null +++ b/starknet-core/src/types/typed_data/value.rs @@ -0,0 +1,229 @@ +use alloc::{borrow::ToOwned, string::*, vec::*}; + +use indexmap::IndexMap; +use serde::{de::Visitor, Deserialize}; + +#[cfg(feature = "std")] +type RandomState = std::hash::RandomState; +#[cfg(not(feature = "std"))] +type RandomState = foldhash::fast::RandomState; + +const DEFAULT_INDEXMAP_CAPACITY: usize = 5; + +/// The primitive representation of the SNIP-12 message value. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Value { + /// String value. + String(String), + /// Unsigned integer value. + UnsignedInteger(u128), + /// Signed integer value. + SignedInteger(i128), + /// Boolean value. + Boolean(bool), + /// Map value. + Object(ObjectValue), + /// Sequence value. + Array(ArrayValue), +} + +/// A map/object value for SNIP-12 message representation. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ObjectValue { + /// Fields of the object. + pub fields: IndexMap, +} + +/// A sequence/array value for SNIP-12 message representation. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ArrayValue { + /// Elements of the array. + pub elements: Vec, +} + +/// The unit enum for identifying [`Value`] variants. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ValueKind { + /// String value. + String, + /// Unsigned integer value. + UnsignedInteger, + /// Signed integer value. + SignedInteger, + /// Boolean value. + Boolean, + /// Map value. + Object, + /// Sequence value. + Array, +} + +impl Value { + /// Gets the type of value. + pub const fn kind(&self) -> ValueKind { + match self { + Self::String(_) => ValueKind::String, + Self::UnsignedInteger(_) => ValueKind::UnsignedInteger, + Self::SignedInteger(_) => ValueKind::SignedInteger, + Self::Boolean(_) => ValueKind::Boolean, + Self::Object(_) => ValueKind::Object, + Self::Array(_) => ValueKind::Array, + } + } +} + +struct ValueVisitor; + +impl<'de> Visitor<'de> for ValueVisitor { + type Value = Value; + + fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(formatter, "integer, string, map or sequence") + } + + fn visit_bool(self, v: bool) -> Result + where + E: serde::de::Error, + { + Ok(Value::Boolean(v)) + } + + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + Ok(Value::UnsignedInteger(v.into())) + } + + fn visit_u128(self, v: u128) -> Result + where + E: serde::de::Error, + { + Ok(Value::UnsignedInteger(v)) + } + + fn visit_i64(self, v: i64) -> Result + where + E: serde::de::Error, + { + Ok(Value::SignedInteger(v.into())) + } + + fn visit_i128(self, v: i128) -> Result + where + E: serde::de::Error, + { + Ok(Value::SignedInteger(v)) + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + Ok(Value::String(v.to_owned())) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut elements = Vec::new(); + while let Some(element) = seq.next_element()? { + elements.push(element); + } + Ok(Value::Array(ArrayValue { elements })) + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut fields = + IndexMap::with_capacity_and_hasher(DEFAULT_INDEXMAP_CAPACITY, Default::default()); + while let Some((key, value)) = map.next_entry()? { + fields.insert(key, value); + } + Ok(Value::Object(ObjectValue { fields })) + } +} + +impl<'de> Deserialize<'de> for Value { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(ValueVisitor) + } +} + +impl core::fmt::Display for ValueKind { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::String => write!(f, "string"), + Self::UnsignedInteger => write!(f, "unsigned_integer"), + Self::SignedInteger => write!(f, "signed_integer"), + Self::Boolean => write!(f, "boolean"), + Self::Object => write!(f, "object"), + Self::Array => write!(f, "array"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_value_deser() { + let raw = r###"{ + "Name": "some name", + "Some Array": [1, 2, 3, 4], + "Some Object": { + "Some Selector": "transfer", + "Some Contract Address": "0x0123" + } +}"###; + + let value = serde_json::from_str::(raw).unwrap(); + + match value { + Value::Object(value) => { + assert_eq!(value.fields.len(), 3); + assert_eq!( + value.fields.get("Name").unwrap(), + &Value::String("some name".into()) + ); + assert_eq!( + value.fields.get("Some Array").unwrap(), + &Value::Array(ArrayValue { + elements: vec![ + Value::UnsignedInteger(1), + Value::UnsignedInteger(2), + Value::UnsignedInteger(3), + Value::UnsignedInteger(4), + ] + }) + ); + assert_eq!( + value.fields.get("Some Object").unwrap(), + &Value::Object(ObjectValue { + fields: [ + ( + String::from("Some Selector"), + Value::String("transfer".into()) + ), + ( + String::from("Some Contract Address"), + Value::String("0x0123".into()) + ), + ] + .into_iter() + .collect() + }) + ); + } + _ => panic!("unexpected value type"), + } + } +}