diff --git a/Cargo.lock b/Cargo.lock index 18cab70b21..b07a69e6cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -808,6 +808,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0" dependencies = [ "const-oid", + "pem-rfc7468", "zeroize", ] @@ -1977,6 +1978,7 @@ dependencies = [ "rand_chacha 0.3.1", "rand_core 0.6.4", "regex", + "rsa", "serde", "serde_bytes", "serde_cbor", @@ -2134,6 +2136,9 @@ name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +dependencies = [ + "spin", +] [[package]] name = "leb128" @@ -2147,6 +2152,12 @@ version = "0.2.158" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" +[[package]] +name = "libm" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" + [[package]] name = "libredox" version = "0.1.3" @@ -2339,6 +2350,23 @@ dependencies = [ "serde", ] +[[package]] +name = "num-bigint-dig" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc84195820f291c7697304f3cbdadd1cb7199c0efc917ff5eafd71225c136151" +dependencies = [ + "byteorder", + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand 0.8.5", + "smallvec", + "zeroize", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -2354,6 +2382,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg 1.3.0", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -2361,6 +2400,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg 1.3.0", + "libm", ] [[package]] @@ -2449,6 +2489,15 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -2512,6 +2561,17 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der", + "pkcs8", + "spki", +] + [[package]] name = "pkcs8" version = "0.10.2" @@ -2987,6 +3047,26 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rsa" +version = "0.9.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47c75d7c5c6b673e58bf54d8544a9f432e3a925b0e80f7cd3602ab5c50c55519" +dependencies = [ + "const-oid", + "digest 0.10.7", + "num-bigint-dig", + "num-integer", + "num-traits", + "pkcs1", + "pkcs8", + "rand_core 0.6.4", + "signature", + "spki", + "subtle", + "zeroize", +] + [[package]] name = "rustc-demangle" version = "0.1.24" diff --git a/Cargo.toml b/Cargo.toml index d3fa406327..303bf79027 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,3 +48,4 @@ serde = "1" serde_bytes = "0.11" serde_cbor = "0.11" sha2 = "0.10" +rsa = "0.9.7" diff --git a/dfx.json b/dfx.json index c2c2c01ba8..6e86ee0410 100644 --- a/dfx.json +++ b/dfx.json @@ -6,7 +6,7 @@ "wasm": "internet_identity.wasm.gz", "build": "bash -c 'II_DEV_CSP=1 II_FETCH_ROOT_KEY=1 II_DUMMY_CAPTCHA=${II_DUMMY_CAPTCHA:-1} scripts/build'", "init_arg": "(opt record { captcha_config = opt record { max_unsolved_captchas= 50:nat64; captcha_trigger = variant {Static = variant {CaptchaDisabled}}}})", - "shrink" : false + "shrink": false }, "test_app": { "type": "custom", @@ -20,7 +20,7 @@ "wasm": "demos/vc_issuer/vc_demo_issuer.wasm.gz", "build": "demos/vc_issuer/build.sh", "post_install": "bash -c 'demos/vc_issuer/provision'", - "dependencies": [ "internet_identity" ] + "dependencies": ["internet_identity"] } }, "defaults": { diff --git a/package.json b/package.json index 02f7ae2d09..9ba44988de 100644 --- a/package.json +++ b/package.json @@ -7,8 +7,8 @@ "private": true, "license": "SEE LICENSE IN LICENSE.md", "scripts": { - "dev": "II_FETCH_ROOT_KEY=1 II_DUMMY_CAPTCHA=1 II_OPENID_GOOGLE_CLIENT_ID=45431994619-cbbfgtn7o0pp0dpfcg2l66bc4rcg7qbu.apps.googleusercontent.com vite", - "host": "II_FETCH_ROOT_KEY=1 II_DUMMY_CAPTCHA=1 II_OPENID_GOOGLE_CLIENT_ID=45431994619-cbbfgtn7o0pp0dpfcg2l66bc4rcg7qbu.apps.googleusercontent.com vite --host", + "dev": "II_FETCH_ROOT_KEY=1 II_DUMMY_CAPTCHA=1 II_OPENID_GOOGLE_CLIENT_ID=\"45431994619-cbbfgtn7o0pp0dpfcg2l66bc4rcg7qbu.apps.googleusercontent.com\" vite", + "host": "II_FETCH_ROOT_KEY=1 II_DUMMY_CAPTCHA=1 II_OPENID_GOOGLE_CLIENT_ID=\"45431994619-cbbfgtn7o0pp0dpfcg2l66bc4rcg7qbu.apps.googleusercontent.com\" vite --host", "showcase": "astro dev --root ./src/showcase", "build": "tsc --noEmit && vite build", "check": "tsc --project ./tsconfig.all.json --noEmit", diff --git a/src/frontend/generated/internet_identity_idl.js b/src/frontend/generated/internet_identity_idl.js index fb954bd217..a6e4fe832a 100644 --- a/src/frontend/generated/internet_identity_idl.js +++ b/src/frontend/generated/internet_identity_idl.js @@ -31,6 +31,7 @@ export const idlFactory = ({ IDL }) => { 'canister_creation_cycles_cost' : IDL.Opt(IDL.Nat64), 'related_origins' : IDL.Opt(IDL.Vec(IDL.Text)), 'captcha_config' : IDL.Opt(CaptchaConfig), + 'openid_google_client_id' : IDL.Opt(IDL.Text), 'register_rate_limit' : IDL.Opt(RateLimitConfig), }); const UserNumber = IDL.Nat64; @@ -561,6 +562,7 @@ export const init = ({ IDL }) => { 'canister_creation_cycles_cost' : IDL.Opt(IDL.Nat64), 'related_origins' : IDL.Opt(IDL.Vec(IDL.Text)), 'captcha_config' : IDL.Opt(CaptchaConfig), + 'openid_google_client_id' : IDL.Opt(IDL.Text), 'register_rate_limit' : IDL.Opt(RateLimitConfig), }); return [IDL.Opt(InternetIdentityInit)]; diff --git a/src/frontend/generated/internet_identity_types.d.ts b/src/frontend/generated/internet_identity_types.d.ts index 89f7567c0d..b5bd09031e 100644 --- a/src/frontend/generated/internet_identity_types.d.ts +++ b/src/frontend/generated/internet_identity_types.d.ts @@ -205,6 +205,7 @@ export interface InternetIdentityInit { 'canister_creation_cycles_cost' : [] | [bigint], 'related_origins' : [] | [Array], 'captcha_config' : [] | [CaptchaConfig], + 'openid_google_client_id' : [] | [string], 'register_rate_limit' : [] | [RateLimitConfig], } export interface InternetIdentityStats { diff --git a/src/internet_identity/Cargo.toml b/src/internet_identity/Cargo.toml index 736ac96da9..d488893930 100644 --- a/src/internet_identity/Cargo.toml +++ b/src/internet_identity/Cargo.toml @@ -14,8 +14,9 @@ serde.workspace = true serde_bytes.workspace = true serde_cbor.workspace = true serde_json = { version = "1.0", default-features = false, features = ["std"] } -sha2.workspace = true +sha2 = { workspace = true, features = ["oid"]} base64.workspace = true +rsa.workspace = true # Captcha deps lodepng = "*" diff --git a/src/internet_identity/internet_identity.did b/src/internet_identity/internet_identity.did index bdaf682540..978eac4570 100644 --- a/src/internet_identity/internet_identity.did +++ b/src/internet_identity/internet_identity.did @@ -259,6 +259,8 @@ type InternetIdentityInit = record { captcha_config: opt CaptchaConfig; // Configuration for Related Origins Requests related_origins: opt vec text; + // Configuration for OpenID Google client + openid_google_client_id: opt text; }; type ChallengeKey = text; diff --git a/src/internet_identity/src/main.rs b/src/internet_identity/src/main.rs index 3bd9ee59e8..fb5fa036fa 100644 --- a/src/internet_identity/src/main.rs +++ b/src/internet_identity/src/main.rs @@ -348,6 +348,7 @@ fn config() -> InternetIdentityInit { register_rate_limit: Some(persistent_state.registration_rate_limit.clone()), captcha_config: Some(persistent_state.captcha_config.clone()), related_origins: persistent_state.related_origins.clone(), + openid_google_client_id: persistent_state.openid_google_client_id.clone(), }) } @@ -387,15 +388,20 @@ fn post_upgrade(maybe_arg: Option) { } fn initialize(maybe_arg: Option) { - let state_related_origins = state::persistent_state(|storage| storage.related_origins.clone()); - let related_origins = maybe_arg - .clone() - .map(|arg| arg.related_origins) - .unwrap_or(state_related_origins); + let related_origins = maybe_arg.as_ref().map_or_else( + || persistent_state(|storage| storage.related_origins.clone()), + |arg| arg.related_origins.clone(), + ); + let openid_google_client_id = maybe_arg.as_ref().map_or_else( + || persistent_state(|storage| storage.openid_google_client_id.clone()), + |arg| arg.openid_google_client_id.clone(), + ); init_assets(related_origins); apply_install_arg(maybe_arg); update_root_hash(); - openid::setup_timers(); + if let Some(client_id) = openid_google_client_id { + openid::setup_google(client_id); + } } fn apply_install_arg(maybe_arg: Option) { @@ -428,6 +434,11 @@ fn apply_install_arg(maybe_arg: Option) { persistent_state.related_origins = Some(related_origins); }) } + if let Some(openid_google_client_id) = arg.openid_google_client_id { + state::persistent_state_mut(|persistent_state| { + persistent_state.openid_google_client_id = Some(openid_google_client_id); + }) + } } } diff --git a/src/internet_identity/src/openid.rs b/src/internet_identity/src/openid.rs index 3bab9c8d79..565411706c 100644 --- a/src/internet_identity/src/openid.rs +++ b/src/internet_identity/src/openid.rs @@ -1,5 +1,125 @@ +use candid::{Deserialize, Principal}; +use identity_jose::jws::Decoder; +use internet_identity_interface::internet_identity::types::{MetadataEntryV2, Timestamp}; +use std::cell::RefCell; +use std::collections::HashMap; + mod google; -pub fn setup_timers() { - google::setup_timers(); +#[derive(Debug, PartialEq)] +pub struct OpenIdCredential { + pub iss: String, + pub sub: String, + pub aud: String, + pub principal: Principal, + pub last_usage_timestamp: Timestamp, + pub metadata: HashMap, +} + +trait OpenIdProvider { + fn issuer(&self) -> &'static str; + + fn verify(&self, jwt: &str, salt: &[u8; 32]) -> Result; +} + +#[derive(Deserialize)] +struct PartialClaims { + iss: String, +} + +thread_local! { + static OPEN_ID_PROVIDERS: RefCell>> = RefCell::new(vec![]); +} + +pub fn setup_google(client_id: String) { + OPEN_ID_PROVIDERS + .with_borrow_mut(|providers| providers.push(Box::new(google::Provider::create(client_id)))); +} + +#[allow(unused)] +pub fn verify(jwt: &str, salt: &[u8; 32]) -> Result { + let validation_item = Decoder::new() + .decode_compact_serialization(jwt.as_bytes(), None) + .map_err(|_| "Failed to decode JWT")?; + let claims: PartialClaims = + serde_json::from_slice(validation_item.claims()).map_err(|_| "Unable to decode claims")?; + + OPEN_ID_PROVIDERS.with_borrow(|providers| { + match providers + .iter() + .find(|provider| provider.issuer() == claims.iss) + { + Some(provider) => provider.verify(jwt, salt), + None => Err(format!("Unsupported issuer: {}", claims.iss)), + } + }) +} + +#[cfg(test)] +struct ExampleProvider; + +#[cfg(test)] +impl OpenIdProvider for ExampleProvider { + fn issuer(&self) -> &'static str { + "https://example.com" + } + + fn verify(&self, _: &str, _: &[u8; 32]) -> Result { + Ok(self.credential()) + } +} + +#[cfg(test)] +impl ExampleProvider { + fn credential(&self) -> OpenIdCredential { + OpenIdCredential { + iss: self.issuer().into(), + sub: "example-sub".into(), + aud: "example-aud".into(), + principal: Principal::anonymous(), + last_usage_timestamp: 0, + metadata: HashMap::new(), + } + } +} + +#[test] +fn should_return_credential() { + let provider = ExampleProvider {}; + let credential = provider.credential(); + OPEN_ID_PROVIDERS.replace(vec![Box::new(provider)]); + let jwt = "eyJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJodHRwczovL2V4YW1wbGUuY29tIn0.SBeD7pV65F98wStsBuC_VRn-yjLoyf6iojJl9Y__wN0"; + + assert_eq!(verify(jwt, &[0u8; 32]), Ok(credential)); +} + +#[test] +fn should_return_error_unsupported_issuer() { + OPEN_ID_PROVIDERS.replace(vec![]); + let jwt = "eyJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJodHRwczovL2V4YW1wbGUuY29tIn0.SBeD7pV65F98wStsBuC_VRn-yjLoyf6iojJl9Y__wN0"; + + assert_eq!( + verify(jwt, &[0u8; 32]), + Err("Unsupported issuer: https://example.com".into()) + ); +} + +#[test] +fn should_return_error_when_encoding_invalid() { + let invalid_jwt = "invalid-jwt"; + + assert_eq!( + verify(invalid_jwt, &[0u8; 32]), + Err("Failed to decode JWT".to_string()) + ); +} + +#[test] +fn should_return_error_when_claims_invalid() { + let jwt_without_issuer = "eyJhbGciOiJIUzI1NiJ9.e30.ZRrHA1JJJW8opsbCGfG_HACGpVUMN_a9IV7pAx_Zmeo"; + + assert_eq!( + verify(jwt_without_issuer, &[0u8; 32]), + Err("Unable to decode claims".to_string()) + ); } diff --git a/src/internet_identity/src/openid/google.rs b/src/internet_identity/src/openid/google.rs index 4d56eb9590..5b50c67d36 100644 --- a/src/internet_identity/src/openid/google.rs +++ b/src/internet_identity/src/openid/google.rs @@ -1,59 +1,165 @@ +use crate::openid::OpenIdCredential; +use crate::openid::OpenIdProvider; +use base64::prelude::BASE64_URL_SAFE_NO_PAD; +use base64::Engine; +use candid::Principal; use candid::{Deserialize, Nat}; -use ic_cdk::api::management_canister::http_request::{ - http_request_with_closure, CanisterHttpRequestArgument, HttpHeader, HttpMethod, HttpResponse, +use ic_cdk::api::management_canister::http_request::{HttpHeader, HttpResponse}; +use ic_cdk::trap; +use ic_stable_structures::Storable; +use identity_jose::jwk::{Jwk, JwkParamsRsa}; +use identity_jose::jws::JwsAlgorithm::RS256; +use identity_jose::jws::{ + Decoder, JwsVerifierFn, SignatureVerificationError, SignatureVerificationErrorKind, + VerificationInput, }; -use ic_cdk::{spawn, trap}; -use ic_cdk_timers::set_timer; -use identity_jose::jwk::Jwk; +use internet_identity_interface::internet_identity::types::MetadataEntryV2; +use rsa::{Pkcs1v15Sign, RsaPublicKey}; use serde::Serialize; +use sha2::{Digest, Sha256}; +#[cfg(test)] +use std::cell::Cell; use std::cell::RefCell; -use std::cmp::min; +use std::collections::HashMap; use std::convert::Into; -use std::time::Duration; +use std::rc::Rc; +const ISSUER: &str = "https://accounts.google.com"; + +#[cfg(not(test))] const CERTS_URL: &str = "https://www.googleapis.com/oauth2/v3/certs"; // The amount of cycles needed to make the HTTP outcall with a large enough margin +#[cfg(not(test))] const CERTS_CALL_CYCLES: u128 = 30_000_000_000; const HTTP_STATUS_OK: u8 = 200; // Fetch the Google certs every hour, the responses are always // valid for at least 5 hours so that should be enough margin. -const FETCH_CERTS_INTERVAL: u64 = 60 * 60; +#[cfg(not(test))] +const FETCH_CERTS_INTERVAL: u64 = 60 * 60; // 1 hour in seconds + +const NANOSECONDS_PER_SECOND: u64 = 1_000_000_000; + +// A JWT is only valid for a very small window, even if the JWT itself says it's valid for longer, +// we only need it right after it's being issued to create a JWT delegation with its own expiry. +const MAX_VALIDITY_WINDOW: u64 = 5 * 60 * NANOSECONDS_PER_SECOND; // 5 minutes in nanos, same as ingress expiry #[derive(Serialize, Deserialize)] -struct GoogleCerts { +struct Certs { keys: Vec, } -thread_local! { - static CERTS: RefCell> = const { RefCell::new(vec![]) }; +#[derive(Deserialize)] +struct Claims { + iss: String, + sub: String, + aud: String, + nonce: String, + iat: u64, + // Optional Google specific claims + email: Option, + name: Option, + picture: Option, +} + +pub struct Provider { + client_id: String, + certs: Rc>>, +} + +impl OpenIdProvider for Provider { + fn issuer(&self) -> &'static str { + ISSUER + } + + fn verify(&self, jwt: &str, salt: &[u8; 32]) -> Result { + // Decode JWT and verify claims + let validation_item = Decoder::new() + .decode_compact_serialization(jwt.as_bytes(), None) + .map_err(|_| "Unable to decode JWT")?; + let claims: Claims = serde_json::from_slice(validation_item.claims()) + .map_err(|_| "Unable to decode claims or expected claims are missing")?; + verify_claims(&self.client_id, &claims, salt)?; + + // Verify JWT signature + let kid = validation_item.kid().ok_or("JWT is missing kid")?; + let certs = self.certs.borrow(); + let cert = certs + .iter() + .find(|cert| cert.kid().is_some_and(|v| v == kid)) + .ok_or(format!("Certificate not found for {kid}"))?; + validation_item + .verify(&JwsVerifierFn::from(verify_signature), cert) + .map_err(|_| "Invalid signature")?; + + // Return credential with Google specific metadata + let mut metadata: HashMap = HashMap::new(); + if let Some(email) = claims.email { + metadata.insert("email".into(), MetadataEntryV2::String(email)); + } + if let Some(name) = claims.name { + metadata.insert("name".into(), MetadataEntryV2::String(name)); + } + if let Some(picture) = claims.picture { + metadata.insert("picture".into(), MetadataEntryV2::String(picture)); + } + Ok(OpenIdCredential { + iss: claims.iss, + sub: claims.sub, + aud: claims.aud, + principal: Principal::anonymous(), + last_usage_timestamp: time(), + metadata, + }) + } } -pub fn setup_timers() { - // Fetch the certs directly after canister initialization. - schedule_fetch_certs(None); +impl Provider { + pub fn create(client_id: String) -> Provider { + #[cfg(test)] + let certs = Rc::new(RefCell::new(TEST_CERTS.take())); + + #[cfg(not(test))] + let certs: Rc>> = Rc::new(RefCell::new(vec![])); + + #[cfg(not(test))] + schedule_fetch_certs(Rc::clone(&certs), None); + + Provider { client_id, certs } + } } -fn schedule_fetch_certs(delay: Option) { +#[cfg(not(test))] +fn schedule_fetch_certs(certs_reference: Rc>>, delay: Option) { + use ic_cdk::spawn; + use ic_cdk_timers::set_timer; + use std::cmp::min; + use std::time::Duration; + set_timer(Duration::from_secs(delay.unwrap_or(0)), move || { spawn(async move { let new_delay = match fetch_certs().await { Ok(google_certs) => { - CERTS.replace(google_certs); + certs_reference.replace(google_certs); FETCH_CERTS_INTERVAL } // Try again earlier with backoff if fetch failed, the HTTP outcall responses // aren't the same across nodes when we fetch at the moment of key rotation. Err(_) => min(FETCH_CERTS_INTERVAL, delay.unwrap_or(60) * 2), }; - schedule_fetch_certs(Some(new_delay)); + schedule_fetch_certs(certs_reference, Some(new_delay)); }); }); } +#[cfg(not(test))] async fn fetch_certs() -> Result, String> { + use ic_cdk::api::management_canister::http_request::{ + http_request_with_closure, CanisterHttpRequestArgument, HttpMethod, + }; + let request = CanisterHttpRequestArgument { url: CERTS_URL.into(), method: HttpMethod::GET, @@ -76,7 +182,7 @@ async fn fetch_certs() -> Result, String> { .await .map_err(|(_, err)| err)?; - serde_json::from_slice::(response.body.as_slice()) + serde_json::from_slice::(response.body.as_slice()) .map_err(|_| "Invalid JSON".into()) .map(|res| res.keys) } @@ -91,14 +197,14 @@ fn transform_certs(response: HttpResponse) -> HttpResponse { trap("Invalid response status") }; - let certs: GoogleCerts = + let certs: Certs = serde_json::from_slice(response.body.as_slice()).unwrap_or_else(|_| trap("Invalid JSON")); let mut sorted_keys = certs.keys.clone(); sorted_keys.sort_by_key(|key| key.kid().unwrap_or_else(|| trap("Invalid JSON")).to_owned()); - let body = serde_json::to_vec(&GoogleCerts { keys: sorted_keys }) - .unwrap_or_else(|_| trap("Invalid JSON")); + let body = + serde_json::to_vec(&Certs { keys: sorted_keys }).unwrap_or_else(|_| trap("Invalid JSON")); // All headers are ignored including the Cache-Control header, instead we fetch the certs // hourly since responses are always valid for at least 5 hours based on analysis of the @@ -110,6 +216,117 @@ fn transform_certs(response: HttpResponse) -> HttpResponse { } } +fn create_rsa_public_key(jwk: &Jwk) -> Result { + // Extract the RSA parameters (modulus 'n' and exponent 'e') from the JWK. + let JwkParamsRsa { n, e, .. } = jwk + .try_rsa_params() + .map_err(|_| "Unable to extract modulus and exponent")?; + + // Decode the base64-url encoded modulus 'n' of the RSA public key. + let n = BASE64_URL_SAFE_NO_PAD + .decode(n) + .map_err(|_| "Unable to decode modulus")?; + + // Decode the base64-url encoded public exponent 'e' of the RSA public key. + let e = BASE64_URL_SAFE_NO_PAD + .decode(e) + .map_err(|_| "Unable to decode exponent")?; + + // Construct the RSA public key using the decoded modulus and exponent. + RsaPublicKey::new( + rsa::BigUint::from_bytes_be(&n), + rsa::BigUint::from_bytes_be(&e), + ) + .map_err(|_| "Unable to construct RSA public key".into()) +} + +/// Verifier implementation for `identity_jose` that verifies the signature of a JWT. +/// +/// - `input`: A `VerificationInput` struct containing the JWT's algorithm (`alg`), +/// the signing input (payload to be hashed and verified), and the decoded signature. +/// - `jwk`: A reference to a `Jwk` (JSON Web Key) that contains the RSA public key +/// parameters (`n` and `e`) used to verify the JWT signature. +#[allow(clippy::needless_pass_by_value)] +fn verify_signature(input: VerificationInput, jwk: &Jwk) -> Result<(), SignatureVerificationError> { + // Ensure the algorithm specified in the JWT header matches the expected algorithm (RS256). + // JSON Web Keys (JWK) returned from Google API (v3) always use RSA with SHA-256. + // If the algorithm does not match, return an UnsupportedAlg error. + if input.alg != RS256 { + return Err(SignatureVerificationErrorKind::UnsupportedAlg.into()); + } + + // Compute the SHA-256 hash of the JWT payload (the signing input). + // This hashed value will be used for signature verification. + let hashed_input = Sha256::digest(input.signing_input); + + // Define the signature scheme to be used for verification (RSA PKCS#1 v1.5 with SHA-256). + let scheme = Pkcs1v15Sign::new::(); + + // Create RSA public key from JWK + let public_key = create_rsa_public_key(jwk).map_err(|_| { + SignatureVerificationError::new(SignatureVerificationErrorKind::KeyDecodingFailure) + })?; + + // Verify the JWT signature using the RSA public key and the defined signature scheme. + // If the signature is invalid, return an InvalidSignature error. + public_key + .verify(scheme, &hashed_input, input.decoded_signature.as_ref()) + .map_err(|_| SignatureVerificationErrorKind::InvalidSignature.into()) +} + +fn verify_claims(client_id: &String, claims: &Claims, salt: &[u8; 32]) -> Result<(), String> { + let now = time(); + let mut hasher = Sha256::new(); + hasher.update(salt); + hasher.update(caller().to_bytes()); + let hash: [u8; 32] = hasher.finalize().into(); + let expected_nonce = BASE64_URL_SAFE_NO_PAD.encode(hash); + + if claims.iss != ISSUER { + return Err(format!("Invalid issuer: {}", claims.iss)); + } + if &claims.aud != client_id { + return Err(format!("Invalid audience: {}", claims.aud)); + } + if claims.nonce != expected_nonce { + return Err(format!("Invalid nonce: {}", claims.nonce)); + } + if now > claims.iat * NANOSECONDS_PER_SECOND + MAX_VALIDITY_WINDOW { + return Err("JWT is no longer valid".into()); + } + if now < claims.iat * NANOSECONDS_PER_SECOND { + return Err("JWT is not valid yet".into()); + } + + Ok(()) +} + +#[cfg(test)] +thread_local! { + static TEST_CALLER: Cell = Cell::new(Principal::from_text("x4gp4-hxabd-5jt4d-wc6uw-qk4qo-5am4u-mncv3-wz3rt-usgjp-od3c2-oae").unwrap()); + static TEST_TIME: Cell = const { Cell::new(1_736_794_102 * NANOSECONDS_PER_SECOND) }; + static TEST_CERTS: Cell> = Cell::new(serde_json::from_str::(r#"{"keys":[{"n": "jwstqI4w2drqbTTVRDriFqepwVVI1y05D5TZCmGvgMK5hyOsVW0tBRiY9Jk9HKDRue3vdXiMgarwqZEDOyOA0rpWh-M76eauFhRl9lTXd5gkX0opwh2-dU1j6UsdWmMa5OpVmPtqXl4orYr2_3iAxMOhHZ_vuTeD0KGeAgbeab7_4ijyLeJ-a8UmWPVkglnNb5JmG8To77tSXGcPpBcAFpdI_jftCWr65eL1vmAkPNJgUTgI4sGunzaybf98LSv_w4IEBc3-nY5GfL-mjPRqVCRLUtbhHO_5AYDpqGj6zkKreJ9-KsoQUP6RrAVxkNuOHV9g1G-CHihKsyAifxNN2Q","use": "sig","kty": "RSA","alg": "RS256","kid": "dd125d5f462fbc6014aedab81ddf3bcedab70847","e": "AQAB"}]}"#).unwrap().keys); +} + +#[cfg(not(test))] +fn caller() -> Principal { + ic_cdk::caller() +} + +#[cfg(test)] +fn caller() -> Principal { + TEST_CALLER.get() +} + +#[cfg(not(test))] +fn time() -> u64 { + ic_cdk::api::time() +} +#[cfg(test)] +fn time() -> u64 { + TEST_TIME.get() +} + #[test] fn should_transform_certs_to_same() { let input = HttpResponse { @@ -128,3 +345,166 @@ fn should_transform_certs_to_same() { assert_eq!(transform_certs(input), expected); } + +#[cfg(test)] +fn test_data() -> (String, [u8; 32], Claims) { + // This JWT is for testing purposes, it's already been expired before this commit has been made, + // additionally the audience of this JWT is a test Google client registration, not production. + let jwt = "eyJhbGciOiJSUzI1NiIsImtpZCI6ImRkMTI1ZDVmNDYyZmJjNjAxNGFlZGFiODFkZGYzYmNlZGFiNzA4NDciLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI0NTQzMTk5NDYxOS1jYmJmZ3RuN28wcHAwZHBmY2cybDY2YmM0cmNnN3FidS5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbSIsImF1ZCI6IjQ1NDMxOTk0NjE5LWNiYmZndG43bzBwcDBkcGZjZzJsNjZiYzRyY2c3cWJ1LmFwcHMuZ29vZ2xldXNlcmNvbnRlbnQuY29tIiwic3ViIjoiMTE1MTYwNzE2MzM4ODEzMDA2OTAyIiwiaGQiOiJkZmluaXR5Lm9yZyIsImVtYWlsIjoidGhvbWFzLmdsYWRkaW5lc0BkZmluaXR5Lm9yZyIsImVtYWlsX3ZlcmlmaWVkIjp0cnVlLCJub25jZSI6ImV0aURhTEdjUmRtNS1yY3FlMFpRVWVNZ3BmcDR2OVRPT1lVUGJoUng3bkkiLCJuYmYiOjE3MzY3OTM4MDIsIm5hbWUiOiJUaG9tYXMgR2xhZGRpbmVzIiwicGljdHVyZSI6Imh0dHBzOi8vbGgzLmdvb2dsZXVzZXJjb250ZW50LmNvbS9hL0FDZzhvY0lTTWxja0M1RjZxaGlOWnpfREZtWGp5OTY4LXlPaEhPTjR4TGhRdXVNSDNuQlBXQT1zOTYtYyIsImdpdmVuX25hbWUiOiJUaG9tYXMiLCJmYW1pbHlfbmFtZSI6IkdsYWRkaW5lcyIsImlhdCI6MTczNjc5NDEwMiwiZXhwIjoxNzM2Nzk3NzAyLCJqdGkiOiIwMWM1NmYyMGM1MzFkNDhhYjU0ZDMwY2I4ZmRiNzU0MmM0ZjdmNjg4In0.f47b0HNskm-85sT5XtoRzORnfobK2nzVFG8jTH6eS_qAyu0ojNDqVsBtGN4A7HdjDDCOIMSu-R5e413xuGJIWLadKrLwXmguRFo3SzLrXeja-A-rP-axJsb5QUJZx1mwYd1vUNzLB9bQojU3Na6Hdvq09bMtTwaYdCn8Q9v3RErN-5VUxELmSbSXbf10A-IsS7jtzPjxHV6ueq687Ppeww6Q7AGGFB4t9H8qcDbI1unSdugX3-MfMWJLzVHbVxDgfAcLem1c2iAspvv_D5aPLeJF5HLRR2zg-Jil1BFTOoEPAAPFr1MEsvDMWSTt5jLyuMrnS4jiMGudGGPV4DDDww"; + let salt: [u8; 32] = [ + 143, 79, 158, 224, 218, 125, 157, 169, 98, 43, 205, 227, 243, 123, 173, 255, 132, 83, 81, + 139, 161, 18, 224, 243, 4, 129, 26, 123, 229, 242, 200, 189, + ]; + let validation_item = Decoder::new() + .decode_compact_serialization(jwt.as_bytes(), None) + .unwrap(); + let claims: Claims = serde_json::from_slice(validation_item.claims()).unwrap(); + + (jwt.into(), salt, claims) +} + +#[test] +fn should_return_credential() { + let (jwt, salt, claims) = test_data(); + let provider = Provider::create(claims.aud.clone()); + let credential = OpenIdCredential { + iss: claims.iss, + sub: claims.sub, + aud: claims.aud, + principal: Principal::anonymous(), + last_usage_timestamp: time(), + metadata: HashMap::from([ + ( + "email".into(), + MetadataEntryV2::String(claims.email.unwrap()), + ), + ("name".into(), MetadataEntryV2::String(claims.name.unwrap())), + ( + "picture".into(), + MetadataEntryV2::String(claims.picture.unwrap()), + ), + ]), + }; + + assert_eq!(provider.verify(&jwt, &salt), Ok(credential)); +} + +#[test] +fn should_return_error_when_encoding_invalid() { + let (_, salt, claims) = test_data(); + let provider = Provider::create(claims.aud.clone()); + let invalid_jwt = "invalid-jwt"; + + assert_eq!( + provider.verify(invalid_jwt, &salt), + Err("Unable to decode JWT".into()) + ); +} + +#[test] +fn should_return_error_when_cert_missing() { + TEST_CERTS.replace(vec![]); + let (jwt, salt, claims) = test_data(); + let provider = Provider::create(claims.aud.clone()); + + assert_eq!( + provider.verify(&jwt, &salt), + Err("Certificate not found for dd125d5f462fbc6014aedab81ddf3bcedab70847".into()) + ); +} + +#[test] +fn should_return_error_when_signature_invalid() { + let (jwt, salt, claims) = test_data(); + let provider = Provider::create(claims.aud.clone()); + let chunks: Vec<&str> = jwt.split('.').collect(); + let header = chunks[0]; + let payload = chunks[1]; + let invalid_signature = "f47b0sNskm-85sT5XtoRzORnfobK2nzVFF8jTH6eS_qAyu0ojNDqVsBtGN4A7HdjDDCOIMSu-R5e413xuGJIWLadKrLwXmguRFo3SzLrXeja-A-rP-axJsb5QUJZx1mwYd1vUNzLB9bQojU3Na6Hdvq09bMtTwaYdCn8Q9v3RErN-5VUxELmSbSXbf10A-IsS7jtzPjxHV6ueq687Ppeww5Q7AGGFB4t9H8qcDbI1unSdugX3-MfMWJLzVHbVxDgfAcLem1c2iAspvv_D5aPLeJF5HLRR2zg-Jil1BFTOoEPAAPFr1MEsvDMWSTt5jLyuMrnS4jiMGudGGPV4DDDww"; + let invalid_jwt = [header, payload, invalid_signature].join("."); + + assert_eq!( + provider.verify(&invalid_jwt, &salt), + Err("Invalid signature".into()) + ); +} + +#[test] +fn should_return_error_when_invalid_issuer() { + let (_, salt, claims) = test_data(); + let client_id = claims.aud.clone(); + let mut invalid_claims = claims; + invalid_claims.iss = "invalid-issuer".into(); + + assert_eq!( + verify_claims(&client_id, &invalid_claims, &salt), + Err(format!("Invalid issuer: {}", invalid_claims.iss)) + ); +} + +#[test] +fn should_return_error_when_invalid_audience() { + let (_, salt, claims) = test_data(); + let client_id = claims.aud.clone(); + let mut invalid_claims = claims; + invalid_claims.aud = "invalid-audience".into(); + + assert_eq!( + verify_claims(&client_id, &invalid_claims, &salt), + Err(format!("Invalid audience: {}", invalid_claims.aud)) + ); +} + +#[test] +fn should_return_error_when_invalid_salt() { + let (_, _, claims) = test_data(); + let client_id = &claims.aud; + let invalid_salt: [u8; 32] = [ + 143, 79, 58, 224, 18, 15, 157, 169, 98, 43, 205, 227, 243, 123, 173, 255, 132, 83, 81, 139, + 161, 218, 224, 243, 4, 120, 26, 123, 229, 242, 200, 189, + ]; + + assert_eq!( + verify_claims(client_id, &claims, &invalid_salt), + Err("Invalid nonce: etiDaLGcRdm5-rcqe0ZQUeMgpfp4v9TOOYUPbhRx7nI".into()) + ); +} + +#[test] +fn should_return_error_when_invalid_caller() { + TEST_CALLER.replace( + Principal::from_text("necp6-24oof-6e2i2-xg7fk-pawxw-nlol2-by5bb-mltvt-sazk6-nqrzz-zae") + .unwrap(), + ); + let (_, salt, claims) = test_data(); + let client_id = &claims.aud; + + assert_eq!( + verify_claims(client_id, &claims, &salt), + Err("Invalid nonce: etiDaLGcRdm5-rcqe0ZQUeMgpfp4v9TOOYUPbhRx7nI".into()) + ); +} + +#[test] +fn should_return_error_when_no_longer_valid() { + TEST_TIME.replace(time() + MAX_VALIDITY_WINDOW + 1); + let (_, salt, claims) = test_data(); + let client_id = &claims.aud; + + assert_eq!( + verify_claims(client_id, &claims, &salt), + Err("JWT is no longer valid".into()) + ); +} + +#[test] +fn should_return_error_when_not_valid_yet() { + TEST_TIME.replace(time() - 1); + let (_, salt, claims) = test_data(); + let client_id = &claims.aud; + + assert_eq!( + verify_claims(client_id, &claims, &salt), + Err("JWT is not valid yet".into()) + ); +} diff --git a/src/internet_identity/src/state.rs b/src/internet_identity/src/state.rs index b56c6e475c..03082e8bc3 100644 --- a/src/internet_identity/src/state.rs +++ b/src/internet_identity/src/state.rs @@ -105,6 +105,8 @@ pub struct PersistentState { pub captcha_config: CaptchaConfig, // Configuration for Related Origins Requests pub related_origins: Option>, + // Configuration for OpenID Google client id + pub openid_google_client_id: Option, // Key into the event_data BTreeMap where the 24h tracking window starts. // This key is used to remove old entries from the 24h event aggregations. // If it is `none`, then the 24h window starts from the newest entry in the event_data @@ -124,6 +126,7 @@ impl Default for PersistentState { active_authn_method_stats: ActivityStats::new(time), captcha_config: DEFAULT_CAPTCHA_CONFIG, related_origins: None, + openid_google_client_id: None, event_stats_24h_start: None, } } diff --git a/src/internet_identity/src/storage/storable_persistent_state.rs b/src/internet_identity/src/storage/storable_persistent_state.rs index f61b4a3a06..87b34acc31 100644 --- a/src/internet_identity/src/storage/storable_persistent_state.rs +++ b/src/internet_identity/src/storage/storable_persistent_state.rs @@ -33,6 +33,7 @@ pub struct StorablePersistentState { event_stats_24h_start: Option, captcha_config: Option, related_origins: Option>, + openid_google_client_id: Option, } impl Storable for StorablePersistentState { @@ -71,6 +72,7 @@ impl From for StorablePersistentState { event_stats_24h_start: s.event_stats_24h_start, captcha_config: Some(s.captcha_config), related_origins: s.related_origins, + openid_google_client_id: s.openid_google_client_id, } } } @@ -86,6 +88,7 @@ impl From for PersistentState { active_authn_method_stats: s.active_authn_method_stats, captcha_config: s.captcha_config.unwrap_or(DEFAULT_CAPTCHA_CONFIG), related_origins: s.related_origins, + openid_google_client_id: s.openid_google_client_id, event_stats_24h_start: s.event_stats_24h_start, } } @@ -131,6 +134,7 @@ mod tests { captcha_trigger: CaptchaTrigger::Static(StaticCaptchaTrigger::CaptchaEnabled), }), related_origins: None, + openid_google_client_id: None, }; assert_eq!(StorablePersistentState::default(), expected_defaults); @@ -150,6 +154,7 @@ mod tests { captcha_trigger: CaptchaTrigger::Static(StaticCaptchaTrigger::CaptchaEnabled), }, related_origins: None, + openid_google_client_id: None, event_stats_24h_start: None, }; assert_eq!(PersistentState::default(), expected_defaults); diff --git a/src/internet_identity/tests/integration/config.rs b/src/internet_identity/tests/integration/config.rs index 960b559052..59430eeb01 100644 --- a/src/internet_identity/tests/integration/config.rs +++ b/src/internet_identity/tests/integration/config.rs @@ -32,6 +32,7 @@ fn should_retain_anchor_on_user_range_change() -> Result<(), CallError> { }, }), related_origins: None, + openid_google_client_id: None, }; let canister_id = install_ii_canister_with_arg(&env, II_WASM.clone(), Some(config.clone())); @@ -43,12 +44,13 @@ fn should_retain_anchor_on_user_range_change() -> Result<(), CallError> { #[test] fn should_retain_config_after_none() -> Result<(), CallError> { let env = env(); - let related_origins: Vec = [ + let related_origins = [ "https://identity.internetcomputer.org".to_string(), "https://identity.ic0.app".to_string(), "https://identity.icp0.io".to_string(), ] .to_vec(); + let openid_google_client_id = "https://example.com".to_string(); let config = InternetIdentityInit { assigned_user_number_range: Some((3456, 798977)), archive_config: Some(ArchiveConfig { @@ -71,6 +73,7 @@ fn should_retain_config_after_none() -> Result<(), CallError> { }, }), related_origins: Some(related_origins), + openid_google_client_id: Some(openid_google_client_id), }; let canister_id = install_ii_canister_with_arg(&env, II_WASM.clone(), Some(config.clone())); @@ -87,12 +90,13 @@ fn should_retain_config_after_none() -> Result<(), CallError> { #[test] fn should_override_partially() -> Result<(), CallError> { let env = env(); - let related_origins: Vec = [ + let related_origins = [ "https://identity.internetcomputer.org".to_string(), "https://identity.ic0.app".to_string(), "https://identity.icp0.io".to_string(), ] .to_vec(); + let openid_google_client_id = "https://example.com".to_string(); let config = InternetIdentityInit { assigned_user_number_range: Some((3456, 798977)), archive_config: Some(ArchiveConfig { @@ -115,6 +119,7 @@ fn should_override_partially() -> Result<(), CallError> { }, }), related_origins: Some(related_origins), + openid_google_client_id: Some(openid_google_client_id), }; let canister_id = install_ii_canister_with_arg(&env, II_WASM.clone(), Some(config.clone())); @@ -136,6 +141,7 @@ fn should_override_partially() -> Result<(), CallError> { register_rate_limit: None, captcha_config: Some(new_captcha.clone()), related_origins: None, + openid_google_client_id: None, }; let _ = @@ -153,6 +159,7 @@ fn should_override_partially() -> Result<(), CallError> { "https://identity.ic0.app".to_string(), ] .to_vec(); + let openid_google_client_id_2 = "https://example2.com".to_string(); let config_3 = InternetIdentityInit { assigned_user_number_range: None, archive_config: None, @@ -160,6 +167,7 @@ fn should_override_partially() -> Result<(), CallError> { register_rate_limit: None, captcha_config: None, related_origins: Some(related_origins_2.clone()), + openid_google_client_id: Some(openid_google_client_id_2.clone()), }; let _ = @@ -167,6 +175,7 @@ fn should_override_partially() -> Result<(), CallError> { let expected_config_3 = InternetIdentityInit { related_origins: Some(related_origins_2.clone()), + openid_google_client_id: Some(openid_google_client_id_2.clone()), ..expected_config_2 }; diff --git a/src/internet_identity/tests/integration/http.rs b/src/internet_identity/tests/integration/http.rs index 4fb94e9ac3..7688676c8e 100644 --- a/src/internet_identity/tests/integration/http.rs +++ b/src/internet_identity/tests/integration/http.rs @@ -93,6 +93,7 @@ fn ii_canister_serves_webauthn_assets() -> Result<(), CallError> { register_rate_limit: None, captcha_config: None, related_origins: Some(related_origins.clone()), + openid_google_client_id: None, }; let canister_id = install_ii_canister_with_arg(&env, II_WASM.clone(), Some(config)); @@ -154,6 +155,7 @@ fn ii_canister_serves_webauthn_assets_after_upgrade() -> Result<(), CallError> { register_rate_limit: None, captcha_config: None, related_origins: Some(related_origins.clone()), + openid_google_client_id: None, }; let canister_id = install_ii_canister_with_arg(&env, II_WASM.clone(), Some(config)); @@ -191,6 +193,7 @@ fn ii_canister_serves_webauthn_assets_after_upgrade() -> Result<(), CallError> { register_rate_limit: None, captcha_config: None, related_origins: Some(related_origins_2.clone()), + openid_google_client_id: None, }; let _ = upgrade_ii_canister_with_arg(&env, canister_id, II_WASM.clone(), Some(config_2)); @@ -573,6 +576,7 @@ fn must_not_cache_well_known_webauthn() -> Result<(), CallError> { register_rate_limit: None, captcha_config: None, related_origins: Some(related_origins.clone()), + openid_google_client_id: None, }; let canister_id = install_ii_canister_with_arg(&env, II_WASM.clone(), Some(config)); diff --git a/src/internet_identity_interface/src/internet_identity/types.rs b/src/internet_identity_interface/src/internet_identity/types.rs index 4d5fc66dee..ab41b1bd42 100644 --- a/src/internet_identity_interface/src/internet_identity/types.rs +++ b/src/internet_identity_interface/src/internet_identity/types.rs @@ -196,6 +196,7 @@ pub struct InternetIdentityInit { pub register_rate_limit: Option, pub captcha_config: Option, pub related_origins: Option>, + pub openid_google_client_id: Option, } #[derive(Clone, Debug, CandidType, Deserialize, Eq, PartialEq)]