diff --git a/CHANGELOG.md b/CHANGELOG.md index 780e3df93..41b9cf914 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,9 +22,10 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Added - `ZkProgram` to support non-pure provable types as inputs and outputs https://github.com/o1-labs/o1js/pull/1828 -- API for recursively proving a ZkProgram method from within another https://github.com/o1-labs/o1js/pull/1931 +- APIs for recursively proving a ZkProgram method from within another https://github.com/o1-labs/o1js/pull/1931 https://github.com/o1-labs/o1js/pull/1932 - `let recursive = Experimental.Recursive(program);` - `recursive.(...args): Promise` + - `recursive..if(condition, ...args): Promise` - This also works within the same program, as long as the return value is type-annotated - Add `enforceTransactionLimits` parameter on Network https://github.com/o1-labs/o1js/issues/1910 - Method for optional types to assert none https://github.com/o1-labs/o1js/pull/1922 diff --git a/src/examples/zkprogram/hash-chain.ts b/src/examples/zkprogram/hash-chain.ts new file mode 100644 index 000000000..53bc574cf --- /dev/null +++ b/src/examples/zkprogram/hash-chain.ts @@ -0,0 +1,72 @@ +/** + * This shows how to prove an arbitrarily long chain of hashes using ZkProgram, i.e. + * `hash^n(x) = y`. + * + * We implement this as a self-recursive ZkProgram, using `proveRecursivelyIf()` + */ +import { + assert, + Bool, + Experimental, + Field, + Poseidon, + Provable, + Struct, + ZkProgram, +} from 'o1js'; + +const HASHES_PER_PROOF = 30; + +class HashChainSpec extends Struct({ x: Field, n: Field }) {} + +const hashChain = ZkProgram({ + name: 'hash-chain', + publicInput: HashChainSpec, + publicOutput: Field, + + methods: { + chain: { + privateInputs: [], + + async method({ x, n }: HashChainSpec) { + Provable.log('hashChain (start method)', n); + let y = x; + let k = Field(0); + let reachedN = Bool(false); + + for (let i = 0; i < HASHES_PER_PROOF; i++) { + reachedN = k.equals(n); + y = Provable.if(reachedN, y, Poseidon.hash([y])); + k = Provable.if(reachedN, n, k.add(1)); + } + + // we have y = hash^k(x) + // now do z = hash^(n-k)(y) = hash^n(x) by calling this method recursively + // except if we have k = n, then ignore the output and use y + let z: Field = await hashChainRecursive.chain.if(reachedN.not(), { + x: y, + n: n.sub(k), + }); + z = Provable.if(reachedN, y, z); + Provable.log('hashChain (start proving)', n); + return { publicOutput: z }; + }, + }, + }, +}); +let hashChainRecursive = Experimental.Recursive(hashChain); + +await hashChain.compile(); + +let n = 100; +let x = Field.random(); + +let { proof } = await hashChain.chain({ x, n: Field(n) }); + +assert(await hashChain.verify(proof), 'Proof invalid'); + +// check that the output is correct +let z = Array.from({ length: n }, () => 0).reduce((y) => Poseidon.hash([y]), x); +proof.publicOutput.assertEquals(z, 'Output is incorrect'); + +console.log('Finished hash chain proof'); diff --git a/src/lib/proof-system/recursive.ts b/src/lib/proof-system/recursive.ts index f7f0a4a3f..bdc6f1031 100644 --- a/src/lib/proof-system/recursive.ts +++ b/src/lib/proof-system/recursive.ts @@ -5,6 +5,7 @@ import { Tuple } from '../util/types.js'; import { Proof } from './proof.js'; import { mapObject, mapToObject, zip } from '../util/arrays.js'; import { Undefined, Void } from './zkprogram.js'; +import { Bool } from '../provable/bool.js'; export { Recursive }; @@ -25,6 +26,7 @@ function Recursive< ...args: any ) => Promise<{ publicOutput: InferProvable }>; }; + maxProofsVerified: () => Promise<0 | 1 | 2>; } & { [Key in keyof PrivateInputs]: (...args: any) => Promise<{ proof: Proof< @@ -38,7 +40,13 @@ function Recursive< InferProvable, InferProvable, PrivateInputs[Key] - >; + > & { + if: ConditionalRecursiveProver< + InferProvable, + InferProvable, + PrivateInputs[Key] + >; + }; } { type PublicInput = InferProvable; type PublicOutput = InferProvable; @@ -64,9 +72,15 @@ function Recursive< let regularRecursiveProvers = mapToObject(methodKeys, (key) => { return async function proveRecursively_( + conditionAndConfig: Bool | { condition: Bool; domainLog2?: number }, publicInput: PublicInput, ...args: TupleToInstances - ) { + ): Promise { + let condition = + conditionAndConfig instanceof Bool + ? conditionAndConfig + : conditionAndConfig.condition; + // create the base proof in a witness block let proof = await Provable.witnessAsync(SelfProof, async () => { // move method args to constants @@ -78,6 +92,20 @@ function Recursive< Provable.toConstant(type, arg) ); + if (!condition.toBoolean()) { + let publicOutput: PublicOutput = + ProvableType.synthesize(publicOutputType); + let maxProofsVerified = await zkprogram.maxProofsVerified(); + return SelfProof.dummy( + publicInput, + publicOutput, + maxProofsVerified, + conditionAndConfig instanceof Bool + ? undefined + : conditionAndConfig.domainLog2 + ); + } + let prover = zkprogram[key]; if (hasPublicInput) { @@ -96,32 +124,48 @@ function Recursive< // declare and verify the proof, and return its public output proof.declare(); - proof.verify(); + proof.verifyIf(condition); return proof.publicOutput; }; }); - type RecursiveProver_ = RecursiveProver< - PublicInput, - PublicOutput, - PrivateInputs[K] - >; - type RecursiveProvers = { - [K in MethodKey]: RecursiveProver_; - }; - let proveRecursively: RecursiveProvers = mapToObject( - methodKeys, - (key: MethodKey) => { + return mapObject( + regularRecursiveProvers, + ( + prover + ): RecursiveProver & { + if: ConditionalRecursiveProver< + PublicInput, + PublicOutput, + PrivateInputs[MethodKey] + >; + } => { if (!hasPublicInput) { - return ((...args: any) => - regularRecursiveProvers[key](undefined as any, ...args)) as any; + return Object.assign( + ((...args: any) => + prover(new Bool(true), undefined as any, ...args)) as any, + { + if: ( + condition: Bool | { condition: Bool; domainLog2?: number }, + ...args: any + ) => prover(condition, undefined as any, ...args), + } + ); } else { - return regularRecursiveProvers[key] as any; + return Object.assign( + ((pi: PublicInput, ...args: any) => + prover(new Bool(true), pi, ...args)) as any, + { + if: ( + condition: Bool | { condition: Bool; domainLog2?: number }, + pi: PublicInput, + ...args: any + ) => prover(condition, pi, ...args), + } + ); } } ); - - return proveRecursively; } type RecursiveProver< @@ -135,6 +179,21 @@ type RecursiveProver< ...args: TupleToInstances ) => Promise; +type ConditionalRecursiveProver< + PublicInput, + PublicOutput, + Args extends Tuple +> = PublicInput extends undefined + ? ( + condition: Bool | { condition: Bool; domainLog2?: number }, + ...args: TupleToInstances + ) => Promise + : ( + condition: Bool | { condition: Bool; domainLog2?: number }, + publicInput: PublicInput, + ...args: TupleToInstances + ) => Promise; + type TupleToInstances = { [I in keyof T]: InferProvable; }; diff --git a/src/lib/proof-system/zkprogram.ts b/src/lib/proof-system/zkprogram.ts index b6e39ee96..01c04322c 100644 --- a/src/lib/proof-system/zkprogram.ts +++ b/src/lib/proof-system/zkprogram.ts @@ -30,7 +30,6 @@ import { unsetSrsCache, } from '../../bindings/crypto/bindings/srs.js'; import { - ProvablePure, ProvableType, ProvableTypePure, ToProvable, @@ -55,7 +54,7 @@ import { import { emptyWitness } from '../provable/types/util.js'; import { InferValue } from '../../bindings/lib/provable-generic.js'; import { DeclaredProof, ZkProgramContext } from './zkprogram-context.js'; -import { mapObject, mapToObject, zip } from '../util/arrays.js'; +import { mapObject, mapToObject } from '../util/arrays.js'; // public API export {