diff --git a/Cargo.lock b/Cargo.lock index 914038e03ffa..70906b5435ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -586,7 +586,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e" dependencies = [ "termcolor", - "unicode-width", + "unicode-width 0.1.9", ] [[package]] @@ -642,7 +642,7 @@ dependencies = [ "encode_unicode", "lazy_static", "libc", - "unicode-width", + "unicode-width 0.1.9", "windows-sys 0.52.0", ] @@ -3136,7 +3136,7 @@ dependencies = [ "cargo_metadata", "heck 0.5.0", "wasmtime", - "wit-component", + "wit-component 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", ] [[package]] @@ -3428,6 +3428,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3ed742d4ea2bd1176e236172c8429aaf54486e7ac098db29ffe6529e0ce50973" +[[package]] +name = "unicode-width" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" + [[package]] name = "unicode-xid" version = "0.2.3" @@ -3528,7 +3534,7 @@ name = "verify-component-adapter" version = "28.0.0" dependencies = [ "anyhow", - "wasmparser", + "wasmparser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", "wat", ] @@ -3620,7 +3626,7 @@ dependencies = [ "byte-array-literals", "object", "wasi", - "wasm-encoder", + "wasm-encoder 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", "wit-bindgen-rust-macro", ] @@ -3685,7 +3691,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29cbbd772edcb8e7d524a82ee8cef8dd046fc14033796a754c3ad246d019fa54" dependencies = [ "leb128", - "wasmparser", + "wasmparser 0.219.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "wasm-encoder" +version = "0.219.1" +source = "git+https://github.com/dicej/wasm-tools?branch=async#940de62bd278831d1134bde84caedc1adcce6a00" +dependencies = [ + "leb128", + "wasmparser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", ] [[package]] @@ -3700,36 +3715,49 @@ dependencies = [ "serde_derive", "serde_json", "spdx", - "wasm-encoder", - "wasmparser", + "wasm-encoder 0.219.1 (registry+https://github.com/rust-lang/crates.io-index)", + "wasmparser 0.219.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "wasm-metadata" +version = "0.219.1" +source = "git+https://github.com/dicej/wasm-tools?branch=async#940de62bd278831d1134bde84caedc1adcce6a00" +dependencies = [ + "anyhow", + "indexmap 2.2.6", + "serde", + "serde_derive", + "serde_json", + "spdx", + "wasm-encoder 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", + "wasmparser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", ] [[package]] name = "wasm-mutate" version = "0.219.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cf7f56e470118baa73eb16a0e8dc6d2af8766b7b32d61450c7a5e4fceb150a4" +source = "git+https://github.com/dicej/wasm-tools?branch=async#940de62bd278831d1134bde84caedc1adcce6a00" dependencies = [ "egg", "log", "rand", "thiserror", - "wasm-encoder", - "wasmparser", + "wasm-encoder 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", + "wasmparser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", ] [[package]] name = "wasm-smith" version = "0.219.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b526e4c6eed409b619960258ba5bd8a3b44dfb30c75c12fce80b750a4487fcc" +source = "git+https://github.com/dicej/wasm-tools?branch=async#940de62bd278831d1134bde84caedc1adcce6a00" dependencies = [ "anyhow", "arbitrary", "flagset", "indexmap 2.2.6", "leb128", - "wasm-encoder", + "wasm-encoder 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", ] [[package]] @@ -3789,6 +3817,18 @@ name = "wasmparser" version = "0.219.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c771866898879073c53b565a6c7b49953795159836714ac56a5befb581227c5" +dependencies = [ + "ahash", + "bitflags 2.6.0", + "hashbrown 0.14.3", + "indexmap 2.2.6", + "semver", +] + +[[package]] +name = "wasmparser" +version = "0.219.1" +source = "git+https://github.com/dicej/wasm-tools?branch=async#940de62bd278831d1134bde84caedc1adcce6a00" dependencies = [ "ahash", "bitflags 2.6.0", @@ -3810,12 +3850,11 @@ dependencies = [ [[package]] name = "wasmprinter" version = "0.219.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "228cdc1f30c27816da225d239ce4231f28941147d34713dee8f1fff7cb330e54" +source = "git+https://github.com/dicej/wasm-tools?branch=async#940de62bd278831d1134bde84caedc1adcce6a00" dependencies = [ "anyhow", "termcolor", - "wasmparser", + "wasmparser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", ] [[package]] @@ -3831,6 +3870,7 @@ dependencies = [ "cfg-if", "encoding_rs", "env_logger 0.11.5", + "futures", "fxprof-processed-profile", "gimli", "hashbrown 0.14.3", @@ -3860,8 +3900,8 @@ dependencies = [ "target-lexicon", "tempfile", "wasi-common", - "wasm-encoder", - "wasmparser", + "wasm-encoder 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", + "wasmparser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", "wasmtime-asm-macros", "wasmtime-cache", "wasmtime-component-macro", @@ -4002,7 +4042,7 @@ dependencies = [ "tracing", "walkdir", "wasi-common", - "wasmparser", + "wasmparser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", "wasmtime", "wasmtime-cache", "wasmtime-cli-flags", @@ -4021,7 +4061,7 @@ dependencies = [ "wast 219.0.1", "wat", "windows-sys 0.59.0", - "wit-component", + "wit-component 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", ] [[package]] @@ -4054,7 +4094,7 @@ dependencies = [ "wasmtime", "wasmtime-component-util", "wasmtime-wit-bindgen", - "wit-parser", + "wit-parser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", ] [[package]] @@ -4079,7 +4119,7 @@ dependencies = [ "smallvec", "target-lexicon", "thiserror", - "wasmparser", + "wasmparser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", "wasmtime-environ", "wasmtime-versioned-export-macros", ] @@ -4105,8 +4145,8 @@ dependencies = [ "serde_derive", "smallvec", "target-lexicon", - "wasm-encoder", - "wasmparser", + "wasm-encoder 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", + "wasmparser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", "wasmprinter", "wasmtime-component-util", "wat", @@ -4120,7 +4160,7 @@ dependencies = [ "component-fuzz-util", "env_logger 0.11.5", "libfuzzer-sys", - "wasmparser", + "wasmparser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", "wasmprinter", "wasmtime-environ", "wat", @@ -4179,7 +4219,7 @@ dependencies = [ "rand", "smallvec", "target-lexicon", - "wasmparser", + "wasmparser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", "wasmtime", "wasmtime-fuzzing", ] @@ -4200,12 +4240,12 @@ dependencies = [ "target-lexicon", "tempfile", "v8", - "wasm-encoder", + "wasm-encoder 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", "wasm-mutate", "wasm-smith", "wasm-spec-interpreter", "wasmi", - "wasmparser", + "wasmparser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", "wasmprinter", "wasmtime", "wasmtime-wast", @@ -4386,7 +4426,7 @@ dependencies = [ "gimli", "object", "target-lexicon", - "wasmparser", + "wasmparser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", "wasmtime-cranelift", "wasmtime-environ", "winch-codegen", @@ -4399,7 +4439,7 @@ dependencies = [ "anyhow", "heck 0.5.0", "indexmap 2.2.6", - "wit-parser", + "wit-parser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", ] [[package]] @@ -4418,21 +4458,19 @@ dependencies = [ [[package]] name = "wast" version = "219.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f79a9d9df79986a68689a6b40bcc8d5d40d807487b235bebc2ac69a242b54a1" +source = "git+https://github.com/dicej/wasm-tools?branch=async#940de62bd278831d1134bde84caedc1adcce6a00" dependencies = [ "bumpalo", "leb128", "memchr", - "unicode-width", - "wasm-encoder", + "unicode-width 0.2.0", + "wasm-encoder 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", ] [[package]] name = "wat" version = "1.219.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bc3cf014fb336883a411cd662f987abf6a1d2a27f2f0008616a0070bbf6bd0d" +source = "git+https://github.com/dicej/wasm-tools?branch=async#940de62bd278831d1134bde84caedc1adcce6a00" dependencies = [ "wast 219.0.1", ] @@ -4575,7 +4613,7 @@ dependencies = [ "regalloc2", "smallvec", "target-lexicon", - "wasmparser", + "wasmparser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", "wasmtime-cranelift", "wasmtime-environ", ] @@ -4814,7 +4852,7 @@ checksum = "163cee59d3d5ceec0b256735f3ab0dccac434afb0ec38c406276de9c5a11e906" dependencies = [ "anyhow", "heck 0.5.0", - "wit-parser", + "wit-parser 0.219.1 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -4837,9 +4875,9 @@ dependencies = [ "indexmap 2.2.6", "prettyplease", "syn 2.0.60", - "wasm-metadata", + "wasm-metadata 0.219.1 (registry+https://github.com/rust-lang/crates.io-index)", "wit-bindgen-core", - "wit-component", + "wit-component 0.219.1 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -4870,10 +4908,28 @@ dependencies = [ "serde", "serde_derive", "serde_json", - "wasm-encoder", - "wasm-metadata", - "wasmparser", - "wit-parser", + "wasm-encoder 0.219.1 (registry+https://github.com/rust-lang/crates.io-index)", + "wasm-metadata 0.219.1 (registry+https://github.com/rust-lang/crates.io-index)", + "wasmparser 0.219.1 (registry+https://github.com/rust-lang/crates.io-index)", + "wit-parser 0.219.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "wit-component" +version = "0.219.1" +source = "git+https://github.com/dicej/wasm-tools?branch=async#940de62bd278831d1134bde84caedc1adcce6a00" +dependencies = [ + "anyhow", + "bitflags 2.6.0", + "indexmap 2.2.6", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", + "wasm-metadata 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", + "wasmparser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", + "wit-parser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", ] [[package]] @@ -4891,7 +4947,24 @@ dependencies = [ "serde_derive", "serde_json", "unicode-xid", - "wasmparser", + "wasmparser 0.219.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "wit-parser" +version = "0.219.1" +source = "git+https://github.com/dicej/wasm-tools?branch=async#940de62bd278831d1134bde84caedc1adcce6a00" +dependencies = [ + "anyhow", + "id-arena", + "indexmap 2.2.6", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser 0.219.1 (git+https://github.com/dicej/wasm-tools?branch=async)", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 30f66681a513..ecd1ccb091a3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -274,15 +274,15 @@ wit-bindgen = { version = "0.34.0", default-features = false } wit-bindgen-rust-macro = { version = "0.34.0", default-features = false } # wasm-tools family: -wasmparser = { version = "0.219.1", default-features = false } -wat = "1.219.1" -wast = "219.0.1" -wasmprinter = "0.219.1" -wasm-encoder = "0.219.1" -wasm-smith = "0.219.1" -wasm-mutate = "0.219.1" -wit-parser = "0.219.1" -wit-component = "0.219.1" +wasmparser = { git = "https://github.com/dicej/wasm-tools", branch = "async" } +wat = { git = "https://github.com/dicej/wasm-tools", branch = "async" } +wast = { git = "https://github.com/dicej/wasm-tools", branch = "async" } +wasmprinter = { git = "https://github.com/dicej/wasm-tools", branch = "async" } +wasm-encoder = { git = "https://github.com/dicej/wasm-tools", branch = "async" } +wasm-smith = { git = "https://github.com/dicej/wasm-tools", branch = "async" } +wasm-mutate = { git = "https://github.com/dicej/wasm-tools", branch = "async" } +wit-parser = { git = "https://github.com/dicej/wasm-tools", branch = "async" } +wit-component = { git = "https://github.com/dicej/wasm-tools", branch = "async" } # Non-Bytecode Alliance maintained dependencies: # -------------------------- diff --git a/crates/component-macro/Cargo.toml b/crates/component-macro/Cargo.toml index 79dbc6a27353..5e38dd2977ce 100644 --- a/crates/component-macro/Cargo.toml +++ b/crates/component-macro/Cargo.toml @@ -41,3 +41,4 @@ similar = { workspace = true } [features] async = [] std = ['wasmtime-wit-bindgen/std'] +component-model-async = ['std', 'async', 'wasmtime-wit-bindgen/component-model-async'] diff --git a/crates/component-macro/src/bindgen.rs b/crates/component-macro/src/bindgen.rs index 9c5b7dc88cb5..0661ad84e552 100644 --- a/crates/component-macro/src/bindgen.rs +++ b/crates/component-macro/src/bindgen.rs @@ -1,14 +1,15 @@ use proc_macro2::{Span, TokenStream}; use quote::ToTokens; -use std::collections::HashMap; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::env; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicUsize, Ordering::Relaxed}; use syn::parse::{Error, Parse, ParseStream, Result}; use syn::punctuated::Punctuated; use syn::{braced, token, Token}; -use wasmtime_wit_bindgen::{AsyncConfig, Opts, Ownership, TrappableError, TrappableImports}; +use wasmtime_wit_bindgen::{ + AsyncConfig, CallStyle, Opts, Ownership, TrappableError, TrappableImports, +}; use wit_parser::{PackageId, Resolve, UnresolvedPackageGroup, WorldId}; pub struct Config { @@ -20,13 +21,22 @@ pub struct Config { } pub fn expand(input: &Config) -> Result { - if !cfg!(feature = "async") && input.opts.async_.maybe_async() { + if let (CallStyle::Async | CallStyle::Concurrent, false) = + (input.opts.call_style(), cfg!(feature = "async")) + { return Err(Error::new( Span::call_site(), "cannot enable async bindings unless `async` crate feature is active", )); } + if input.opts.concurrent_imports && !cfg!(feature = "component-model-async") { + return Err(Error::new( + Span::call_site(), + "cannot enable `concurrent_imports` option unless `component-model-async` crate feature is active", + )); + } + let mut src = match input.opts.generate(&input.resolve, input.world) { Ok(s) => s, Err(e) => return Err(Error::new(Span::call_site(), e.to_string())), @@ -40,7 +50,10 @@ pub fn expand(input: &Config) -> Result { // place a formatted version of the expanded code into a file. This file // will then show up in rustc error messages for any codegen issues and can // be inspected manually. - if input.include_generated_code_from_file || std::env::var("WASMTIME_DEBUG_BINDGEN").is_ok() { + if input.include_generated_code_from_file + || input.opts.debug + || std::env::var("WASMTIME_DEBUG_BINDGEN").is_ok() + { static INVOCATION: AtomicUsize = AtomicUsize::new(0); let root = Path::new(env!("DEBUG_OUTPUT_DIR")); let world_name = &input.resolve.worlds[input.world].name; @@ -107,6 +120,7 @@ impl Parse for Config { } Opt::Tracing(val) => opts.tracing = val, Opt::VerboseTracing(val) => opts.verbose_tracing = val, + Opt::Debug(val) => opts.debug = val, Opt::Async(val, span) => { if async_configured { return Err(Error::new(span, "cannot specify second async config")); @@ -114,6 +128,7 @@ impl Parse for Config { async_configured = true; opts.async_ = val; } + Opt::ConcurrentImports(val) => opts.concurrent_imports = val, Opt::TrappableErrorType(val) => opts.trappable_error_type = val, Opt::TrappableImports(val) => opts.trappable_imports = val, Opt::Ownership(val) => opts.ownership = val, @@ -138,7 +153,7 @@ impl Parse for Config { "cannot specify a world with `interfaces`", )); } - world = Some("interfaces".to_string()); + world = Some("wasmtime:component-macro-synthesized/interfaces".to_string()); opts.only_interfaces = true; } @@ -205,7 +220,7 @@ fn parse_source( }; let (pkg, sources) = resolve.push_path(normalized_path)?; pkgs.push(pkg); - files.extend(sources); + files.extend(sources.package_paths(pkg).unwrap().map(|v| v.to_owned())); } Ok(()) }; @@ -281,6 +296,8 @@ mod kw { syn::custom_keyword!(require_store_data_send); syn::custom_keyword!(wasmtime_crate); syn::custom_keyword!(include_generated_code_from_file); + syn::custom_keyword!(concurrent_imports); + syn::custom_keyword!(debug); } enum Opt { @@ -301,12 +318,18 @@ enum Opt { RequireStoreDataSend(bool), WasmtimeCrate(syn::Path), IncludeGeneratedCodeFromFile(bool), + ConcurrentImports(bool), + Debug(bool), } impl Parse for Opt { fn parse(input: ParseStream<'_>) -> Result { let l = input.lookahead1(); - if l.peek(kw::path) { + if l.peek(kw::debug) { + input.parse::()?; + input.parse::()?; + Ok(Opt::Debug(input.parse::()?.value)) + } else if l.peek(kw::path) { input.parse::()?; input.parse::()?; @@ -380,6 +403,10 @@ impl Parse for Opt { span, )) } + } else if l.peek(kw::concurrent_imports) { + input.parse::()?; + input.parse::()?; + Ok(Opt::ConcurrentImports(input.parse::()?.value)) } else if l.peek(kw::ownership) { input.parse::()?; input.parse::()?; diff --git a/crates/cranelift/Cargo.toml b/crates/cranelift/Cargo.toml index 678457c35eca..6c19418f6bc0 100644 --- a/crates/cranelift/Cargo.toml +++ b/crates/cranelift/Cargo.toml @@ -45,3 +45,4 @@ gc = ["wasmtime-environ/gc"] gc-drc = ["gc", "wasmtime-environ/gc-drc"] gc-null = ["gc", "wasmtime-environ/gc-null"] threads = ["wasmtime-environ/threads"] + diff --git a/crates/cranelift/src/compiler/component.rs b/crates/cranelift/src/compiler/component.rs index cc0b5fe9ef26..adb4836842a4 100644 --- a/crates/cranelift/src/compiler/component.rs +++ b/crates/cranelift/src/compiler/component.rs @@ -2,7 +2,7 @@ use crate::{compiler::Compiler, TRAP_ALWAYS, TRAP_CANNOT_ENTER, TRAP_INTERNAL_ASSERT}; use anyhow::Result; -use cranelift_codegen::ir::{self, InstBuilder, MemFlags}; +use cranelift_codegen::ir::{self, InstBuilder, MemFlags, Value}; use cranelift_codegen::isa::{CallConv, TargetIsa}; use cranelift_frontend::FunctionBuilder; use std::any::Any; @@ -96,19 +96,379 @@ impl<'a> TrampolineCompiler<'a> { Trampoline::ResourceNew(ty) => self.translate_resource_new(*ty), Trampoline::ResourceRep(ty) => self.translate_resource_rep(*ty), Trampoline::ResourceDrop(ty) => self.translate_resource_drop(*ty), + Trampoline::TaskBackpressure { instance } => { + self.translate_task_backpressure_call(*instance) + } + Trampoline::TaskReturn => self.translate_task_return_call(), + Trampoline::TaskWait { async_, memory } => { + self.translate_task_wait_or_poll_call(*async_, *memory, self.offsets.task_wait()) + } + Trampoline::TaskPoll { async_, memory } => { + self.translate_task_wait_or_poll_call(*async_, *memory, self.offsets.task_poll()) + } + Trampoline::TaskYield { async_ } => self.translate_task_yield_call(*async_), + Trampoline::SubtaskDrop { instance } => self.translate_subtask_drop_call(*instance), + Trampoline::StreamNew { ty } => self.translate_future_or_stream_call( + ty.as_u32(), + None, + self.offsets.stream_new(), + Vec::new(), + vec![ir::AbiParam::new(ir::types::I32)], + ), + Trampoline::StreamRead { ty, options } => { + if let Some(info) = self.flat_stream_element_info(*ty) { + self.translate_flat_stream_call( + *ty, + options, + self.offsets.flat_stream_read(), + &info, + ) + } else { + self.translate_future_or_stream_call( + ty.as_u32(), + Some(options), + self.offsets.stream_read(), + vec![ + ir::AbiParam::new(ir::types::I32), + ir::AbiParam::new(ir::types::I32), + ir::AbiParam::new(ir::types::I32), + ], + vec![ir::AbiParam::new(ir::types::I32)], + ) + } + } + Trampoline::StreamWrite { ty, options } => { + if let Some(info) = self.flat_stream_element_info(*ty) { + self.translate_flat_stream_call( + *ty, + options, + self.offsets.flat_stream_write(), + &info, + ) + } else { + self.translate_future_or_stream_call( + ty.as_u32(), + Some(options), + self.offsets.stream_write(), + vec![ + ir::AbiParam::new(ir::types::I32), + ir::AbiParam::new(ir::types::I32), + ir::AbiParam::new(ir::types::I32), + ], + vec![ir::AbiParam::new(ir::types::I32)], + ) + } + } + Trampoline::StreamCancelRead { ty, async_ } => { + self.translate_cancel_call(ty.as_u32(), *async_, self.offsets.stream_cancel_read()) + } + Trampoline::StreamCancelWrite { ty, async_ } => { + self.translate_cancel_call(ty.as_u32(), *async_, self.offsets.stream_cancel_write()) + } + Trampoline::StreamCloseReadable { ty } => self.translate_future_or_stream_call( + ty.as_u32(), + None, + self.offsets.stream_close_readable(), + vec![ir::AbiParam::new(ir::types::I32)], + Vec::new(), + ), + Trampoline::StreamCloseWritable { ty } => self.translate_future_or_stream_call( + ty.as_u32(), + None, + self.offsets.stream_close_writable(), + vec![ + ir::AbiParam::new(ir::types::I32), + ir::AbiParam::new(ir::types::I32), + ], + Vec::new(), + ), + Trampoline::FutureNew { ty } => self.translate_future_or_stream_call( + ty.as_u32(), + None, + self.offsets.future_new(), + Vec::new(), + vec![ir::AbiParam::new(ir::types::I32)], + ), + Trampoline::FutureRead { ty, options } => self.translate_future_or_stream_call( + ty.as_u32(), + Some(&options), + self.offsets.future_read(), + vec![ + ir::AbiParam::new(ir::types::I32), + ir::AbiParam::new(ir::types::I32), + ], + vec![ir::AbiParam::new(ir::types::I32)], + ), + Trampoline::FutureWrite { ty, options } => self.translate_future_or_stream_call( + ty.as_u32(), + Some(options), + self.offsets.future_write(), + vec![ + ir::AbiParam::new(ir::types::I32), + ir::AbiParam::new(ir::types::I32), + ], + vec![ir::AbiParam::new(ir::types::I32)], + ), + Trampoline::FutureCancelRead { ty, async_ } => { + self.translate_cancel_call(ty.as_u32(), *async_, self.offsets.future_cancel_read()) + } + Trampoline::FutureCancelWrite { ty, async_ } => { + self.translate_cancel_call(ty.as_u32(), *async_, self.offsets.future_cancel_write()) + } + Trampoline::FutureCloseReadable { ty } => self.translate_future_or_stream_call( + ty.as_u32(), + None, + self.offsets.future_close_readable(), + vec![ir::AbiParam::new(ir::types::I32)], + Vec::new(), + ), + Trampoline::FutureCloseWritable { ty } => self.translate_future_or_stream_call( + ty.as_u32(), + None, + self.offsets.future_close_writable(), + vec![ + ir::AbiParam::new(ir::types::I32), + ir::AbiParam::new(ir::types::I32), + ], + Vec::new(), + ), + Trampoline::ErrorContextNew { ty, options } => self.translate_error_context_call( + *ty, + options, + self.offsets.error_context_new(), + vec![ + ir::AbiParam::new(ir::types::I32), + ir::AbiParam::new(ir::types::I32), + ], + vec![ir::AbiParam::new(ir::types::I32)], + ), + Trampoline::ErrorContextDebugMessage { ty, options } => self + .translate_error_context_call( + *ty, + options, + self.offsets.error_context_debug_message(), + vec![ + ir::AbiParam::new(ir::types::I32), + ir::AbiParam::new(ir::types::I32), + ], + Vec::new(), + ), + Trampoline::ErrorContextDrop { ty } => self.translate_error_context_drop_call(*ty), Trampoline::ResourceTransferOwn => { - self.translate_resource_libcall(host::resource_transfer_own) + self.translate_host_libcall(host::resource_transfer_own) } Trampoline::ResourceTransferBorrow => { - self.translate_resource_libcall(host::resource_transfer_borrow) + self.translate_host_libcall(host::resource_transfer_borrow) + } + Trampoline::ResourceEnterCall => self.translate_host_libcall(host::resource_enter_call), + Trampoline::ResourceExitCall => self.translate_host_libcall(host::resource_exit_call), + Trampoline::AsyncEnterCall => { + let pointer_type = self.isa.pointer_type(); + self.translate_async_enter_or_exit( + self.offsets.async_enter(), + None, + vec![ + ir::AbiParam::new(pointer_type), + ir::AbiParam::new(pointer_type), + ir::AbiParam::new(ir::types::I32), + ir::AbiParam::new(ir::types::I32), + ir::AbiParam::new(ir::types::I32), + ], + Vec::new(), + ) + } + Trampoline::AsyncExitCall(callback) => { + let pointer_type = self.isa.pointer_type(); + self.translate_async_enter_or_exit( + self.offsets.async_exit(), + Some(*callback), + vec![ + ir::AbiParam::new(pointer_type), + ir::AbiParam::new(ir::types::I32), + ir::AbiParam::new(ir::types::I32), + ir::AbiParam::new(ir::types::I32), + ir::AbiParam::new(ir::types::I32), + ], + vec![ir::AbiParam::new(ir::types::I32)], + ) + } + Trampoline::FutureTransfer => self.translate_host_libcall(host::future_transfer), + Trampoline::StreamTransfer => self.translate_host_libcall(host::stream_transfer), + Trampoline::ErrorContextTransfer => { + self.translate_host_libcall(host::error_context_transfer) + } + } + } + + fn flat_stream_element_info(&self, ty: TypeStreamTableIndex) -> Option { + let payload = self.types[self.types[ty].ty].payload; + match payload { + InterfaceType::Bool + | InterfaceType::S8 + | InterfaceType::U8 + | InterfaceType::S16 + | InterfaceType::U16 + | InterfaceType::S32 + | InterfaceType::U32 + | InterfaceType::S64 + | InterfaceType::U64 + | InterfaceType::Float32 + | InterfaceType::Float64 + | InterfaceType::Char => Some(self.types.canonical_abi(&payload).clone()), + // TODO: Recursively check for other "flat" types (i.e. those without pointers or handles), + // e.g. `record`s, `variant`s, etc. which contain only flat types. + _ => None, + } + } + + fn store_wasm_arguments(&mut self, args: &[Value]) -> (Value, Value) { + let pointer_type = self.isa.pointer_type(); + let wasm_func_ty = &self.types[self.signature].unwrap_func(); + + // Start off by spilling all the wasm arguments into a stack slot to be + // passed to the host function. + match self.abi { + Abi::Wasm => { + let (ptr, len) = self.compiler.allocate_stack_array_and_spill_args( + wasm_func_ty, + &mut self.builder, + args, + ); + let len = self.builder.ins().iconst(pointer_type, i64::from(len)); + (ptr, len) + } + Abi::Array => { + let params = self.builder.func.dfg.block_params(self.block0); + (params[2], params[3]) + } + } + } + + fn load_wasm_results(&mut self, values_vec_ptr: Value, values_vec_len: Value) { + let wasm_func_ty = &self.types[self.signature].unwrap_func(); + + match self.abi { + Abi::Wasm => { + // After the host function has returned the results are loaded from + // `values_vec_ptr` and then returned. + let results = self.compiler.load_values_from_array( + wasm_func_ty.returns(), + &mut self.builder, + values_vec_ptr, + values_vec_len, + ); + self.builder.ins().return_(&results); + } + Abi::Array => { + self.builder.ins().return_(&[]); } - Trampoline::ResourceEnterCall => { - self.translate_resource_libcall(host::resource_enter_call) + } + } + + fn translate_task_return_call(&mut self) { + let pointer_type = self.isa.pointer_type(); + let args = self.builder.func.dfg.block_params(self.block0).to_vec(); + let vmctx = args[0]; + + let (values_vec_ptr, values_vec_len) = self.store_wasm_arguments(&args[2..]); + + let mut callee_args = Vec::new(); + let mut host_sig = ir::Signature::new(CallConv::triple_default(self.isa.triple())); + + // vmctx: *mut VMComponentContext + host_sig.params.push(ir::AbiParam::new(pointer_type)); + callee_args.push(vmctx); + + // storage: *mut ValRaw + host_sig.params.push(ir::AbiParam::new(pointer_type)); + callee_args.push(values_vec_ptr); + + // storage_len: usize + host_sig.params.push(ir::AbiParam::new(pointer_type)); + callee_args.push(values_vec_len); + + // Load host function pointer from the vmcontext and then call that + // indirect function pointer with the list of arguments. + let host_fn = self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.task_return()).unwrap(), + ); + let host_sig = self.builder.import_signature(host_sig); + self.builder + .ins() + .call_indirect(host_sig, host_fn, &callee_args); + + self.load_wasm_results(values_vec_ptr, values_vec_len); + } + + fn translate_async_enter_or_exit( + &mut self, + offset: u32, + callback: Option>, + params: Vec, + results: Vec, + ) { + match self.abi { + Abi::Wasm => {} + + // These trampolines can only actually be called by Wasm, so + // let's assert that here. + Abi::Array => { + self.builder.ins().trap(TRAP_INTERNAL_ASSERT); + return; } - Trampoline::ResourceExitCall => { - self.translate_resource_libcall(host::resource_exit_call) + } + + let pointer_type = self.isa.pointer_type(); + let args = self.builder.func.dfg.block_params(self.block0).to_vec(); + let vmctx = args[0]; + + let mut callee_args = Vec::new(); + let mut host_sig = ir::Signature::new(CallConv::triple_default(self.isa.triple())); + + // vmctx: *mut VMComponentContext + host_sig.params.push(ir::AbiParam::new(pointer_type)); + callee_args.push(vmctx); + + if let Some(callback) = callback { + // callback: *mut VMFuncRef + host_sig.params.push(ir::AbiParam::new(pointer_type)); + if let Some(callback) = callback { + callee_args.push(self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.runtime_callback(callback)).unwrap(), + )); + } else { + callee_args.push(self.builder.ins().iconst(pointer_type, 0)); } } + + // remaining parameters + host_sig.params.extend(params); + callee_args.extend(args[2..].iter().copied()); + + host_sig.returns.extend(results); + + // Load host function pointer from the vmcontext and then call that + // indirect function pointer with the list of arguments. + let host_fn = self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(offset).unwrap(), + ); + let host_sig = self.builder.import_signature(host_sig); + let call = self + .builder + .ins() + .call_indirect(host_sig, host_fn, &callee_args); + + let results = self.builder.func.dfg.inst_results(call).to_vec(); + self.builder.ins().return_(&results); } fn translate_lower_import( @@ -150,10 +510,14 @@ impl<'a> TrampolineCompiler<'a> { instance, memory, realloc, + callback, post_return, string_encoding, + async_, } = *options; + assert!(callback.is_none()); + // vmctx: *mut VMComponentContext host_sig.params.push(ir::AbiParam::new(pointer_type)); callee_args.push(vmctx); @@ -175,6 +539,14 @@ impl<'a> TrampolineCompiler<'a> { .iconst(ir::types::I32, i64::from(lower_ty.as_u32())), ); + // caller_instance: RuntimeComponentInstanceIndex + host_sig.params.push(ir::AbiParam::new(ir::types::I32)); + callee_args.push( + self.builder + .ins() + .iconst(ir::types::I32, i64::from(instance.as_u32())), + ); + // flags: *mut VMGlobalDefinition host_sig.params.push(ir::AbiParam::new(pointer_type)); callee_args.push( @@ -220,6 +592,14 @@ impl<'a> TrampolineCompiler<'a> { .iconst(ir::types::I8, i64::from(string_encoding as u8)), ); + // async_: bool + host_sig.params.push(ir::AbiParam::new(ir::types::I8)); + callee_args.push( + self.builder + .ins() + .iconst(ir::types::I8, if async_ { 1 } else { 0 }), + ); + // storage: *mut ValRaw host_sig.params.push(ir::AbiParam::new(pointer_type)); callee_args.push(values_vec_ptr); @@ -241,22 +621,7 @@ impl<'a> TrampolineCompiler<'a> { .ins() .call_indirect(host_sig, host_fn, &callee_args); - match self.abi { - Abi::Wasm => { - // After the host function has returned the results are loaded from - // `values_vec_ptr` and then returned. - let results = self.compiler.load_values_from_array( - wasm_func_ty.returns(), - &mut self.builder, - values_vec_ptr, - values_vec_len, - ); - self.builder.ins().return_(&results); - } - Abi::Array => { - self.builder.ins().return_(&[]); - } - } + self.load_wasm_results(values_vec_ptr, values_vec_len); } fn translate_always_trap(&mut self) { @@ -535,7 +900,7 @@ impl<'a> TrampolineCompiler<'a> { /// /// Only intended for simple trampolines and effectively acts as a bridge /// from the wasm abi to host. - fn translate_resource_libcall( + fn translate_host_libcall( &mut self, get_libcall: fn(&dyn TargetIsa, &mut ir::Function) -> (ir::SigRef, u32), ) { @@ -564,6 +929,573 @@ impl<'a> TrampolineCompiler<'a> { self.builder.ins().return_(&results); } + fn translate_cancel_call(&mut self, ty: u32, async_: bool, offset: u32) { + match self.abi { + Abi::Wasm => {} + + // These trampolines can only actually be called by Wasm, so + // let's assert that here. + Abi::Array => { + self.builder.ins().trap(TRAP_INTERNAL_ASSERT); + return; + } + } + + let pointer_type = self.isa.pointer_type(); + let args = self.builder.func.dfg.block_params(self.block0).to_vec(); + let vmctx = args[0]; + let mut host_args = vec![vmctx]; + let mut host_sig = ir::Signature::new(CallConv::triple_default(self.isa.triple())); + host_sig.params.push(ir::AbiParam::new(pointer_type)); + + host_sig.params.push(ir::AbiParam::new(ir::types::I32)); + host_args.push(self.builder.ins().iconst(ir::types::I32, i64::from(ty))); + + // async_: bool + host_sig.params.push(ir::AbiParam::new(ir::types::I8)); + host_args.push( + self.builder + .ins() + .iconst(ir::types::I8, if async_ { 1 } else { 0 }), + ); + + // Load host function pointer from the vmcontext and then call that + // indirect function pointer with the list of arguments. + let host_fn = self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(offset).unwrap(), + ); + + host_sig.params.push(ir::AbiParam::new(ir::types::I32)); + host_args.extend(args[2..].iter().copied()); + + host_sig.returns.push(ir::AbiParam::new(ir::types::I32)); + + let host_sig = self.builder.import_signature(host_sig); + let call = self + .builder + .ins() + .call_indirect(host_sig, host_fn, &host_args); + + let results = self.builder.func.dfg.inst_results(call).to_vec(); + self.builder.ins().return_(&results); + } + + fn translate_future_or_stream_call( + &mut self, + ty: u32, + options: Option<&CanonicalOptions>, + offset: u32, + params: Vec, + results: Vec, + ) { + match self.abi { + Abi::Wasm => {} + + // These trampolines can only actually be called by Wasm, so + // let's assert that here. + Abi::Array => { + self.builder.ins().trap(TRAP_INTERNAL_ASSERT); + return; + } + } + + let pointer_type = self.isa.pointer_type(); + let args = self.builder.func.dfg.block_params(self.block0).to_vec(); + let vmctx = args[0]; + let mut host_args = vec![vmctx]; + let mut host_sig = ir::Signature::new(CallConv::triple_default(self.isa.triple())); + host_sig.params.push(ir::AbiParam::new(pointer_type)); + + if let Some(options) = options { + // memory: *mut VMMemoryDefinition + host_sig.params.push(ir::AbiParam::new(pointer_type)); + host_args.push(self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.runtime_memory(options.memory.unwrap())).unwrap(), + )); + + // realloc: *mut VMFuncRef + host_sig.params.push(ir::AbiParam::new(pointer_type)); + host_args.push(match options.realloc { + Some(idx) => self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.runtime_realloc(idx)).unwrap(), + ), + None => self.builder.ins().iconst(pointer_type, 0), + }); + + // string_encoding: StringEncoding + host_sig.params.push(ir::AbiParam::new(ir::types::I8)); + host_args.push( + self.builder + .ins() + .iconst(ir::types::I8, i64::from(options.string_encoding as u8)), + ); + } + + host_sig.params.push(ir::AbiParam::new(ir::types::I32)); + host_args.push(self.builder.ins().iconst(ir::types::I32, i64::from(ty))); + + host_sig.params.extend(params); + host_args.extend(args[2..].iter().copied()); + + host_sig.returns.extend(results); + + // Load host function pointer from the vmcontext and then call that + // indirect function pointer with the list of arguments. + let host_fn = self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(offset).unwrap(), + ); + let host_sig = self.builder.import_signature(host_sig); + let call = self + .builder + .ins() + .call_indirect(host_sig, host_fn, &host_args); + + let results = self.builder.func.dfg.inst_results(call).to_vec(); + self.builder.ins().return_(&results); + } + + fn translate_flat_stream_call( + &mut self, + ty: TypeStreamTableIndex, + options: &CanonicalOptions, + offset: u32, + info: &CanonicalAbiInfo, + ) { + match self.abi { + Abi::Wasm => {} + + // These trampolines can only actually be called by Wasm, so + // let's assert that here. + Abi::Array => { + self.builder.ins().trap(TRAP_INTERNAL_ASSERT); + return; + } + } + + let pointer_type = self.isa.pointer_type(); + let args = self.builder.func.dfg.block_params(self.block0).to_vec(); + let vmctx = args[0]; + let mut host_args = vec![vmctx]; + let mut host_sig = ir::Signature::new(CallConv::triple_default(self.isa.triple())); + host_sig.params.push(ir::AbiParam::new(pointer_type)); + + // memory: *mut VMMemoryDefinition + host_sig.params.push(ir::AbiParam::new(pointer_type)); + host_args.push(self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.runtime_memory(options.memory.unwrap())).unwrap(), + )); + + // realloc: *mut VMFuncRef + host_sig.params.push(ir::AbiParam::new(pointer_type)); + host_args.push(match options.realloc { + Some(idx) => self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.runtime_realloc(idx)).unwrap(), + ), + None => self.builder.ins().iconst(pointer_type, 0), + }); + + host_sig.params.push(ir::AbiParam::new(ir::types::I32)); + host_args.push( + self.builder + .ins() + .iconst(ir::types::I32, i64::from(ty.as_u32())), + ); + + host_sig.params.push(ir::AbiParam::new(ir::types::I32)); + host_args.push( + self.builder + .ins() + .iconst(ir::types::I32, i64::from(info.size32)), + ); + + host_sig.params.push(ir::AbiParam::new(ir::types::I32)); + host_args.push( + self.builder + .ins() + .iconst(ir::types::I32, i64::from(info.align32)), + ); + + host_sig.params.extend(vec![ + ir::AbiParam::new(ir::types::I32), + ir::AbiParam::new(ir::types::I32), + ir::AbiParam::new(ir::types::I32), + ]); + host_args.extend(args[2..].iter().copied()); + + host_sig + .returns + .extend(vec![ir::AbiParam::new(ir::types::I32)]); + + // Load host function pointer from the vmcontext and then call that + // indirect function pointer with the list of arguments. + let host_fn = self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(offset).unwrap(), + ); + let host_sig = self.builder.import_signature(host_sig); + let call = self + .builder + .ins() + .call_indirect(host_sig, host_fn, &host_args); + + let results = self.builder.func.dfg.inst_results(call).to_vec(); + self.builder.ins().return_(&results); + } + + fn translate_error_context_call( + &mut self, + ty: TypeErrorContextTableIndex, + options: &CanonicalOptions, + offset: u32, + params: Vec, + results: Vec, + ) { + match self.abi { + Abi::Wasm => {} + + // These trampolines can only actually be called by Wasm, so + // let's assert that here. + Abi::Array => { + self.builder.ins().trap(TRAP_INTERNAL_ASSERT); + return; + } + } + + let pointer_type = self.isa.pointer_type(); + let args = self.builder.func.dfg.block_params(self.block0).to_vec(); + let vmctx = args[0]; + let mut host_args = vec![vmctx]; + let mut host_sig = ir::Signature::new(CallConv::triple_default(self.isa.triple())); + host_sig.params.push(ir::AbiParam::new(pointer_type)); + + // memory: *mut VMMemoryDefinition + host_sig.params.push(ir::AbiParam::new(pointer_type)); + host_args.push(self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.runtime_memory(options.memory.unwrap())).unwrap(), + )); + + // realloc: *mut VMFuncRef + host_sig.params.push(ir::AbiParam::new(pointer_type)); + host_args.push(match options.realloc { + Some(idx) => self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.runtime_realloc(idx)).unwrap(), + ), + None => self.builder.ins().iconst(pointer_type, 0), + }); + + // string_encoding: StringEncoding + host_sig.params.push(ir::AbiParam::new(ir::types::I8)); + host_args.push( + self.builder + .ins() + .iconst(ir::types::I8, i64::from(options.string_encoding as u8)), + ); + + host_sig.params.push(ir::AbiParam::new(ir::types::I32)); + host_args.push( + self.builder + .ins() + .iconst(ir::types::I32, i64::from(ty.as_u32())), + ); + + host_sig.params.extend(params); + host_args.extend(args[2..].iter().copied()); + + host_sig.returns.extend(results); + + // Load host function pointer from the vmcontext and then call that + // indirect function pointer with the list of arguments. + let host_fn = self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(offset).unwrap(), + ); + let host_sig = self.builder.import_signature(host_sig); + let call = self + .builder + .ins() + .call_indirect(host_sig, host_fn, &host_args); + + let results = self.builder.func.dfg.inst_results(call).to_vec(); + self.builder.ins().return_(&results); + } + + fn translate_error_context_drop_call(&mut self, ty: TypeErrorContextTableIndex) { + match self.abi { + Abi::Wasm => {} + + // These trampolines can only actually be called by Wasm, so + // let's assert that here. + Abi::Array => { + self.builder.ins().trap(TRAP_INTERNAL_ASSERT); + return; + } + } + + let pointer_type = self.isa.pointer_type(); + let args = self.builder.func.dfg.block_params(self.block0).to_vec(); + let vmctx = args[0]; + let mut host_args = vec![vmctx]; + let mut host_sig = ir::Signature::new(CallConv::triple_default(self.isa.triple())); + host_sig.params.push(ir::AbiParam::new(pointer_type)); + + host_sig.params.push(ir::AbiParam::new(ir::types::I32)); + host_args.push( + self.builder + .ins() + .iconst(ir::types::I32, i64::from(ty.as_u32())), + ); + + host_sig.params.push(ir::AbiParam::new(ir::types::I32)); + host_args.extend(args[2..].iter().copied()); + + // Load host function pointer from the vmcontext and then call that + // indirect function pointer with the list of arguments. + let offset = self.offsets.error_context_drop(); + let host_fn = self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(offset).unwrap(), + ); + let host_sig = self.builder.import_signature(host_sig); + let call = self + .builder + .ins() + .call_indirect(host_sig, host_fn, &host_args); + + let results = self.builder.func.dfg.inst_results(call).to_vec(); + self.builder.ins().return_(&results); + } + + fn translate_task_backpressure_call(&mut self, caller_instance: RuntimeComponentInstanceIndex) { + match self.abi { + Abi::Wasm => {} + + // These trampolines can only actually be called by Wasm, so + // let's assert that here. + Abi::Array => { + self.builder.ins().trap(TRAP_INTERNAL_ASSERT); + return; + } + } + + let pointer_type = self.isa.pointer_type(); + let args = self.builder.func.dfg.block_params(self.block0).to_vec(); + let vmctx = args[0]; + let mut host_args = vec![vmctx]; + let mut host_sig = ir::Signature::new(CallConv::triple_default(self.isa.triple())); + host_sig.params.push(ir::AbiParam::new(pointer_type)); + + host_sig.params.push(ir::AbiParam::new(ir::types::I32)); + host_args.push( + self.builder + .ins() + .iconst(ir::types::I32, i64::from(caller_instance.as_u32())), + ); + + host_sig.params.push(ir::AbiParam::new(ir::types::I32)); + host_args.extend(args[2..].iter().copied()); + + // Load host function pointer from the vmcontext and then call that + // indirect function pointer with the list of arguments. + let host_fn = self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.task_backpressure()).unwrap(), + ); + let host_sig = self.builder.import_signature(host_sig); + let call = self + .builder + .ins() + .call_indirect(host_sig, host_fn, &host_args); + + let results = self.builder.func.dfg.inst_results(call).to_vec(); + self.builder.ins().return_(&results); + } + + fn translate_subtask_drop_call(&mut self, caller_instance: RuntimeComponentInstanceIndex) { + match self.abi { + Abi::Wasm => {} + + // These trampolines can only actually be called by Wasm, so + // let's assert that here. + Abi::Array => { + self.builder.ins().trap(TRAP_INTERNAL_ASSERT); + return; + } + } + + let pointer_type = self.isa.pointer_type(); + let args = self.builder.func.dfg.block_params(self.block0).to_vec(); + let vmctx = args[0]; + let mut host_args = vec![vmctx]; + let mut host_sig = ir::Signature::new(CallConv::triple_default(self.isa.triple())); + host_sig.params.push(ir::AbiParam::new(pointer_type)); + + host_sig.params.push(ir::AbiParam::new(ir::types::I32)); + host_args.push( + self.builder + .ins() + .iconst(ir::types::I32, i64::from(caller_instance.as_u32())), + ); + + host_sig.params.push(ir::AbiParam::new(ir::types::I32)); + host_args.extend(args[2..].iter().copied()); + + // Load host function pointer from the vmcontext and then call that + // indirect function pointer with the list of arguments. + let host_fn = self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.subtask_drop()).unwrap(), + ); + let host_sig = self.builder.import_signature(host_sig); + let call = self + .builder + .ins() + .call_indirect(host_sig, host_fn, &host_args); + + let results = self.builder.func.dfg.inst_results(call).to_vec(); + self.builder.ins().return_(&results); + } + + fn translate_task_wait_or_poll_call( + &mut self, + async_: bool, + memory: RuntimeMemoryIndex, + offset: u32, + ) { + match self.abi { + Abi::Wasm => {} + + // These trampolines can only actually be called by Wasm, so + // let's assert that here. + Abi::Array => { + self.builder.ins().trap(TRAP_INTERNAL_ASSERT); + return; + } + } + + let pointer_type = self.isa.pointer_type(); + let args = self.builder.func.dfg.block_params(self.block0).to_vec(); + let vmctx = args[0]; + let mut host_args = vec![vmctx]; + let mut host_sig = ir::Signature::new(CallConv::triple_default(self.isa.triple())); + host_sig.params.push(ir::AbiParam::new(pointer_type)); + + host_sig.params.push(ir::AbiParam::new(ir::types::I8)); + host_args.push( + self.builder + .ins() + .iconst(ir::types::I8, if async_ { 1 } else { 0 }), + ); + + // memory: *mut VMMemoryDefinition + host_sig.params.push(ir::AbiParam::new(pointer_type)); + host_args.push(self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.runtime_memory(memory)).unwrap(), + )); + + host_sig.params.push(ir::AbiParam::new(ir::types::I32)); + host_args.extend(args[2..].iter().copied()); + + host_sig.returns.push(ir::AbiParam::new(ir::types::I32)); + + // Load host function pointer from the vmcontext and then call that + // indirect function pointer with the list of arguments. + let host_fn = self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(offset).unwrap(), + ); + let host_sig = self.builder.import_signature(host_sig); + let call = self + .builder + .ins() + .call_indirect(host_sig, host_fn, &host_args); + + let results = self.builder.func.dfg.inst_results(call).to_vec(); + self.builder.ins().return_(&results); + } + + fn translate_task_yield_call(&mut self, async_: bool) { + match self.abi { + Abi::Wasm => {} + + // These trampolines can only actually be called by Wasm, so + // let's assert that here. + Abi::Array => { + self.builder.ins().trap(TRAP_INTERNAL_ASSERT); + return; + } + } + + let pointer_type = self.isa.pointer_type(); + let args = self.builder.func.dfg.block_params(self.block0).to_vec(); + let vmctx = args[0]; + let mut host_args = vec![vmctx]; + let mut host_sig = ir::Signature::new(CallConv::triple_default(self.isa.triple())); + host_sig.params.push(ir::AbiParam::new(pointer_type)); + + host_sig.params.push(ir::AbiParam::new(ir::types::I8)); + host_args.push( + self.builder + .ins() + .iconst(ir::types::I8, if async_ { 1 } else { 0 }), + ); + + // Load host function pointer from the vmcontext and then call that + // indirect function pointer with the list of arguments. + let host_fn = self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.task_yield()).unwrap(), + ); + let host_sig = self.builder.import_signature(host_sig); + let call = self + .builder + .ins() + .call_indirect(host_sig, host_fn, &host_args); + + let results = self.builder.func.dfg.inst_results(call).to_vec(); + self.builder.ins().return_(&results); + } + /// Loads a host function pointer for a libcall stored at the `offset` /// provided in the libcalls array. /// diff --git a/crates/environ/examples/factc.rs b/crates/environ/examples/factc.rs index 2c692a926465..d1f94586f4fe 100644 --- a/crates/environ/examples/factc.rs +++ b/crates/environ/examples/factc.rs @@ -147,6 +147,8 @@ impl Factc { realloc: Some(dummy_def()), // Lowering never allows `post-return` post_return: None, + async_: false, + callback: None, }, lift_options: AdapterOptions { instance: RuntimeComponentInstanceIndex::from_u32(1), @@ -159,6 +161,8 @@ impl Factc { } else { None }, + async_: false, + callback: None, }, func: dummy_def(), }); diff --git a/crates/environ/src/component.rs b/crates/environ/src/component.rs index 4b72baa71322..d8fa8707b6c8 100644 --- a/crates/environ/src/component.rs +++ b/crates/environ/src/component.rs @@ -104,8 +104,10 @@ macro_rules! foreach_builtin_component_function { resource_transfer_borrow(vmctx: vmctx, src_idx: u32, src_table: u32, dst_table: u32) -> u32; resource_enter_call(vmctx: vmctx); resource_exit_call(vmctx: vmctx); - trap(vmctx: vmctx, code: u8); + future_transfer(vmctx: vmctx, src_idx: u32, src_table: u32, dst_table: u32) -> u32; + stream_transfer(vmctx: vmctx, src_idx: u32, src_table: u32, dst_table: u32) -> u32; + error_context_transfer(vmctx: vmctx, src_idx: u32, src_table: u32, dst_table: u32) -> u32; } }; } diff --git a/crates/environ/src/component/dfg.rs b/crates/environ/src/component/dfg.rs index a2a78c5643ad..fa9c8fc996be 100644 --- a/crates/environ/src/component/dfg.rs +++ b/crates/environ/src/component/dfg.rs @@ -57,6 +57,8 @@ pub struct ComponentDfg { /// used by the host) pub reallocs: Intern, + pub callbacks: Intern, + /// Same as `reallocs`, but for post-return. pub post_returns: Intern, @@ -103,7 +105,7 @@ pub struct ComponentDfg { /// The values here are the module that the adapter is present within along /// as the core wasm index of the export corresponding to the lowered /// version of the adapter. - pub adapter_paritionings: PrimaryMap, + pub adapter_partitionings: PrimaryMap, /// Defined resources in this component sorted by index with metadata about /// each resource. @@ -163,6 +165,7 @@ id! { pub struct InstanceId(u32); pub struct MemoryId(u32); pub struct ReallocId(u32); + pub struct CallbackId(u32); pub struct AdapterId(u32); pub struct PostReturnId(u32); pub struct AdapterModuleId(u32); @@ -211,7 +214,7 @@ pub enum CoreDef { /// This is a special variant not present in `info::CoreDef` which /// represents that this definition refers to a fused adapter function. This /// adapter is fully processed after the initial translation and - /// identificatino of adapters. + /// identification of adapters. /// /// During translation into `info::CoreDef` this variant is erased and /// replaced by `info::CoreDef::Export` since adapters are always @@ -269,10 +272,108 @@ pub enum Trampoline { ResourceNew(TypeResourceTableIndex), ResourceRep(TypeResourceTableIndex), ResourceDrop(TypeResourceTableIndex), + TaskBackpressure { + instance: RuntimeComponentInstanceIndex, + }, + TaskReturn, + TaskWait { + async_: bool, + memory: MemoryId, + }, + TaskPoll { + async_: bool, + memory: MemoryId, + }, + TaskYield { + async_: bool, + }, + SubtaskDrop { + instance: RuntimeComponentInstanceIndex, + }, + StreamNew { + ty: TypeStreamTableIndex, + }, + StreamRead { + ty: TypeStreamTableIndex, + options: CanonicalOptions, + }, + StreamWrite { + ty: TypeStreamTableIndex, + options: CanonicalOptions, + }, + StreamCancelRead { + ty: TypeStreamTableIndex, + async_: bool, + }, + StreamCancelWrite { + ty: TypeStreamTableIndex, + async_: bool, + }, + StreamCloseReadable { + ty: TypeStreamTableIndex, + }, + StreamCloseWritable { + ty: TypeStreamTableIndex, + }, + FutureNew { + ty: TypeFutureTableIndex, + }, + FutureRead { + ty: TypeFutureTableIndex, + options: CanonicalOptions, + }, + FutureWrite { + ty: TypeFutureTableIndex, + options: CanonicalOptions, + }, + FutureCancelRead { + ty: TypeFutureTableIndex, + async_: bool, + }, + FutureCancelWrite { + ty: TypeFutureTableIndex, + async_: bool, + }, + FutureCloseReadable { + ty: TypeFutureTableIndex, + }, + FutureCloseWritable { + ty: TypeFutureTableIndex, + }, + ErrorContextNew { + ty: TypeErrorContextTableIndex, + options: CanonicalOptions, + }, + ErrorContextDebugMessage { + ty: TypeErrorContextTableIndex, + options: CanonicalOptions, + }, + ErrorContextDrop { + ty: TypeErrorContextTableIndex, + }, ResourceTransferOwn, ResourceTransferBorrow, ResourceEnterCall, ResourceExitCall, + AsyncEnterCall, + AsyncExitCall(Option), + FutureTransfer, + StreamTransfer, + ErrorContextTransfer, +} + +#[derive(Copy, Clone, Hash, Eq, PartialEq)] +#[allow(missing_docs)] +pub struct FutureInfo { + pub instance: RuntimeComponentInstanceIndex, + pub payload_type: Option, +} + +#[derive(Copy, Clone, Hash, Eq, PartialEq)] +#[allow(missing_docs)] +pub struct StreamInfo { + pub instance: RuntimeComponentInstanceIndex, + pub payload_type: InterfaceType, } /// Same as `info::CanonicalOptions` @@ -283,7 +384,9 @@ pub struct CanonicalOptions { pub string_encoding: StringEncoding, pub memory: Option, pub realloc: Option, + pub callback: Option, pub post_return: Option, + pub async_: bool, } /// Same as `info::Resource` @@ -348,7 +451,7 @@ impl Default for Intern { impl ComponentDfg { /// Consumes the intermediate `ComponentDfg` to produce a final `Component` - /// with a linear innitializer list. + /// with a linear initializer list. pub fn finish( self, wasmtime_types: &mut ComponentTypesBuilder, @@ -360,6 +463,7 @@ impl ComponentDfg { runtime_memories: Default::default(), runtime_post_return: Default::default(), runtime_reallocs: Default::default(), + runtime_callbacks: Default::default(), runtime_instances: Default::default(), num_lowerings: 0, trampolines: Default::default(), @@ -400,6 +504,7 @@ impl ComponentDfg { num_runtime_memories: linearize.runtime_memories.len() as u32, num_runtime_post_returns: linearize.runtime_post_return.len() as u32, num_runtime_reallocs: linearize.runtime_reallocs.len() as u32, + num_runtime_callbacks: linearize.runtime_callbacks.len() as u32, num_runtime_instances: linearize.runtime_instances.len() as u32, imports: self.imports, import_types: self.import_types, @@ -431,6 +536,7 @@ struct LinearizeDfg<'a> { trampoline_map: HashMap, runtime_memories: HashMap, runtime_reallocs: HashMap, + runtime_callbacks: HashMap, runtime_post_return: HashMap, runtime_instances: HashMap, num_lowerings: u32, @@ -539,13 +645,16 @@ impl LinearizeDfg<'_> { fn options(&mut self, options: &CanonicalOptions) -> info::CanonicalOptions { let memory = options.memory.map(|mem| self.runtime_memory(mem)); let realloc = options.realloc.map(|mem| self.runtime_realloc(mem)); + let callback = options.callback.map(|mem| self.runtime_callback(mem)); let post_return = options.post_return.map(|mem| self.runtime_post_return(mem)); info::CanonicalOptions { instance: options.instance, string_encoding: options.string_encoding, memory, realloc, + callback, post_return, + async_: options.async_, } } @@ -567,6 +676,15 @@ impl LinearizeDfg<'_> { ) } + fn runtime_callback(&mut self, callback: CallbackId) -> RuntimeCallbackIndex { + self.intern( + callback, + |me| &mut me.runtime_callbacks, + |me, callback| me.core_def(&me.dfg.callbacks[callback]), + |index, def| GlobalInitializer::ExtractCallback(ExtractCallback { index, def }), + ) + } + fn runtime_post_return(&mut self, post_return: PostReturnId) -> RuntimePostReturnIndex { self.intern( post_return, @@ -625,10 +743,90 @@ impl LinearizeDfg<'_> { Trampoline::ResourceNew(ty) => info::Trampoline::ResourceNew(*ty), Trampoline::ResourceDrop(ty) => info::Trampoline::ResourceDrop(*ty), Trampoline::ResourceRep(ty) => info::Trampoline::ResourceRep(*ty), + Trampoline::TaskBackpressure { instance } => info::Trampoline::TaskBackpressure { + instance: *instance, + }, + Trampoline::TaskReturn => info::Trampoline::TaskReturn, + Trampoline::TaskWait { async_, memory } => info::Trampoline::TaskWait { + async_: *async_, + memory: self.runtime_memory(*memory), + }, + Trampoline::TaskPoll { async_, memory } => info::Trampoline::TaskPoll { + async_: *async_, + memory: self.runtime_memory(*memory), + }, + Trampoline::TaskYield { async_ } => info::Trampoline::TaskYield { async_: *async_ }, + Trampoline::SubtaskDrop { instance } => info::Trampoline::SubtaskDrop { + instance: *instance, + }, + Trampoline::StreamNew { ty } => info::Trampoline::StreamNew { ty: *ty }, + Trampoline::StreamRead { ty, options } => info::Trampoline::StreamRead { + ty: *ty, + options: self.options(options), + }, + Trampoline::StreamWrite { ty, options } => info::Trampoline::StreamWrite { + ty: *ty, + options: self.options(options), + }, + Trampoline::StreamCancelRead { ty, async_ } => info::Trampoline::StreamCancelRead { + ty: *ty, + async_: *async_, + }, + Trampoline::StreamCancelWrite { ty, async_ } => info::Trampoline::StreamCancelWrite { + ty: *ty, + async_: *async_, + }, + Trampoline::StreamCloseReadable { ty } => { + info::Trampoline::StreamCloseReadable { ty: *ty } + } + Trampoline::StreamCloseWritable { ty } => { + info::Trampoline::StreamCloseWritable { ty: *ty } + } + Trampoline::FutureNew { ty } => info::Trampoline::FutureNew { ty: *ty }, + Trampoline::FutureRead { ty, options } => info::Trampoline::FutureRead { + ty: *ty, + options: self.options(options), + }, + Trampoline::FutureWrite { ty, options } => info::Trampoline::FutureWrite { + ty: *ty, + options: self.options(options), + }, + Trampoline::FutureCancelRead { ty, async_ } => info::Trampoline::FutureCancelRead { + ty: *ty, + async_: *async_, + }, + Trampoline::FutureCancelWrite { ty, async_ } => info::Trampoline::FutureCancelWrite { + ty: *ty, + async_: *async_, + }, + Trampoline::FutureCloseReadable { ty } => { + info::Trampoline::FutureCloseReadable { ty: *ty } + } + Trampoline::FutureCloseWritable { ty } => { + info::Trampoline::FutureCloseWritable { ty: *ty } + } + Trampoline::ErrorContextNew { ty, options } => info::Trampoline::ErrorContextNew { + ty: *ty, + options: self.options(options), + }, + Trampoline::ErrorContextDebugMessage { ty, options } => { + info::Trampoline::ErrorContextDebugMessage { + ty: *ty, + options: self.options(options), + } + } + Trampoline::ErrorContextDrop { ty } => info::Trampoline::ErrorContextDrop { ty: *ty }, Trampoline::ResourceTransferOwn => info::Trampoline::ResourceTransferOwn, Trampoline::ResourceTransferBorrow => info::Trampoline::ResourceTransferBorrow, Trampoline::ResourceEnterCall => info::Trampoline::ResourceEnterCall, Trampoline::ResourceExitCall => info::Trampoline::ResourceExitCall, + Trampoline::AsyncEnterCall => info::Trampoline::AsyncEnterCall, + Trampoline::AsyncExitCall(callback) => { + info::Trampoline::AsyncExitCall(callback.map(|v| self.runtime_callback(v))) + } + Trampoline::FutureTransfer => info::Trampoline::FutureTransfer, + Trampoline::StreamTransfer => info::Trampoline::StreamTransfer, + Trampoline::ErrorContextTransfer => info::Trampoline::ErrorContextTransfer, }; let i1 = self.trampolines.push(*signature); let i2 = self.trampoline_defs.push(trampoline); @@ -650,7 +848,7 @@ impl LinearizeDfg<'_> { } fn adapter(&mut self, adapter: AdapterId) -> info::CoreExport { - let (adapter_module, entity_index) = self.dfg.adapter_paritionings[adapter]; + let (adapter_module, entity_index) = self.dfg.adapter_partitionings[adapter]; // Instantiates the adapter module if it hasn't already been // instantiated or otherwise returns the index that the module was diff --git a/crates/environ/src/component/info.rs b/crates/environ/src/component/info.rs index 6fb367cca92a..5522653d1b7b 100644 --- a/crates/environ/src/component/info.rs +++ b/crates/environ/src/component/info.rs @@ -148,6 +148,10 @@ pub struct Component { /// `VMComponentContext`. pub num_runtime_reallocs: u32, + /// The number of runtime async callbacks (maximum `RuntimeCallbackIndex`) + /// needed to instantiate this component. + pub num_runtime_callbacks: u32, + /// Same as `num_runtime_reallocs`, but for post-return functions. pub num_runtime_post_returns: u32, @@ -252,6 +256,10 @@ pub enum GlobalInitializer { /// used as a `realloc` function. ExtractRealloc(ExtractRealloc), + /// Same as `ExtractMemory`, except it's extracting a function pointer to be + /// used as an async `callback` function. + ExtractCallback(ExtractCallback), + /// Same as `ExtractMemory`, except it's extracting a function pointer to be /// used as a `post-return` function. ExtractPostReturn(ExtractPostReturn), @@ -281,6 +289,15 @@ pub struct ExtractRealloc { pub def: CoreDef, } +/// Same as `ExtractMemory` but for the `callback` canonical option. +#[derive(Debug, Serialize, Deserialize)] +pub struct ExtractCallback { + /// The index of the callback being defined. + pub index: RuntimeCallbackIndex, + /// Where this callback is being extracted from. + pub def: CoreDef, +} + /// Same as `ExtractMemory` but for the `post-return` canonical option. #[derive(Debug, Serialize, Deserialize)] pub struct ExtractPostReturn { @@ -447,8 +464,14 @@ pub struct CanonicalOptions { /// The realloc function used by these options, if specified. pub realloc: Option, + /// The async callback function used by these options, if specified. + pub callback: Option, + /// The post-return function used by these options, if specified. pub post_return: Option, + + /// Whether to use the async ABI for lifting or lowering. + pub async_: bool, } /// Possible encodings of strings within the component model. @@ -633,6 +656,195 @@ pub enum Trampoline { /// Same as `ResourceNew`, but for the `resource.drop` intrinsic. ResourceDrop(TypeResourceTableIndex), + /// A `task.backpressure` intrinsic, which tells the host to enable or + /// disable backpressure for the caller's instance. + TaskBackpressure { + /// The specific component instance which is calling the intrinsic. + instance: RuntimeComponentInstanceIndex, + }, + + /// A `task.return` intrinsic, which returns a result to the caller of a + /// lifted export function. This allows the callee to continue executing + /// after returning a result. + TaskReturn, + + /// A `task.wait` intrinsic, which waits for at least one outstanding async + /// task/stream/future to make progress, returning the first such event. + TaskWait { + /// If `true`, indicates the caller instance maybe reentered. + async_: bool, + /// Memory to use when storing the event. + memory: RuntimeMemoryIndex, + }, + + /// A `task.poll` intrinsic, which checks whether any outstanding async + /// task/stream/future has made progress. Unlike `task.wait`, this does not + /// block and may return nothing if no such event has occurred. + TaskPoll { + /// If `true`, indicates the caller instance maybe reentered. + async_: bool, + /// Memory to use when storing the event. + memory: RuntimeMemoryIndex, + }, + + /// A `task.yield` intrinsic, which yields control to the host so that other + /// tasks are able to make progress, if any. + TaskYield { + /// If `true`, indicates the caller instance maybe reentered. + async_: bool, + }, + + /// A `subtask.drop` intrinsic to drop a specified task which has completed. + SubtaskDrop { + /// The specific component instance which is calling the intrinsic. + instance: RuntimeComponentInstanceIndex, + }, + + /// A `stream.new` intrinsic to create a new `stream` handle of the + /// specified type. + StreamNew { + /// The table index for the specific `stream` type and caller instance. + ty: TypeStreamTableIndex, + }, + + /// A `stream.read` intrinsic to read from a `stream` of the specified type. + StreamRead { + /// The table index for the specific `stream` type and caller instance. + ty: TypeStreamTableIndex, + /// Any options (e.g. string encoding) to use when storing values to + /// memory. + options: CanonicalOptions, + }, + + /// A `stream.write` intrinsic to write to a `stream` of the specified type. + StreamWrite { + /// The table index for the specific `stream` type and caller instance. + ty: TypeStreamTableIndex, + /// Any options (e.g. string encoding) to use when storing values to + /// memory. + options: CanonicalOptions, + }, + + /// A `stream.cancel-read` intrinsic to cancel an in-progress read from a + /// `stream` of the specified type. + StreamCancelRead { + /// The table index for the specific `stream` type and caller instance. + ty: TypeStreamTableIndex, + /// If `false`, block until cancellation completes rather than return + /// `BLOCKED`. + async_: bool, + }, + + /// A `stream.cancel-write` intrinsic to cancel an in-progress write from a + /// `stream` of the specified type. + StreamCancelWrite { + /// The table index for the specific `stream` type and caller instance. + ty: TypeStreamTableIndex, + /// If `false`, block until cancellation completes rather than return + /// `BLOCKED`. + async_: bool, + }, + + /// A `stream.close-readable` intrinsic to close the readable end of a + /// `stream` of the specified type. + StreamCloseReadable { + /// The table index for the specific `stream` type and caller instance. + ty: TypeStreamTableIndex, + }, + + /// A `stream.close-writable` intrinsic to close the writable end of a + /// `stream` of the specified type. + StreamCloseWritable { + /// The table index for the specific `stream` type and caller instance. + ty: TypeStreamTableIndex, + }, + + /// A `future.new` intrinsic to create a new `future` handle of the + /// specified type. + FutureNew { + /// The table index for the specific `future` type and caller instance. + ty: TypeFutureTableIndex, + }, + + /// A `future.read` intrinsic to read from a `future` of the specified type. + FutureRead { + /// The table index for the specific `future` type and caller instance. + ty: TypeFutureTableIndex, + /// Any options (e.g. string encoding) to use when storing values to + /// memory. + options: CanonicalOptions, + }, + + /// A `future.write` intrinsic to write to a `future` of the specified type. + FutureWrite { + /// The table index for the specific `future` type and caller instance. + ty: TypeFutureTableIndex, + /// Any options (e.g. string encoding) to use when storing values to + /// memory. + options: CanonicalOptions, + }, + + /// A `future.cancel-read` intrinsic to cancel an in-progress read from a + /// `future` of the specified type. + FutureCancelRead { + /// The table index for the specific `future` type and caller instance. + ty: TypeFutureTableIndex, + /// If `false`, block until cancellation completes rather than return + /// `BLOCKED`. + async_: bool, + }, + + /// A `future.cancel-write` intrinsic to cancel an in-progress write from a + /// `future` of the specified type. + FutureCancelWrite { + /// The table index for the specific `future` type and caller instance. + ty: TypeFutureTableIndex, + /// If `false`, block until cancellation completes rather than return + /// `BLOCKED`. + async_: bool, + }, + + /// A `future.close-readable` intrinsic to close the readable end of a + /// `future` of the specified type. + FutureCloseReadable { + /// The table index for the specific `future` type and caller instance. + ty: TypeFutureTableIndex, + }, + + /// A `future.close-writable` intrinsic to close the writable end of a + /// `future` of the specified type. + FutureCloseWritable { + /// The table index for the specific `future` type and caller instance. + ty: TypeFutureTableIndex, + }, + + /// A `error-context.new` intrinsic to create a new `error-context` with a + /// specified debug message. + ErrorContextNew { + /// The table index for the `error-context` type in the caller instance. + ty: TypeErrorContextTableIndex, + /// String encoding, memory, etc. to use when loading debug message. + options: CanonicalOptions, + }, + + /// A `error-context.debug-message` intrinsic to get the debug message for a + /// specified `error-context`. + /// + /// Note that the debug message might not necessarily match what was passed + /// to `error.new`. + ErrorContextDebugMessage { + /// The table index for the `error-context` type in the caller instance. + ty: TypeErrorContextTableIndex, + /// String encoding, memory, etc. to use when storing debug message. + options: CanonicalOptions, + }, + + /// A `error-context.drop` intrinsic to drop a specified `error-context`. + ErrorContextDrop { + /// The table index for the `error-context` type in the caller instance. + ty: TypeErrorContextTableIndex, + }, + /// An intrinsic used by FACT-generated modules which will transfer an owned /// resource from one table to another. Used in component-to-component /// adapter trampolines. @@ -650,6 +862,43 @@ pub enum Trampoline { /// Same as `ResourceEnterCall` except for when exiting a call. ResourceExitCall, + + /// An intrinsic used by FACT-generated modules to begin a call to an + /// async-lowered import function. + AsyncEnterCall, + + /// An intrinsic used by FACT-generated modules to complete a call to an + /// async-lowered import function. + /// + /// Note that `AsyncEnterCall` and `AsyncExitCall` could theoretically be + /// combined into a single `AsyncCall` intrinsic, but we separate them to + /// allow the FACT-generated module to optionally call the callee directly + /// without an intermediate host stack frame. + AsyncExitCall(Option), + + /// An intrinisic used by FACT-generated modules to (partially or entirely) transfer + /// ownership of a `future`. + /// + /// Transfering a `future` can either mean giving away the readable end + /// while retaining the writable end or only the former, depending on the + /// ownership status of the `future`. + FutureTransfer, + + /// An intrinisic used by FACT-generated modules to (partially or entirely) transfer + /// ownership of a `stream`. + /// + /// Transfering a `stream` can either mean giving away the readable end + /// while retaining the writable end or only the former, depending on the + /// ownership status of the `stream`. + StreamTransfer, + + /// An intrinisic used by FACT-generated modules to (partially or entirely) transfer + /// ownership of an `error-context`. + /// + /// Unlike futures, streams, and resource handles, `error-context` handles + /// are reference counted, meaning that sharing the handle with another + /// component does not invalidate the handle in the original component. + ErrorContextTransfer, } impl Trampoline { @@ -673,10 +922,38 @@ impl Trampoline { ResourceNew(i) => format!("component-resource-new[{}]", i.as_u32()), ResourceRep(i) => format!("component-resource-rep[{}]", i.as_u32()), ResourceDrop(i) => format!("component-resource-drop[{}]", i.as_u32()), + TaskBackpressure { .. } => format!("task-backpressure"), + TaskReturn => format!("task-return"), + TaskWait { .. } => format!("task-wait"), + TaskPoll { .. } => format!("task-poll"), + TaskYield { .. } => format!("task-yield"), + SubtaskDrop { .. } => format!("subtask-drop"), + StreamNew { .. } => format!("stream-new"), + StreamRead { .. } => format!("stream-read"), + StreamWrite { .. } => format!("stream-write"), + StreamCancelRead { .. } => format!("stream-cancel-read"), + StreamCancelWrite { .. } => format!("stream-cancel-write"), + StreamCloseReadable { .. } => format!("stream-close-readable"), + StreamCloseWritable { .. } => format!("stream-close-writable"), + FutureNew { .. } => format!("future-new"), + FutureRead { .. } => format!("future-read"), + FutureWrite { .. } => format!("future-write"), + FutureCancelRead { .. } => format!("future-cancel-read"), + FutureCancelWrite { .. } => format!("future-cancel-write"), + FutureCloseReadable { .. } => format!("future-close-readable"), + FutureCloseWritable { .. } => format!("future-close-writable"), + ErrorContextNew { .. } => format!("error-context-new"), + ErrorContextDebugMessage { .. } => format!("error-context-debug-message"), + ErrorContextDrop { .. } => format!("error-context-drop"), ResourceTransferOwn => format!("component-resource-transfer-own"), ResourceTransferBorrow => format!("component-resource-transfer-borrow"), ResourceEnterCall => format!("component-resource-enter-call"), ResourceExitCall => format!("component-resource-exit-call"), + AsyncEnterCall => format!("component-async-enter-call"), + AsyncExitCall(_) => format!("component-async-exit-call"), + FutureTransfer => format!("future-transfer"), + StreamTransfer => format!("stream-transfer"), + ErrorContextTransfer => format!("error-context-transfer"), } } } diff --git a/crates/environ/src/component/translate.rs b/crates/environ/src/component/translate.rs index 02bd0f497b28..cfd1ec07e062 100644 --- a/crates/environ/src/component/translate.rs +++ b/crates/environ/src/component/translate.rs @@ -12,8 +12,8 @@ use indexmap::IndexMap; use std::collections::HashMap; use std::mem; use wasmparser::component_types::{ - AliasableResourceId, ComponentCoreModuleTypeId, ComponentEntityType, ComponentFuncTypeId, - ComponentInstanceTypeId, + AliasableResourceId, ComponentCoreModuleTypeId, ComponentDefinedTypeId, ComponentEntityType, + ComponentFuncTypeId, ComponentInstanceTypeId, }; use wasmparser::types::Types; use wasmparser::{Chunk, ComponentImportName, Encoding, Parser, Payload, Validator}; @@ -189,6 +189,105 @@ enum LocalInitializer<'data> { ResourceRep(AliasableResourceId, ModuleInternedTypeIndex), ResourceDrop(AliasableResourceId, ModuleInternedTypeIndex), + TaskBackpressure { + func: ModuleInternedTypeIndex, + }, + TaskReturn { + func: ModuleInternedTypeIndex, + }, + TaskWait { + func: ModuleInternedTypeIndex, + async_: bool, + memory: MemoryIndex, + }, + TaskPoll { + func: ModuleInternedTypeIndex, + async_: bool, + memory: MemoryIndex, + }, + TaskYield { + func: ModuleInternedTypeIndex, + async_: bool, + }, + SubtaskDrop { + func: ModuleInternedTypeIndex, + }, + StreamNew { + ty: ComponentDefinedTypeId, + func: ModuleInternedTypeIndex, + }, + StreamRead { + ty: ComponentDefinedTypeId, + func: ModuleInternedTypeIndex, + options: LocalCanonicalOptions, + }, + StreamWrite { + ty: ComponentDefinedTypeId, + func: ModuleInternedTypeIndex, + options: LocalCanonicalOptions, + }, + StreamCancelRead { + ty: ComponentDefinedTypeId, + func: ModuleInternedTypeIndex, + async_: bool, + }, + StreamCancelWrite { + ty: ComponentDefinedTypeId, + func: ModuleInternedTypeIndex, + async_: bool, + }, + StreamCloseReadable { + ty: ComponentDefinedTypeId, + func: ModuleInternedTypeIndex, + }, + StreamCloseWritable { + ty: ComponentDefinedTypeId, + func: ModuleInternedTypeIndex, + }, + FutureNew { + ty: ComponentDefinedTypeId, + func: ModuleInternedTypeIndex, + }, + FutureRead { + ty: ComponentDefinedTypeId, + func: ModuleInternedTypeIndex, + options: LocalCanonicalOptions, + }, + FutureWrite { + ty: ComponentDefinedTypeId, + func: ModuleInternedTypeIndex, + options: LocalCanonicalOptions, + }, + FutureCancelRead { + ty: ComponentDefinedTypeId, + func: ModuleInternedTypeIndex, + async_: bool, + }, + FutureCancelWrite { + ty: ComponentDefinedTypeId, + func: ModuleInternedTypeIndex, + async_: bool, + }, + FutureCloseReadable { + ty: ComponentDefinedTypeId, + func: ModuleInternedTypeIndex, + }, + FutureCloseWritable { + ty: ComponentDefinedTypeId, + func: ModuleInternedTypeIndex, + }, + ErrorContextNew { + func: ModuleInternedTypeIndex, + options: LocalCanonicalOptions, + }, + ErrorContextDebugMessage { + func: ModuleInternedTypeIndex, + options: LocalCanonicalOptions, + }, + ErrorContextDrop { + func: ModuleInternedTypeIndex, + }, + // core wasm modules ModuleStatic(StaticModuleIndex, ComponentCoreModuleTypeId), @@ -256,6 +355,8 @@ struct LocalCanonicalOptions { memory: Option, realloc: Option, post_return: Option, + async_: bool, + callback: Option, } enum Action { @@ -472,7 +573,8 @@ impl<'a, 'data> Translator<'a, 'data> { // Entries in the canonical section will get initializers recorded // with the listed options for lifting/lowering. Payload::ComponentCanonicalSection(s) => { - let mut core_func_index = self.validator.types(0).unwrap().function_count(); + let types = self.validator.types(0).unwrap(); + let mut core_func_index = types.function_count(); self.validator.component_canonical_section(&s)?; for func in s { let types = self.validator.types(0).unwrap(); @@ -522,11 +624,153 @@ impl<'a, 'data> Translator<'a, 'data> { core_func_index += 1; LocalInitializer::ResourceRep(resource, ty) } - wasmparser::CanonicalFunction::ThreadSpawn { .. } | wasmparser::CanonicalFunction::ThreadHwConcurrency => { bail!("unsupported intrinsic") } + wasmparser::CanonicalFunction::TaskBackpressure => { + let core_type = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::TaskBackpressure { func: core_type } + } + wasmparser::CanonicalFunction::TaskReturn { .. } => { + let core_type = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::TaskReturn { func: core_type } + } + wasmparser::CanonicalFunction::TaskWait { async_, memory } => { + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::TaskWait { + func, + async_, + memory: MemoryIndex::from_u32(memory), + } + } + wasmparser::CanonicalFunction::TaskPoll { async_, memory } => { + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::TaskPoll { + func, + async_, + memory: MemoryIndex::from_u32(memory), + } + } + wasmparser::CanonicalFunction::TaskYield { async_ } => { + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::TaskYield { func, async_ } + } + wasmparser::CanonicalFunction::SubtaskDrop => { + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::SubtaskDrop { func } + } + wasmparser::CanonicalFunction::StreamNew { ty } => { + let ty = types.component_defined_type_at(ty); + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::StreamNew { ty, func } + } + wasmparser::CanonicalFunction::StreamRead { ty, options } => { + let ty = types.component_defined_type_at(ty); + let options = self.canonical_options(&options); + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::StreamRead { ty, func, options } + } + wasmparser::CanonicalFunction::StreamWrite { ty, options } => { + let ty = types.component_defined_type_at(ty); + let options = self.canonical_options(&options); + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::StreamWrite { ty, func, options } + } + wasmparser::CanonicalFunction::StreamCancelRead { ty, async_ } => { + let ty = types.component_defined_type_at(ty); + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::StreamCancelRead { ty, func, async_ } + } + wasmparser::CanonicalFunction::StreamCancelWrite { ty, async_ } => { + let ty = types.component_defined_type_at(ty); + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::StreamCancelWrite { ty, func, async_ } + } + wasmparser::CanonicalFunction::StreamCloseReadable { ty } => { + let ty = types.component_defined_type_at(ty); + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::StreamCloseReadable { ty, func } + } + wasmparser::CanonicalFunction::StreamCloseWritable { ty } => { + let ty = types.component_defined_type_at(ty); + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::StreamCloseWritable { ty, func } + } + wasmparser::CanonicalFunction::FutureNew { ty } => { + let ty = types.component_defined_type_at(ty); + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::FutureNew { ty, func } + } + wasmparser::CanonicalFunction::FutureRead { ty, options } => { + let ty = types.component_defined_type_at(ty); + let options = self.canonical_options(&options); + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::FutureRead { ty, func, options } + } + wasmparser::CanonicalFunction::FutureWrite { ty, options } => { + let ty = types.component_defined_type_at(ty); + let options = self.canonical_options(&options); + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::FutureWrite { ty, func, options } + } + wasmparser::CanonicalFunction::FutureCancelRead { ty, async_ } => { + let ty = types.component_defined_type_at(ty); + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::FutureCancelRead { ty, func, async_ } + } + wasmparser::CanonicalFunction::FutureCancelWrite { ty, async_ } => { + let ty = types.component_defined_type_at(ty); + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::FutureCancelWrite { ty, func, async_ } + } + wasmparser::CanonicalFunction::FutureCloseReadable { ty } => { + let ty = types.component_defined_type_at(ty); + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::FutureCloseReadable { ty, func } + } + wasmparser::CanonicalFunction::FutureCloseWritable { ty } => { + let ty = types.component_defined_type_at(ty); + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::FutureCloseWritable { ty, func } + } + wasmparser::CanonicalFunction::ErrorContextNew { options } => { + let options = self.canonical_options(&options); + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::ErrorContextNew { func, options } + } + wasmparser::CanonicalFunction::ErrorContextDebugMessage { options } => { + let options = self.canonical_options(&options); + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::ErrorContextDebugMessage { func, options } + } + wasmparser::CanonicalFunction::ErrorContextDrop => { + let func = self.core_func_signature(core_func_index)?; + core_func_index += 1; + LocalInitializer::ErrorContextDrop { func } + } }; self.result.initializers.push(init); } @@ -897,6 +1141,8 @@ impl<'a, 'data> Translator<'a, 'data> { memory: None, realloc: None, post_return: None, + async_: false, + callback: None, }; for opt in opts { match opt { @@ -921,6 +1167,11 @@ impl<'a, 'data> Translator<'a, 'data> { let idx = FuncIndex::from_u32(*idx); ret.post_return = Some(idx); } + wasmparser::CanonicalOption::Async => ret.async_ = true, + wasmparser::CanonicalOption::Callback(idx) => { + let idx = FuncIndex::from_u32(*idx); + ret.callback = Some(idx); + } } } return ret; diff --git a/crates/environ/src/component/translate/adapt.rs b/crates/environ/src/component/translate/adapt.rs index 8989e6fce1a1..f49c04f0a0ab 100644 --- a/crates/environ/src/component/translate/adapt.rs +++ b/crates/environ/src/component/translate/adapt.rs @@ -157,8 +157,12 @@ pub struct AdapterOptions { pub memory64: bool, /// An optional definition of `realloc` to used. pub realloc: Option, + /// The async callback function used by these options, if specified. + pub callback: Option, /// An optional definition of a `post-return` to use. pub post_return: Option, + /// Whether to use the async ABI for lifting or lowering. + pub async_: bool, } impl<'data> Translator<'_, 'data> { @@ -227,7 +231,7 @@ impl<'data> Translator<'_, 'data> { // in-order here as well. (with an assert to double-check) for (adapter, name) in adapter_module.adapters.iter().zip(&names) { let index = translation.module.exports[name]; - let i = component.adapter_paritionings.push((module_id, index)); + let i = component.adapter_partitionings.push((module_id, index)); assert_eq!(i, *adapter); } @@ -300,6 +304,15 @@ fn fact_import_to_core_def( } fact::Import::ResourceEnterCall => simple_intrinsic(dfg::Trampoline::ResourceEnterCall), fact::Import::ResourceExitCall => simple_intrinsic(dfg::Trampoline::ResourceExitCall), + fact::Import::AsyncEnterCall => simple_intrinsic(dfg::Trampoline::AsyncEnterCall), + fact::Import::AsyncExitCall(callback) => simple_intrinsic(dfg::Trampoline::AsyncExitCall( + callback.clone().map(|v| dfg.callbacks.push(v)), + )), + fact::Import::FutureTransfer => simple_intrinsic(dfg::Trampoline::FutureTransfer), + fact::Import::StreamTransfer => simple_intrinsic(dfg::Trampoline::StreamTransfer), + fact::Import::ErrorContextTransfer => { + simple_intrinsic(dfg::Trampoline::ErrorContextTransfer) + } } } @@ -363,6 +376,9 @@ impl PartitionAdapterModules { if let Some(def) = &options.realloc { self.core_def(dfg, def); } + if let Some(def) = &options.callback { + self.core_def(dfg, def); + } if let Some(def) = &options.post_return { self.core_def(dfg, def); } diff --git a/crates/environ/src/component/translate/inline.rs b/crates/environ/src/component/translate/inline.rs index 881cb7440f08..c004e9ef3fbe 100644 --- a/crates/environ/src/component/translate/inline.rs +++ b/crates/environ/src/component/translate/inline.rs @@ -667,6 +667,286 @@ impl<'a> Inliner<'a> { .push((*ty, dfg::Trampoline::ResourceDrop(id))); frame.funcs.push(dfg::CoreDef::Trampoline(index)); } + TaskBackpressure { func } => { + let index = self.result.trampolines.push(( + *func, + dfg::Trampoline::TaskBackpressure { + instance: frame.instance, + }, + )); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + TaskReturn { func } => { + let index = self + .result + .trampolines + .push((*func, dfg::Trampoline::TaskReturn)); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + TaskWait { + func, + async_, + memory, + } => { + let (memory, _) = self.memory(frame, types, *memory); + let memory = self.result.memories.push(memory); + let index = self.result.trampolines.push(( + *func, + dfg::Trampoline::TaskWait { + async_: *async_, + memory, + }, + )); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + TaskPoll { + func, + async_, + memory, + } => { + let (memory, _) = self.memory(frame, types, *memory); + let memory = self.result.memories.push(memory); + let index = self.result.trampolines.push(( + *func, + dfg::Trampoline::TaskPoll { + async_: *async_, + memory, + }, + )); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + TaskYield { func, async_ } => { + let index = self + .result + .trampolines + .push((*func, dfg::Trampoline::TaskYield { async_: *async_ })); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + SubtaskDrop { func } => { + let index = self.result.trampolines.push(( + *func, + dfg::Trampoline::SubtaskDrop { + instance: frame.instance, + }, + )); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + StreamNew { ty, func } => { + let InterfaceType::Stream(ty) = + types.defined_type(frame.translation.types_ref(), *ty)? + else { + unreachable!() + }; + let index = self + .result + .trampolines + .push((*func, dfg::Trampoline::StreamNew { ty })); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + StreamRead { ty, func, options } => { + let InterfaceType::Stream(ty) = + types.defined_type(frame.translation.types_ref(), *ty)? + else { + unreachable!() + }; + let options = self.adapter_options(frame, types, options); + let options = self.canonical_options(options); + let index = self + .result + .trampolines + .push((*func, dfg::Trampoline::StreamRead { ty, options })); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + StreamWrite { ty, func, options } => { + let InterfaceType::Stream(ty) = + types.defined_type(frame.translation.types_ref(), *ty)? + else { + unreachable!() + }; + let options = self.adapter_options(frame, types, options); + let options = self.canonical_options(options); + let index = self + .result + .trampolines + .push((*func, dfg::Trampoline::StreamWrite { ty, options })); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + StreamCancelRead { ty, func, async_ } => { + let InterfaceType::Stream(ty) = + types.defined_type(frame.translation.types_ref(), *ty)? + else { + unreachable!() + }; + let index = self.result.trampolines.push(( + *func, + dfg::Trampoline::StreamCancelRead { + ty, + async_: *async_, + }, + )); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + StreamCancelWrite { ty, func, async_ } => { + let InterfaceType::Stream(ty) = + types.defined_type(frame.translation.types_ref(), *ty)? + else { + unreachable!() + }; + let index = self.result.trampolines.push(( + *func, + dfg::Trampoline::StreamCancelWrite { + ty, + async_: *async_, + }, + )); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + StreamCloseReadable { ty, func } => { + let InterfaceType::Stream(ty) = + types.defined_type(frame.translation.types_ref(), *ty)? + else { + unreachable!() + }; + let index = self + .result + .trampolines + .push((*func, dfg::Trampoline::StreamCloseReadable { ty })); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + StreamCloseWritable { ty, func } => { + let InterfaceType::Stream(ty) = + types.defined_type(frame.translation.types_ref(), *ty)? + else { + unreachable!() + }; + let index = self + .result + .trampolines + .push((*func, dfg::Trampoline::StreamCloseWritable { ty })); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + FutureNew { ty, func } => { + let InterfaceType::Future(ty) = + types.defined_type(frame.translation.types_ref(), *ty)? + else { + unreachable!() + }; + let index = self + .result + .trampolines + .push((*func, dfg::Trampoline::FutureNew { ty })); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + FutureRead { ty, func, options } => { + let InterfaceType::Future(ty) = + types.defined_type(frame.translation.types_ref(), *ty)? + else { + unreachable!() + }; + let options = self.adapter_options(frame, types, options); + let options = self.canonical_options(options); + let index = self + .result + .trampolines + .push((*func, dfg::Trampoline::FutureRead { ty, options })); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + FutureWrite { ty, func, options } => { + let InterfaceType::Future(ty) = + types.defined_type(frame.translation.types_ref(), *ty)? + else { + unreachable!() + }; + let options = self.adapter_options(frame, types, options); + let options = self.canonical_options(options); + let index = self + .result + .trampolines + .push((*func, dfg::Trampoline::FutureWrite { ty, options })); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + FutureCancelRead { ty, func, async_ } => { + let InterfaceType::Future(ty) = + types.defined_type(frame.translation.types_ref(), *ty)? + else { + unreachable!() + }; + let index = self.result.trampolines.push(( + *func, + dfg::Trampoline::FutureCancelRead { + ty, + async_: *async_, + }, + )); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + FutureCancelWrite { ty, func, async_ } => { + let InterfaceType::Future(ty) = + types.defined_type(frame.translation.types_ref(), *ty)? + else { + unreachable!() + }; + let index = self.result.trampolines.push(( + *func, + dfg::Trampoline::FutureCancelWrite { + ty, + async_: *async_, + }, + )); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + FutureCloseReadable { ty, func } => { + let InterfaceType::Future(ty) = + types.defined_type(frame.translation.types_ref(), *ty)? + else { + unreachable!() + }; + let index = self + .result + .trampolines + .push((*func, dfg::Trampoline::FutureCloseReadable { ty })); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + FutureCloseWritable { ty, func } => { + let InterfaceType::Future(ty) = + types.defined_type(frame.translation.types_ref(), *ty)? + else { + unreachable!() + }; + let index = self + .result + .trampolines + .push((*func, dfg::Trampoline::FutureCloseWritable { ty })); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + ErrorContextNew { func, options } => { + let ty = types.error_context_type()?; + let options = self.adapter_options(frame, types, options); + let options = self.canonical_options(options); + let index = self + .result + .trampolines + .push((*func, dfg::Trampoline::ErrorContextNew { ty, options })); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + ErrorContextDebugMessage { func, options } => { + let ty = types.error_context_type()?; + let options = self.adapter_options(frame, types, options); + let options = self.canonical_options(options); + let index = self.result.trampolines.push(( + *func, + dfg::Trampoline::ErrorContextDebugMessage { ty, options }, + )); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } + ErrorContextDrop { func } => { + let ty = types.error_context_type()?; + let index = self + .result + .trampolines + .push((*func, dfg::Trampoline::ErrorContextDrop { ty })); + frame.funcs.push(dfg::CoreDef::Trampoline(index)); + } ModuleStatic(idx, ty) => { frame.modules.push(ModuleDef::Static(*idx, *ty)); @@ -948,6 +1228,41 @@ impl<'a> Inliner<'a> { } } + fn memory( + &mut self, + frame: &InlinerFrame<'a>, + types: &ComponentTypesBuilder, + memory: MemoryIndex, + ) -> (dfg::CoreExport, bool) { + let memory = frame.memories[memory].clone().map_index(|i| match i { + EntityIndex::Memory(i) => i, + _ => unreachable!(), + }); + let memory64 = match &self.runtime_instances[memory.instance] { + InstanceModule::Static(idx) => match &memory.item { + ExportItem::Index(i) => { + let ty = &self.nested_modules[*idx].module.memories[*i]; + match ty.idx_type { + IndexType::I32 => false, + IndexType::I64 => true, + } + } + ExportItem::Name(_) => unreachable!(), + }, + InstanceModule::Import(ty) => match &memory.item { + ExportItem::Name(name) => match types[*ty].exports[name] { + EntityType::Memory(m) => match m.idx_type { + IndexType::I32 => false, + IndexType::I64 => true, + }, + _ => unreachable!(), + }, + ExportItem::Index(_) => unreachable!(), + }, + }; + (memory, memory64) + } + /// Translates a `LocalCanonicalOptions` which indexes into the `frame` /// specified into a runtime representation. fn adapter_options( @@ -988,6 +1303,7 @@ impl<'a> Inliner<'a> { None => false, }; let realloc = options.realloc.map(|i| frame.funcs[i].clone()); + let callback = options.callback.map(|i| frame.funcs[i].clone()); let post_return = options.post_return.map(|i| frame.funcs[i].clone()); AdapterOptions { instance: frame.instance, @@ -995,7 +1311,9 @@ impl<'a> Inliner<'a> { memory, memory64, realloc, + callback, post_return, + async_: options.async_, } } @@ -1008,6 +1326,7 @@ impl<'a> Inliner<'a> { .memory .map(|export| self.result.memories.push(export)); let realloc = options.realloc.map(|def| self.result.reallocs.push(def)); + let callback = options.callback.map(|def| self.result.callbacks.push(def)); let post_return = options .post_return .map(|def| self.result.post_returns.push(def)); @@ -1016,7 +1335,9 @@ impl<'a> Inliner<'a> { string_encoding: options.string_encoding, memory, realloc, + callback, post_return, + async_: options.async_, } } diff --git a/crates/environ/src/component/types.rs b/crates/environ/src/component/types.rs index 6d99e1f55f86..d9d24f770d1c 100644 --- a/crates/environ/src/component/types.rs +++ b/crates/environ/src/component/types.rs @@ -89,6 +89,25 @@ indices! { pub struct TypeResultIndex(u32); /// Index pointing to a list type in the component model. pub struct TypeListIndex(u32); + /// Index pointing to a future type in the component model. + pub struct TypeFutureIndex(u32); + /// Index pointing to a future table within a component. + /// + /// This is analogous to `TypeResourceTableIndex` in that it tracks + /// ownership of futures within each (sub)component instance. + pub struct TypeFutureTableIndex(u32); + /// Index pointing to a stream type in the component model. + pub struct TypeStreamIndex(u32); + /// Index pointing to a stream table within a component. + /// + /// This is analogous to `TypeResourceTableIndex` in that it tracks + /// ownership of stream within each (sub)component instance. + pub struct TypeStreamTableIndex(u32); + /// Index pointing to a error context table within a component. + /// + /// This is analogous to `TypeResourceTableIndex` in that it tracks + /// ownership of error contexts within each (sub)component instance. + pub struct TypeErrorContextTableIndex(u32); /// Index pointing to a resource table within a component. /// @@ -186,6 +205,9 @@ indices! { /// Same as `RuntimeMemoryIndex` except for the `realloc` function. pub struct RuntimeReallocIndex(u32); + /// Same as `RuntimeMemoryIndex` except for the `callback` function. + pub struct RuntimeCallbackIndex(u32); + /// Same as `RuntimeMemoryIndex` except for the `post-return` function. pub struct RuntimePostReturnIndex(u32); @@ -194,7 +216,7 @@ indices! { /// /// This is used to point to various bits of metadata within a compiled /// component and is stored in the final compilation artifact. This does not - /// have a direct corresponance to any wasm definition. + /// have a direct correspondence to any wasm definition. pub struct TrampolineIndex(u32); /// An index into `Component::export_items` at the end of compilation. @@ -237,8 +259,12 @@ pub struct ComponentTypes { pub(super) options: PrimaryMap, pub(super) results: PrimaryMap, pub(super) resource_tables: PrimaryMap, - pub(super) module_types: Option, + pub(super) futures: PrimaryMap, + pub(super) future_tables: PrimaryMap, + pub(super) streams: PrimaryMap, + pub(super) stream_tables: PrimaryMap, + pub(super) error_context_tables: PrimaryMap, } impl ComponentTypes { @@ -261,7 +287,10 @@ impl ComponentTypes { | InterfaceType::Float32 | InterfaceType::Char | InterfaceType::Own(_) - | InterfaceType::Borrow(_) => &CanonicalAbiInfo::SCALAR4, + | InterfaceType::Borrow(_) + | InterfaceType::Future(_) + | InterfaceType::Stream(_) + | InterfaceType::ErrorContext(_) => &CanonicalAbiInfo::SCALAR4, InterfaceType::U64 | InterfaceType::S64 | InterfaceType::Float64 => { &CanonicalAbiInfo::SCALAR8 @@ -320,6 +349,11 @@ impl_index! { impl Index for ComponentTypes { TypeResult => results } impl Index for ComponentTypes { TypeList => lists } impl Index for ComponentTypes { TypeResourceTable => resource_tables } + impl Index for ComponentTypes { TypeFuture => futures } + impl Index for ComponentTypes { TypeStream => streams } + impl Index for ComponentTypes { TypeFutureTable => future_tables } + impl Index for ComponentTypes { TypeStreamTable => stream_tables } + impl Index for ComponentTypes { TypeErrorContextTable => error_context_tables } } // Additionally forward anything that can index `ModuleTypes` to `ModuleTypes` @@ -462,6 +496,9 @@ pub enum InterfaceType { Result(TypeResultIndex), Own(TypeResourceTableIndex), Borrow(TypeResourceTableIndex), + Future(TypeFutureTableIndex), + Stream(TypeStreamTableIndex), + ErrorContext(TypeErrorContextTableIndex), } impl From<&wasmparser::PrimitiveValType> for InterfaceType { @@ -970,6 +1007,45 @@ pub struct TypeResult { pub info: VariantInfo, } +/// Shape of a "future" interface type. +#[derive(Serialize, Deserialize, Clone, Hash, Eq, PartialEq, Debug)] +pub struct TypeFuture { + /// The `T` in `future` + pub payload: Option, +} + +/// Metadata about a future table added to a component. +#[derive(Serialize, Deserialize, Clone, Hash, Eq, PartialEq, Debug)] +pub struct TypeFutureTable { + /// The specific future type this table is used for. + pub ty: TypeFutureIndex, + /// The specific component instance this table is used for. + pub instance: RuntimeComponentInstanceIndex, +} + +/// Shape of a "stream" interface type. +#[derive(Serialize, Deserialize, Clone, Hash, Eq, PartialEq, Debug)] +pub struct TypeStream { + /// The `T` in `stream` + pub payload: InterfaceType, +} + +/// Metadata about a stream table added to a component. +#[derive(Serialize, Deserialize, Clone, Hash, Eq, PartialEq, Debug)] +pub struct TypeStreamTable { + /// The specific stream type this table is used for. + pub ty: TypeStreamIndex, + /// The specific component instance this table is used for. + pub instance: RuntimeComponentInstanceIndex, +} + +/// Metadata about a error context table added to a component. +#[derive(Serialize, Deserialize, Clone, Hash, Eq, PartialEq, Debug)] +pub struct TypeErrorContextTable { + /// The specific component instance this table is used for. + pub instance: RuntimeComponentInstanceIndex, +} + /// Metadata about a resource table added to a component. #[derive(Serialize, Deserialize, Clone, Hash, Eq, PartialEq, Debug)] pub struct TypeResourceTable { diff --git a/crates/environ/src/component/types_builder.rs b/crates/environ/src/component/types_builder.rs index 053415b8ce66..858efc9856ea 100644 --- a/crates/environ/src/component/types_builder.rs +++ b/crates/environ/src/component/types_builder.rs @@ -31,7 +31,7 @@ pub use resources::ResourcesBuilder; /// Some more information about this can be found in #4814 const MAX_TYPE_DEPTH: u32 = 100; -/// Structured used to build a [`ComponentTypes`] during translation. +/// Structure used to build a [`ComponentTypes`] during translation. /// /// This contains tables to intern any component types found as well as /// managing building up core wasm [`ModuleTypes`] as well. @@ -45,6 +45,11 @@ pub struct ComponentTypesBuilder { flags: HashMap, options: HashMap, results: HashMap, + futures: HashMap, + streams: HashMap, + future_tables: HashMap, + stream_tables: HashMap, + error_context_tables: HashMap, component_types: ComponentTypes, module_types: ModuleTypesBuilder, @@ -70,15 +75,16 @@ where macro_rules! intern_and_fill_flat_types { ($me:ident, $name:ident, $val:ident) => {{ if let Some(idx) = $me.$name.get(&$val) { - return *idx; + *idx + } else { + let idx = $me.component_types.$name.push($val.clone()); + let mut info = TypeInformation::new(); + info.$name($me, &$val); + let idx2 = $me.type_info.$name.push(info); + assert_eq!(idx, idx2); + $me.$name.insert($val, idx); + idx } - let idx = $me.component_types.$name.push($val.clone()); - let mut info = TypeInformation::new(); - info.$name($me, &$val); - let idx2 = $me.type_info.$name.push(info); - assert_eq!(idx, idx2); - $me.$name.insert($val, idx); - return idx; }}; } @@ -97,6 +103,11 @@ impl ComponentTypesBuilder { flags: HashMap::default(), options: HashMap::default(), results: HashMap::default(), + futures: HashMap::default(), + streams: HashMap::default(), + future_tables: HashMap::default(), + stream_tables: HashMap::default(), + error_context_tables: HashMap::default(), component_types: ComponentTypes::default(), type_info: TypeInformationCache::default(), resources: ResourcesBuilder::default(), @@ -354,7 +365,8 @@ impl ComponentTypesBuilder { }) } - fn defined_type( + /// Convert a wasmparser `ComponentDefinedTypeId` into Wasmtime's type representation. + pub fn defined_type( &mut self, types: TypesRef<'_>, id: ComponentDefinedTypeId, @@ -378,6 +390,15 @@ impl ComponentTypesBuilder { ComponentDefinedType::Borrow(r) => { InterfaceType::Borrow(self.resource_id(r.resource())) } + ComponentDefinedType::Future(ty) => { + InterfaceType::Future(self.future_table_type(types, ty)?) + } + ComponentDefinedType::Stream(ty) => { + InterfaceType::Stream(self.stream_table_type(types, ty)?) + } + ComponentDefinedType::ErrorContext => { + InterfaceType::ErrorContext(self.error_context_table_type()?) + } }; let info = self.type_information(&ret); if info.depth > MAX_TYPE_DEPTH { @@ -386,6 +407,11 @@ impl ComponentTypesBuilder { Ok(ret) } + /// Retrieve Wasmtime's type representation of the `error-context` type. + pub fn error_context_type(&mut self) -> Result { + self.error_context_table_type() + } + fn valtype(&mut self, types: TypesRef<'_>, ty: &ComponentValType) -> Result { assert_eq!(types.id(), self.module_types.validator_id()); match ty { @@ -511,6 +537,38 @@ impl ComponentTypesBuilder { Ok(self.add_result_type(TypeResult { ok, err, abi, info })) } + fn future_table_type( + &mut self, + types: TypesRef<'_>, + ty: &Option, + ) -> Result { + let payload = ty.as_ref().map(|ty| self.valtype(types, ty)).transpose()?; + let ty = self.add_future_type(TypeFuture { payload }); + Ok(self.add_future_table_type(TypeFutureTable { + ty, + instance: self.resources.get_current_instance().unwrap(), + })) + } + + fn stream_table_type( + &mut self, + types: TypesRef<'_>, + ty: &ComponentValType, + ) -> Result { + let payload = self.valtype(types, ty)?; + let ty = self.add_stream_type(TypeStream { payload }); + Ok(self.add_stream_table_type(TypeStreamTable { + ty, + instance: self.resources.get_current_instance().unwrap(), + })) + } + + fn error_context_table_type(&mut self) -> Result { + Ok(self.add_error_context_table_type(TypeErrorContextTable { + instance: self.resources.get_current_instance().unwrap(), + })) + } + fn list_type(&mut self, types: TypesRef<'_>, ty: &ComponentValType) -> Result { assert_eq!(types.id(), self.module_types.validator_id()); let element = self.valtype(types, ty)?; @@ -563,11 +621,51 @@ impl ComponentTypesBuilder { intern_and_fill_flat_types!(self, results, ty) } - /// Interns a new type within this type information. + /// Interns a new list type within this type information. pub fn add_list_type(&mut self, ty: TypeList) -> TypeListIndex { intern_and_fill_flat_types!(self, lists, ty) } + /// Interns a new future type within this type information. + pub fn add_future_type(&mut self, ty: TypeFuture) -> TypeFutureIndex { + intern(&mut self.futures, &mut self.component_types.futures, ty) + } + + /// Interns a new future table type within this type information. + pub fn add_future_table_type(&mut self, ty: TypeFutureTable) -> TypeFutureTableIndex { + intern( + &mut self.future_tables, + &mut self.component_types.future_tables, + ty, + ) + } + + /// Interns a new stream type within this type information. + pub fn add_stream_type(&mut self, ty: TypeStream) -> TypeStreamIndex { + intern(&mut self.streams, &mut self.component_types.streams, ty) + } + + /// Interns a new stream table type within this type information. + pub fn add_stream_table_type(&mut self, ty: TypeStreamTable) -> TypeStreamTableIndex { + intern( + &mut self.stream_tables, + &mut self.component_types.stream_tables, + ty, + ) + } + + /// Interns a new error context table type within this type information. + pub fn add_error_context_table_type( + &mut self, + ty: TypeErrorContextTable, + ) -> TypeErrorContextTableIndex { + intern( + &mut self.error_context_tables, + &mut self.component_types.error_context_tables, + ty, + ) + } + /// Returns the canonical ABI information about the specified type. pub fn canonical_abi(&self, ty: &InterfaceType) -> &CanonicalAbiInfo { self.component_types.canonical_abi(ty) @@ -598,7 +696,10 @@ impl ComponentTypesBuilder { | InterfaceType::U32 | InterfaceType::S32 | InterfaceType::Char - | InterfaceType::Own(_) => { + | InterfaceType::Own(_) + | InterfaceType::Future(_) + | InterfaceType::Stream(_) + | InterfaceType::ErrorContext(_) => { static INFO: TypeInformation = TypeInformation::primitive(FlatType::I32); &INFO } diff --git a/crates/environ/src/component/types_builder/resources.rs b/crates/environ/src/component/types_builder/resources.rs index d02426d49b6b..cba01c60747f 100644 --- a/crates/environ/src/component/types_builder/resources.rs +++ b/crates/environ/src/component/types_builder/resources.rs @@ -231,4 +231,9 @@ impl ResourcesBuilder { pub fn set_current_instance(&mut self, instance: RuntimeComponentInstanceIndex) { self.current_instance = Some(instance); } + + /// Retrieves the `current_instance` field. + pub fn get_current_instance(&self) -> Option { + self.current_instance + } } diff --git a/crates/environ/src/component/vmcomponent_offsets.rs b/crates/environ/src/component/vmcomponent_offsets.rs index a770fb2eeecf..e3f25a74106f 100644 --- a/crates/environ/src/component/vmcomponent_offsets.rs +++ b/crates/environ/src/component/vmcomponent_offsets.rs @@ -4,13 +4,40 @@ // magic: u32, // libcalls: &'static VMComponentLibcalls, // store: *mut dyn Store, +// task_backpressure: VMTaskBackpressureCallback, +// task_return: VMTaskReturnCallback, +// task_wait: VMTaskWaitOrPollCallback, +// task_poll: VMTaskWaitOrPollCallback, +// task_yield: VMTaskYieldCallback, +// subtask_drop: VMSubtaskDropCallback, +// async_enter: VMAsyncEnterCallback, +// async_exit: VMAsyncExitCallback, +// future_new: VMFutureNewCallback, +// future_write: VMFutureTransmitCallback, +// future_read: VMFutureTransmitCallback, +// future_cancel_write: VMFutureCancelCallback, +// future_cancel_read: VMFutureCancelCallback, +// stream_cancel_write: VMStreamCancelCallback, +// stream_cancel_read: VMStreamCancelCallback, +// future_close_writable: VMFutureCloseWritableCallback, +// future_close_readable: VMFutureCloseReadableCallback, +// stream_close_writable: VMStreamCloseWritableCallback, +// stream_close_readable: VMStreamCloseReadableCallback, +// stream_new: VMStreamNewCallback, +// stream_write: VMStreamTransmitCallback, +// stream_read: VMStreamTransmitCallback, +// flat_stream_write: VMFlatStreamTransmitCallback, +// flat_stream_read: VMFlatStreamTransmitCallback, +// error_context_new: VMErrorContextNewCallback, +// error_context_debug_string: VMErrorContextDebugStringCallback, +// error_context_drop: VMErrorContextDropCallback, // limits: *const VMRuntimeLimits, // flags: [VMGlobalDefinition; component.num_runtime_component_instances], // trampoline_func_refs: [VMFuncRef; component.num_trampolines], // lowerings: [VMLowering; component.num_lowerings], -// memories: [*mut VMMemoryDefinition; component.num_memories], -// reallocs: [*mut VMFuncRef; component.num_reallocs], -// post_returns: [*mut VMFuncRef; component.num_post_returns], +// memories: [*mut VMMemoryDefinition; component.num_runtime_memories], +// reallocs: [*mut VMFuncRef; component.num_runtime_reallocs], +// post_returns: [*mut VMFuncRef; component.num_runtime_post_returns], // resource_destructors: [*mut VMFuncRef; component.num_resources], // } @@ -47,6 +74,8 @@ pub struct VMComponentOffsets

{ pub num_runtime_memories: u32, /// The number of reallocs which are recorded in this component for options. pub num_runtime_reallocs: u32, + /// The number of callbacks which are recorded in this component for options. + pub num_runtime_callbacks: u32, /// The number of post-returns which are recorded in this component for options. pub num_runtime_post_returns: u32, /// Number of component instances internally in the component (always at @@ -61,12 +90,40 @@ pub struct VMComponentOffsets

{ magic: u32, libcalls: u32, store: u32, + task_backpressure: u32, + task_return: u32, + task_wait: u32, + task_poll: u32, + task_yield: u32, + subtask_drop: u32, + async_enter: u32, + async_exit: u32, + future_new: u32, + future_write: u32, + future_read: u32, + future_cancel_write: u32, + future_cancel_read: u32, + future_close_writable: u32, + future_close_readable: u32, + stream_new: u32, + stream_write: u32, + stream_read: u32, + stream_cancel_write: u32, + stream_cancel_read: u32, + stream_close_writable: u32, + stream_close_readable: u32, + flat_stream_write: u32, + flat_stream_read: u32, + error_context_new: u32, + error_context_debug_message: u32, + error_context_drop: u32, limits: u32, flags: u32, trampoline_func_refs: u32, lowerings: u32, memories: u32, reallocs: u32, + callbacks: u32, post_returns: u32, resource_destructors: u32, size: u32, @@ -87,6 +144,7 @@ impl VMComponentOffsets

{ num_lowerings: component.num_lowerings, num_runtime_memories: component.num_runtime_memories.try_into().unwrap(), num_runtime_reallocs: component.num_runtime_reallocs.try_into().unwrap(), + num_runtime_callbacks: component.num_runtime_callbacks.try_into().unwrap(), num_runtime_post_returns: component.num_runtime_post_returns.try_into().unwrap(), num_runtime_component_instances: component .num_runtime_component_instances @@ -103,9 +161,37 @@ impl VMComponentOffsets

{ lowerings: 0, memories: 0, reallocs: 0, + callbacks: 0, post_returns: 0, resource_destructors: 0, size: 0, + task_backpressure: 0, + task_return: 0, + task_wait: 0, + task_poll: 0, + task_yield: 0, + subtask_drop: 0, + async_enter: 0, + async_exit: 0, + future_new: 0, + future_write: 0, + future_read: 0, + future_cancel_write: 0, + future_cancel_read: 0, + future_close_writable: 0, + future_close_readable: 0, + stream_new: 0, + stream_write: 0, + stream_read: 0, + stream_cancel_write: 0, + stream_cancel_read: 0, + stream_close_writable: 0, + stream_close_readable: 0, + flat_stream_write: 0, + flat_stream_read: 0, + error_context_new: 0, + error_context_debug_message: 0, + error_context_drop: 0, }; // Convenience functions for checked addition and multiplication. @@ -138,6 +224,33 @@ impl VMComponentOffsets

{ size(libcalls) = ret.ptr.size(), size(store) = cmul(2, ret.ptr.size()), size(limits) = ret.ptr.size(), + size(task_backpressure) = ret.ptr.size(), + size(task_return) = ret.ptr.size(), + size(task_wait) = ret.ptr.size(), + size(task_poll) = ret.ptr.size(), + size(task_yield) = ret.ptr.size(), + size(subtask_drop) = ret.ptr.size(), + size(async_enter) = ret.ptr.size(), + size(async_exit) = ret.ptr.size(), + size(future_new) = ret.ptr.size(), + size(future_write) = ret.ptr.size(), + size(future_read) = ret.ptr.size(), + size(future_cancel_write) = ret.ptr.size(), + size(future_cancel_read) = ret.ptr.size(), + size(future_close_writable) = ret.ptr.size(), + size(future_close_readable) = ret.ptr.size(), + size(stream_new) = ret.ptr.size(), + size(stream_write) = ret.ptr.size(), + size(stream_read) = ret.ptr.size(), + size(stream_cancel_write) = ret.ptr.size(), + size(stream_cancel_read) = ret.ptr.size(), + size(stream_close_writable) = ret.ptr.size(), + size(stream_close_readable) = ret.ptr.size(), + size(flat_stream_write) = ret.ptr.size(), + size(flat_stream_read) = ret.ptr.size(), + size(error_context_new) = ret.ptr.size(), + size(error_context_debug_message) = ret.ptr.size(), + size(error_context_drop) = ret.ptr.size(), align(16), size(flags) = cmul(ret.num_runtime_component_instances, ret.ptr.size_of_vmglobal_definition()), align(u32::from(ret.ptr.size())), @@ -145,6 +258,7 @@ impl VMComponentOffsets

{ size(lowerings) = cmul(ret.num_lowerings, ret.ptr.size() * 2), size(memories) = cmul(ret.num_runtime_memories, ret.ptr.size()), size(reallocs) = cmul(ret.num_runtime_reallocs, ret.ptr.size()), + size(callbacks) = cmul(ret.num_runtime_callbacks, ret.ptr.size()), size(post_returns) = cmul(ret.num_runtime_post_returns, ret.ptr.size()), size(resource_destructors) = cmul(ret.num_resources, ret.ptr.size()), } @@ -215,6 +329,141 @@ impl VMComponentOffsets

{ self.lowerings } + /// The offset of the `task_backpressure` field. + pub fn task_backpressure(&self) -> u32 { + self.task_backpressure + } + + /// The offset of the `task_return` field. + pub fn task_return(&self) -> u32 { + self.task_return + } + + /// The offset of the `task_wait` field. + pub fn task_wait(&self) -> u32 { + self.task_wait + } + + /// The offset of the `task_poll` field. + pub fn task_poll(&self) -> u32 { + self.task_poll + } + + /// The offset of the `task_yield` field. + pub fn task_yield(&self) -> u32 { + self.task_yield + } + + /// The offset of the `subtask_drop` field. + pub fn subtask_drop(&self) -> u32 { + self.subtask_drop + } + + /// The offset of the `async_enter` field. + pub fn async_enter(&self) -> u32 { + self.async_enter + } + + /// The offset of the `async_exit` field. + pub fn async_exit(&self) -> u32 { + self.async_exit + } + + /// The offset of the `future_new` field. + pub fn future_new(&self) -> u32 { + self.future_new + } + + /// The offset of the `future_write` field. + pub fn future_write(&self) -> u32 { + self.future_write + } + + /// The offset of the `future_read` field. + pub fn future_read(&self) -> u32 { + self.future_read + } + + /// The offset of the `future_cancel_write` field. + pub fn future_cancel_write(&self) -> u32 { + self.future_cancel_write + } + + /// The offset of the `future_cancel_read` field. + pub fn future_cancel_read(&self) -> u32 { + self.future_cancel_read + } + + /// The offset of the `future_close_writable` field. + pub fn future_close_writable(&self) -> u32 { + self.future_close_writable + } + + /// The offset of the `future_close_readable` field. + pub fn future_close_readable(&self) -> u32 { + self.future_close_readable + } + + /// The offset of the `stream_new` field. + pub fn stream_new(&self) -> u32 { + self.stream_new + } + + /// The offset of the `stream_write` field. + pub fn stream_write(&self) -> u32 { + self.stream_write + } + + /// The offset of the `stream_read` field. + pub fn stream_read(&self) -> u32 { + self.stream_read + } + + /// The offset of the `stream_cancel_write` field. + pub fn stream_cancel_write(&self) -> u32 { + self.stream_cancel_write + } + + /// The offset of the `stream_cancel_read` field. + pub fn stream_cancel_read(&self) -> u32 { + self.stream_cancel_read + } + + /// The offset of the `stream_close_writable` field. + pub fn stream_close_writable(&self) -> u32 { + self.stream_close_writable + } + + /// The offset of the `stream_close_readable` field. + pub fn stream_close_readable(&self) -> u32 { + self.stream_close_readable + } + + /// The offset of the `flat_stream_write` field. + pub fn flat_stream_write(&self) -> u32 { + self.flat_stream_write + } + + /// The offset of the `flat_stream_read` field. + pub fn flat_stream_read(&self) -> u32 { + self.flat_stream_read + } + + /// The offset of the `error_context_new` field. + pub fn error_context_new(&self) -> u32 { + self.error_context_new + } + + /// The offset of the `error_context_debug_message` field. + pub fn error_context_debug_message(&self) -> u32 { + self.error_context_debug_message + } + + /// The offset of the `error_context_drop` field. + pub fn error_context_drop(&self) -> u32 { + self.error_context_drop + } + /// The offset of the `VMLowering` for the `index` specified. #[inline] pub fn lowering(&self, index: LoweredIndex) -> u32 { @@ -280,6 +529,20 @@ impl VMComponentOffsets

{ self.runtime_reallocs() + index.as_u32() * u32::from(self.ptr.size()) } + /// The offset of the base of the `runtime_callbacks` field + #[inline] + pub fn runtime_callbacks(&self) -> u32 { + self.callbacks + } + + /// The offset of the `*mut VMFuncRef` for the runtime index + /// provided. + #[inline] + pub fn runtime_callback(&self, index: RuntimeCallbackIndex) -> u32 { + assert!(index.as_u32() < self.num_runtime_callbacks); + self.runtime_callbacks() + index.as_u32() * u32::from(self.ptr.size()) + } + /// The offset of the base of the `runtime_post_returns` field #[inline] pub fn runtime_post_returns(&self) -> u32 { diff --git a/crates/environ/src/fact.rs b/crates/environ/src/fact.rs index 4afdac971ff7..b8485444eba7 100644 --- a/crates/environ/src/fact.rs +++ b/crates/environ/src/fact.rs @@ -21,7 +21,7 @@ use crate::component::dfg::CoreDef; use crate::component::{ Adapter, AdapterOptions as AdapterOptionsDfg, ComponentTypesBuilder, FlatType, InterfaceType, - StringEncoding, Transcode, TypeFuncIndex, + RuntimeComponentInstanceIndex, StringEncoding, Transcode, TypeFuncIndex, }; use crate::fact::transcode::Transcoder; use crate::prelude::*; @@ -64,6 +64,11 @@ pub struct Module<'a> { imported_resource_transfer_borrow: Option, imported_resource_enter_call: Option, imported_resource_exit_call: Option, + imported_async_enter_call: Option, + imported_async_exit_call: Option, + imported_future_transfer: Option, + imported_stream_transfer: Option, + imported_error_context_transfer: Option, // Current status of index spaces from the imports generated so far. imported_funcs: PrimaryMap>, @@ -73,6 +78,11 @@ pub struct Module<'a> { funcs: PrimaryMap, helper_funcs: HashMap, helper_worklist: Vec<(FunctionId, Helper)>, + + globals_by_type: [Vec; 4], + globals: Vec, + + exports: Vec<(u32, String)>, } struct AdapterData { @@ -95,6 +105,7 @@ struct AdapterData { /// These options are typically unique per-adapter and generally aren't needed /// when translating recursive types within an adapter. struct AdapterOptions { + instance: RuntimeComponentInstanceIndex, /// The ascribed type of this adapter. ty: TypeFuncIndex, /// The global that represents the instance flags for where this adapter @@ -122,6 +133,8 @@ struct Options { /// An optionally-specified function to be used to allocate space for /// types such as strings as they go into a module. realloc: Option, + callback: Option, + async_: bool, } enum Context { @@ -187,6 +200,14 @@ impl<'a> Module<'a> { imported_resource_transfer_borrow: None, imported_resource_enter_call: None, imported_resource_exit_call: None, + imported_async_enter_call: None, + imported_async_exit_call: None, + imported_future_transfer: None, + imported_stream_transfer: None, + imported_error_context_transfer: None, + globals_by_type: Default::default(), + globals: Default::default(), + exports: Vec::new(), } } @@ -240,6 +261,28 @@ impl<'a> Module<'a> { } } + fn allocate(&mut self, counts: &mut [usize; 4], ty: ValType) -> u32 { + let which = match ty { + ValType::I32 => 0, + ValType::I64 => 1, + ValType::F32 => 2, + ValType::F64 => 3, + _ => unreachable!(), + }; + + let index = counts[which]; + counts[which] += 1; + + if let Some(offset) = self.globals_by_type[which].get(index) { + *offset + } else { + let offset = u32::try_from(self.globals.len()).unwrap(); + self.globals_by_type[which].push(offset); + self.globals.push(ty); + offset + } + } + fn import_options(&mut self, ty: TypeFuncIndex, options: &AdapterOptionsDfg) -> AdapterOptions { let AdapterOptionsDfg { instance, @@ -248,7 +291,10 @@ impl<'a> Module<'a> { memory64, realloc, post_return: _, // handled above + callback, + async_, } = options; + let flags = self.import_global( "flags", &format!("instance{}", instance.as_u32()), @@ -287,8 +333,26 @@ impl<'a> Module<'a> { func.clone(), ) }); + let callback = callback.as_ref().map(|func| { + let ptr = if *memory64 { + ValType::I64 + } else { + ValType::I32 + }; + let ty = self.core_types.function( + &[ptr, ValType::I32, ValType::I32, ValType::I32], + &[ValType::I32], + ); + self.import_func( + "callback", + &format!("f{}", self.imported_funcs.len()), + ty, + func.clone(), + ) + }); AdapterOptions { + instance: *instance, ty, flags, post_return: None, @@ -297,6 +361,8 @@ impl<'a> Module<'a> { memory64: *memory64, memory, realloc, + callback, + async_: *async_, }, } } @@ -397,6 +463,76 @@ impl<'a> Module<'a> { idx } + fn import_async_enter_call(&mut self) -> FuncIndex { + self.import_simple( + "async", + "enter-call", + &[ + ValType::FUNCREF, + ValType::FUNCREF, + ValType::I32, + ValType::I32, + ValType::I32, + ], + &[], + Import::AsyncEnterCall, + |me| &mut me.imported_async_enter_call, + ) + } + + fn import_async_exit_call(&mut self, callback: Option) -> FuncIndex { + self.import_simple( + "async", + "exit-call", + &[ + ValType::FUNCREF, + ValType::I32, + ValType::I32, + ValType::I32, + ValType::I32, + ], + &[ValType::I32], + Import::AsyncExitCall( + callback + .map(|callback| self.imported_funcs.get(callback).unwrap().clone().unwrap()), + ), + |me| &mut me.imported_async_exit_call, + ) + } + + fn import_future_transfer(&mut self) -> FuncIndex { + self.import_simple( + "future", + "transfer", + &[ValType::I32; 4], + &[ValType::I32], + Import::FutureTransfer, + |me| &mut me.imported_future_transfer, + ) + } + + fn import_stream_transfer(&mut self) -> FuncIndex { + self.import_simple( + "stream", + "transfer", + &[ValType::I32; 4], + &[ValType::I32], + Import::StreamTransfer, + |me| &mut me.imported_stream_transfer, + ) + } + + fn import_error_context_transfer(&mut self) -> FuncIndex { + self.import_simple( + "error-context", + "transfer", + &[ValType::I32; 3], + &[ValType::I32], + Import::ErrorContextTransfer, + |me| &mut me.imported_error_context_transfer, + ) + } + fn import_resource_transfer_own(&mut self) -> FuncIndex { self.import_simple( "resource", @@ -472,6 +608,11 @@ impl<'a> Module<'a> { exports.export(name, ExportKind::Func, idx.as_u32()); } } + for (idx, name) in &self.exports { + exports.export(name, ExportKind::Func, *idx); + } + + let imported_global_count = u32::try_from(self.imported_globals.len()).unwrap(); // With all functions numbered the fragments of the body of each // function can be assigned into one final adapter function. @@ -504,6 +645,15 @@ impl<'a> Module<'a> { Body::Call(id) => { Instruction::Call(id_to_index[*id].as_u32()).encode(&mut body); } + Body::RefFunc(id) => { + Instruction::RefFunc(id_to_index[*id].as_u32()).encode(&mut body); + } + Body::GlobalGet(offset) => { + Instruction::GlobalGet(offset + imported_global_count).encode(&mut body); + } + Body::GlobalSet(offset) => { + Instruction::GlobalSet(offset + imported_global_count).encode(&mut body); + } } } code.raw(&body); @@ -512,10 +662,29 @@ impl<'a> Module<'a> { let traps = traps.finish(); + let mut globals = GlobalSection::new(); + for ty in &self.globals { + globals.global( + GlobalType { + val_type: *ty, + mutable: true, + shared: false, + }, + &match ty { + ValType::I32 => ConstExpr::i32_const(0), + ValType::I64 => ConstExpr::i64_const(0), + ValType::F32 => ConstExpr::f32_const(0_f32), + ValType::F64 => ConstExpr::f64_const(0_f64), + _ => unreachable!(), + }, + ); + } + let mut result = wasm_encoder::Module::new(); result.section(&self.core_types.section); result.section(&self.core_imports); result.section(&funcs); + result.section(&globals); result.section(&exports); result.section(&code); if self.debug { @@ -561,6 +730,21 @@ pub enum Import { /// Tears down a previous entry and handles checking borrow-related /// metadata. ResourceExitCall, + /// An intrinsic used by FACT-generated modules to begin a call to an + /// async-lowered import function. + AsyncEnterCall, + /// An intrinsic used by FACT-generated modules to complete a call to an + /// async-lowered import function. + AsyncExitCall(Option), + /// An intrinisic used by FACT-generated modules to (partially or entirely) transfer + /// ownership of a `future`. + FutureTransfer, + /// An intrinisic used by FACT-generated modules to (partially or entirely) transfer + /// ownership of a `stream`. + StreamTransfer, + /// An intrinisic used by FACT-generated modules to (partially or entirely) transfer + /// ownership of an `error-context`. + ErrorContextTransfer, } impl Options { @@ -659,6 +843,9 @@ struct Function { enum Body { Raw(Vec, Vec<(usize, traps::Trap)>), Call(FunctionId), + RefFunc(FunctionId), + GlobalGet(u32), + GlobalSet(u32), } impl Function { diff --git a/crates/environ/src/fact/signature.rs b/crates/environ/src/fact/signature.rs index 328ec085e359..72474cd4468f 100644 --- a/crates/environ/src/fact/signature.rs +++ b/crates/environ/src/fact/signature.rs @@ -13,6 +13,14 @@ pub struct Signature { pub params: Vec, /// Core wasm results. pub results: Vec, + /// Indicator to whether parameters are indirect, meaning that the first + /// entry of `params` is a pointer type which all parameters are loaded + /// through. + pub params_indirect: bool, + /// Indicator whether results are passed indirectly. This may mean that + /// `results` is an `i32` or that `params` ends with an `i32` depending on + /// the `Context`. + pub results_indirect: bool, } impl ComponentTypesBuilder { @@ -26,6 +34,16 @@ impl ComponentTypesBuilder { let ty = &self[options.ty]; let ptr_ty = options.options.ptr(); + if let (Context::Lower, true) = (&context, options.options.async_) { + return Signature { + params: vec![ptr_ty; 2], + results: vec![ValType::I32], + params_indirect: true, + results_indirect: true, + }; + } + + let mut params_indirect = false; let mut params = match self.flatten_types( &options.options, MAX_FLAT_PARAMS, @@ -33,10 +51,21 @@ impl ComponentTypesBuilder { ) { Some(list) => list, None => { + params_indirect = true; vec![ptr_ty] } }; + if options.options.async_ { + return Signature { + params, + results: vec![ptr_ty], + params_indirect, + results_indirect: false, + }; + } + + let mut results_indirect = false; let results = match self.flatten_types( &options.options, MAX_FLAT_RESULTS, @@ -44,6 +73,7 @@ impl ComponentTypesBuilder { ) { Some(list) => list, None => { + results_indirect = true; match context { // For a lifted function too-many-results gets translated to a // returned pointer where results are read from. The callee @@ -59,7 +89,70 @@ impl ComponentTypesBuilder { } } }; - Signature { params, results } + Signature { + params, + results, + params_indirect, + results_indirect, + } + } + + pub(super) fn async_start_signature(&self, options: &AdapterOptions) -> Signature { + let ty = &self[options.ty]; + let ptr_ty = options.options.ptr(); + + let mut params = vec![ptr_ty]; + + let mut results_indirect = false; + let results = match self.flatten_types( + &options.options, + MAX_FLAT_PARAMS, + self[ty.params].types.iter().map(|ty| *ty), + ) { + Some(list) => list, + None => { + results_indirect = true; + params.push(ptr_ty); + Vec::new() + } + }; + Signature { + params, + results, + params_indirect: false, + results_indirect, + } + } + + pub(super) fn async_return_signature(&self, options: &AdapterOptions) -> Signature { + let ty = &self[options.ty]; + let ptr_ty = options.options.ptr(); + + let mut params_indirect = false; + let mut params = match self.flatten_types( + &options.options, + if options.options.async_ { + MAX_FLAT_PARAMS + } else { + MAX_FLAT_RESULTS + }, + self[ty.results].types.iter().copied(), + ) { + Some(list) => list, + None => { + params_indirect = true; + vec![ptr_ty] + } + }; + // Add return pointer + params.push(ptr_ty); + + Signature { + params, + results: Vec::new(), + params_indirect, + results_indirect: false, + } } /// Pushes the flat version of a list of component types into a final result diff --git a/crates/environ/src/fact/trampoline.rs b/crates/environ/src/fact/trampoline.rs index 8af700de421b..abf180f31a18 100644 --- a/crates/environ/src/fact/trampoline.rs +++ b/crates/environ/src/fact/trampoline.rs @@ -17,9 +17,10 @@ use crate::component::{ CanonicalAbiInfo, ComponentTypesBuilder, FixedEncoding as FE, FlatType, InterfaceType, - StringEncoding, Transcode, TypeEnumIndex, TypeFlagsIndex, TypeListIndex, TypeOptionIndex, - TypeRecordIndex, TypeResourceTableIndex, TypeResultIndex, TypeTupleIndex, TypeVariantIndex, - VariantInfo, FLAG_MAY_ENTER, FLAG_MAY_LEAVE, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, + StringEncoding, Transcode, TypeEnumIndex, TypeErrorContextTableIndex, TypeFlagsIndex, + TypeFutureTableIndex, TypeListIndex, TypeOptionIndex, TypeRecordIndex, TypeResourceTableIndex, + TypeResultIndex, TypeStreamTableIndex, TypeTupleIndex, TypeVariantIndex, VariantInfo, + FLAG_MAY_ENTER, FLAG_MAY_LEAVE, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, }; use crate::fact::signature::Signature; use crate::fact::transcode::Transcoder; @@ -39,6 +40,9 @@ use wasmtime_component_util::{DiscriminantSize, FlagsSize}; const MAX_STRING_BYTE_LENGTH: u32 = 1 << 31; const UTF16_TAG: u32 = 1 << 31; +const EXIT_FLAG_ASYNC_CALLER: i32 = 1 << 0; +const EXIT_FLAG_ASYNC_CALLEE: i32 = 1 << 1; + /// This value is arbitrarily chosen and should be fine to change at any time, /// it just seemed like a halfway reasonable starting point. const INITIAL_FUEL: usize = 1_000; @@ -80,37 +84,168 @@ struct Compiler<'a, 'b> { } pub(super) fn compile(module: &mut Module<'_>, adapter: &AdapterData) { - let lower_sig = module.types.signature(&adapter.lower, Context::Lower); - let lift_sig = module.types.signature(&adapter.lift, Context::Lift); - let ty = module - .core_types - .function(&lower_sig.params, &lower_sig.results); - let result = module - .funcs - .push(Function::new(Some(adapter.name.clone()), ty)); - - // If this type signature contains any borrowed resources then invocations - // of enter/exit call for resource-related metadata tracking must be used. - // It shouldn't matter whether the lower/lift signature is used here as both - // should return the same answer. - let emit_resource_call = module.types.contains_borrow_resource(&adapter.lower); - assert_eq!( - emit_resource_call, - module.types.contains_borrow_resource(&adapter.lift) - ); - - Compiler { - types: module.types, - module, - code: Vec::new(), - nlocals: lower_sig.params.len() as u32, - free_locals: HashMap::new(), - traps: Vec::new(), - result, - fuel: INITIAL_FUEL, - emit_resource_call, + fn compiler<'a, 'b>( + module: &'b mut Module<'a>, + adapter: &AdapterData, + ) -> (Compiler<'a, 'b>, Signature, Signature) { + let lower_sig = module.types.signature(&adapter.lower, Context::Lower); + let lift_sig = module.types.signature(&adapter.lift, Context::Lift); + let ty = module + .core_types + .function(&lower_sig.params, &lower_sig.results); + let result = module + .funcs + .push(Function::new(Some(adapter.name.clone()), ty)); + + // If this type signature contains any borrowed resources then invocations + // of enter/exit call for resource-related metadata tracking must be used. + // It shouldn't matter whether the lower/lift signature is used here as both + // should return the same answer. + let emit_resource_call = module.types.contains_borrow_resource(&adapter.lower); + assert_eq!( + emit_resource_call, + module.types.contains_borrow_resource(&adapter.lift) + ); + + ( + Compiler { + types: module.types, + module, + code: Vec::new(), + nlocals: lower_sig.params.len() as u32, + free_locals: HashMap::new(), + traps: Vec::new(), + result, + fuel: INITIAL_FUEL, + emit_resource_call, + }, + lower_sig, + lift_sig, + ) + } + + let start_adapter = |module: &mut Module, param_globals| { + let sig = module.types.async_start_signature(&adapter.lift); + let ty = module.core_types.function(&sig.params, &sig.results); + let result = module.funcs.push(Function::new( + Some(format!("[async-start]{}", adapter.name)), + ty, + )); + + Compiler { + types: module.types, + module, + code: Vec::new(), + nlocals: sig.params.len() as u32, + free_locals: HashMap::new(), + traps: Vec::new(), + result, + fuel: INITIAL_FUEL, + emit_resource_call: false, + } + .compile_async_start_adapter(adapter, &sig, param_globals); + + result + }; + + let return_adapter = |module: &mut Module, result_globals| { + let sig = module.types.async_return_signature(&adapter.lift); + let ty = module.core_types.function(&sig.params, &sig.results); + let result = module.funcs.push(Function::new( + Some(format!("[async-return]{}", adapter.name)), + ty, + )); + + Compiler { + types: module.types, + module, + code: Vec::new(), + nlocals: sig.params.len() as u32, + free_locals: HashMap::new(), + traps: Vec::new(), + result, + fuel: INITIAL_FUEL, + emit_resource_call: false, + } + .compile_async_return_adapter(adapter, &sig, result_globals); + + result + }; + + match (adapter.lower.options.async_, adapter.lift.options.async_) { + (false, false) => { + let (compiler, lower_sig, lift_sig) = compiler(module, adapter); + compiler.compile_sync_to_sync_adapter(adapter, &lower_sig, &lift_sig) + } + (true, true) => { + let start = start_adapter(module, None); + let return_ = return_adapter(module, None); + let (compiler, _, lift_sig) = compiler(module, adapter); + compiler.compile_async_to_async_adapter( + adapter, + start, + return_, + i32::try_from(lift_sig.params.len()).unwrap(), + ); + } + (false, true) => { + let lower_sig = module.types.signature(&adapter.lower, Context::Lower); + let param_globals = if lower_sig.params_indirect { + None + } else { + let mut counts = [0; 4]; + Some( + lower_sig + .params + .iter() + .take(if lower_sig.results_indirect { + lower_sig.params.len() - 1 + } else { + lower_sig.params.len() + }) + .map(|ty| module.allocate(&mut counts, *ty)) + .collect::>(), + ) + }; + let result_globals = if lower_sig.results_indirect { + None + } else { + let mut counts = [0; 4]; + Some( + lower_sig + .results + .iter() + .map(|ty| module.allocate(&mut counts, *ty)) + .collect::>(), + ) + }; + + let start = start_adapter(module, param_globals.as_deref()); + let return_ = return_adapter(module, result_globals.as_deref()); + let (compiler, _, lift_sig) = compiler(module, adapter); + compiler.compile_sync_to_async_adapter( + adapter, + start, + return_, + i32::try_from(lift_sig.params.len()).unwrap(), + param_globals.as_deref(), + result_globals.as_deref(), + ); + } + (true, false) => { + let lift_sig = module.types.signature(&adapter.lift, Context::Lift); + let start = start_adapter(module, None); + let return_ = return_adapter(module, None); + let (compiler, ..) = compiler(module, adapter); + compiler.compile_async_to_sync_adapter( + adapter, + start, + return_, + i32::try_from(lift_sig.params.len()).unwrap(), + i32::try_from(lift_sig.results.len()).unwrap(), + ); + } } - .compile_adapter(adapter, &lower_sig, &lift_sig) } /// Compiles a helper function as specified by the `Helper` configuration. @@ -244,7 +379,249 @@ struct Memory<'a> { } impl Compiler<'_, '_> { - fn compile_adapter( + fn compile_async_to_async_adapter( + mut self, + adapter: &AdapterData, + start: FunctionId, + return_: FunctionId, + param_count: i32, + ) { + let enter = self.module.import_async_enter_call(); + let exit = self + .module + .import_async_exit_call(Some(adapter.lift.options.callback.unwrap())); + + self.flush_code(); + self.module.funcs[self.result] + .body + .push(Body::RefFunc(start)); + self.module.funcs[self.result] + .body + .push(Body::RefFunc(return_)); + self.instruction(I32Const( + i32::try_from(adapter.lower.instance.as_u32()).unwrap(), + )); + self.instruction(LocalGet(0)); + self.instruction(LocalGet(1)); + self.instruction(Call(enter.as_u32())); + + // TODO: As an optimization, consider checking the backpressure flag on the callee instance and, if it's + // unset, translate the params and call the callee function directly here (and make sure `exit` knows _not_ + // to call it in that case). + + self.module.exports.push(( + adapter.callee.as_u32(), + format!("[adapter-callee]{}", adapter.name), + )); + self.instruction(RefFunc(adapter.callee.as_u32())); + self.instruction(I32Const( + i32::try_from(adapter.lift.instance.as_u32()).unwrap(), + )); + self.instruction(I32Const(param_count)); + self.instruction(I32Const(1)); // leave room for the guest context result + self.instruction(I32Const(EXIT_FLAG_ASYNC_CALLER | EXIT_FLAG_ASYNC_CALLEE)); + self.instruction(Call(exit.as_u32())); + + self.finish() + } + + fn compile_sync_to_async_adapter( + mut self, + adapter: &AdapterData, + start: FunctionId, + return_: FunctionId, + param_count: i32, + param_globals: Option<&[u32]>, + result_globals: Option<&[u32]>, + ) { + let enter = self.module.import_async_enter_call(); + let exit = self + .module + .import_async_exit_call(Some(adapter.lift.options.callback.unwrap())); + + self.flush_code(); + self.module.funcs[self.result] + .body + .push(Body::RefFunc(start)); + self.module.funcs[self.result] + .body + .push(Body::RefFunc(return_)); + self.instruction(I32Const( + i32::try_from(adapter.lower.instance.as_u32()).unwrap(), + )); + + let results_local = if let Some(globals) = param_globals { + for (local, global) in globals.iter().enumerate() { + self.instruction(LocalGet(u32::try_from(local).unwrap())); + self.flush_code(); + self.module.funcs[self.result] + .body + .push(Body::GlobalSet(*global)); + } + self.instruction(I32Const(0)); // dummy params pointer + u32::try_from(globals.len()).unwrap() + } else { + self.instruction(LocalGet(0)); + 1 + }; + + if result_globals.is_some() { + self.instruction(I32Const(0)); // dummy results pointer + } else { + self.instruction(LocalGet(results_local)); + } + + self.instruction(Call(enter.as_u32())); + + // TODO: As an optimization, consider checking the backpressure flag on the callee instance and, if it's + // unset, translate the params and call the callee function directly here (and make sure `exit` knows _not_ + // to call it in that case). + + self.module.exports.push(( + adapter.callee.as_u32(), + format!("[adapter-callee]{}", adapter.name), + )); + self.instruction(RefFunc(adapter.callee.as_u32())); + self.instruction(I32Const( + i32::try_from(adapter.lift.instance.as_u32()).unwrap(), + )); + self.instruction(I32Const(param_count)); + self.instruction(I32Const(1)); // leave room for the guest context result + self.instruction(I32Const(EXIT_FLAG_ASYNC_CALLEE)); + self.instruction(Call(exit.as_u32())); + self.instruction(Drop); + + if let Some(globals) = result_globals { + self.flush_code(); + for global in globals { + self.module.funcs[self.result] + .body + .push(Body::GlobalGet(*global)); + } + } + + self.finish() + } + + fn compile_async_to_sync_adapter( + mut self, + adapter: &AdapterData, + start: FunctionId, + return_: FunctionId, + param_count: i32, + result_count: i32, + ) { + let enter = self.module.import_async_enter_call(); + let exit = self.module.import_async_exit_call(None); + + self.flush_code(); + self.module.funcs[self.result] + .body + .push(Body::RefFunc(start)); + self.module.funcs[self.result] + .body + .push(Body::RefFunc(return_)); + self.instruction(I32Const( + i32::try_from(adapter.lower.instance.as_u32()).unwrap(), + )); + self.instruction(LocalGet(0)); + self.instruction(LocalGet(1)); + self.instruction(Call(enter.as_u32())); + self.module.exports.push(( + adapter.callee.as_u32(), + format!("[adapter-callee]{}", adapter.name), + )); + self.instruction(RefFunc(adapter.callee.as_u32())); + self.instruction(I32Const( + i32::try_from(adapter.lift.instance.as_u32()).unwrap(), + )); + self.instruction(I32Const(param_count)); + self.instruction(I32Const(result_count)); + self.instruction(I32Const(EXIT_FLAG_ASYNC_CALLER)); + self.instruction(Call(exit.as_u32())); + + self.finish() + } + + fn compile_async_start_adapter( + mut self, + adapter: &AdapterData, + sig: &Signature, + param_globals: Option<&[u32]>, + ) { + let mut temps = Vec::new(); + let param_locals = if let Some(globals) = param_globals { + for global in globals { + let ty = self.module.globals[usize::try_from(*global).unwrap()]; + + self.flush_code(); + self.module.funcs[self.result] + .body + .push(Body::GlobalGet(*global)); + temps.push(self.local_set_new_tmp(ty)); + } + temps + .iter() + .map(|t| (t.idx, t.ty)) + .chain(if sig.results_indirect { + sig.params + .iter() + .enumerate() + .map(|(i, ty)| (i as u32, *ty)) + .last() + } else { + None + }) + .collect::>() + } else { + sig.params + .iter() + .enumerate() + .map(|(i, ty)| (i as u32, *ty)) + .collect::>() + }; + + self.set_flag(adapter.lift.flags, FLAG_MAY_LEAVE, false); + self.translate_params(adapter, ¶m_locals); + self.set_flag(adapter.lift.flags, FLAG_MAY_LEAVE, true); + + for tmp in temps { + self.free_temp_local(tmp); + } + + self.finish(); + } + + fn compile_async_return_adapter( + mut self, + adapter: &AdapterData, + sig: &Signature, + result_globals: Option<&[u32]>, + ) { + let param_locals = sig + .params + .iter() + .enumerate() + .map(|(i, ty)| (i as u32, *ty)) + .collect::>(); + + self.set_flag(adapter.lower.flags, FLAG_MAY_LEAVE, false); + self.translate_results(adapter, ¶m_locals, ¶m_locals); + self.set_flag(adapter.lower.flags, FLAG_MAY_LEAVE, true); + + if let Some(globals) = result_globals { + self.flush_code(); + for global in globals { + self.module.funcs[self.result] + .body + .push(Body::GlobalSet(*global)); + } + } + + self.finish() + } + + fn compile_sync_to_sync_adapter( mut self, adapter: &AdapterData, lower_sig: &Signature, @@ -362,9 +739,12 @@ impl Compiler<'_, '_> { // TODO: handle subtyping assert_eq!(src_tys.len(), dst_tys.len()); - let src_flat = + let src_flat = if adapter.lower.options.async_ { + None + } else { self.types - .flatten_types(lower_opts, MAX_FLAT_PARAMS, src_tys.iter().copied()); + .flatten_types(lower_opts, MAX_FLAT_PARAMS, src_tys.iter().copied()) + }; let dst_flat = self.types .flatten_types(lift_opts, MAX_FLAT_PARAMS, dst_tys.iter().copied()); @@ -391,16 +771,28 @@ impl Compiler<'_, '_> { let dst = if let Some(flat) = &dst_flat { Destination::Stack(flat, lift_opts) } else { - // If there are too many parameters then space is allocated in the - // destination module for the parameters via its `realloc` function. - let abi = CanonicalAbiInfo::record(dst_tys.iter().map(|t| self.types.canonical_abi(t))); - let (size, align) = if lift_opts.memory64 { - (abi.size64, abi.align64) + if lift_opts.async_ { + let align = dst_tys + .iter() + .map(|t| self.types.align(lift_opts, t)) + .max() + .unwrap_or(1); + let (addr, ty) = *param_locals.last().expect("no retptr"); + assert_eq!(ty, lift_opts.ptr()); + Destination::Memory(self.memory_operand(lift_opts, TempLocal::new(addr, ty), align)) } else { - (abi.size32, abi.align32) - }; - let size = MallocSize::Const(size); - Destination::Memory(self.malloc(lift_opts, size, align)) + // If there are too many parameters then space is allocated in the + // destination module for the parameters via its `realloc` function. + let abi = + CanonicalAbiInfo::record(dst_tys.iter().map(|t| self.types.canonical_abi(t))); + let (size, align) = if lift_opts.memory64 { + (abi.size64, abi.align64) + } else { + (abi.size32, abi.align32) + }; + let size = MallocSize::Const(size); + Destination::Memory(self.malloc(lift_opts, size, align)) + } }; let srcs = src @@ -416,7 +808,7 @@ impl Compiler<'_, '_> { // If the destination was linear memory instead of the stack then the // actual parameter that we're passing is the address of the values // stored, so ensure that's happening in the wasm body here. - if let Destination::Memory(mem) = dst { + if let (Destination::Memory(mem), false) = (dst, lift_opts.async_) { self.instruction(LocalGet(mem.addr.idx)); self.free_temp_local(mem.addr); } @@ -443,12 +835,21 @@ impl Compiler<'_, '_> { let lift_opts = &adapter.lift.options; let lower_opts = &adapter.lower.options; - let src_flat = - self.types - .flatten_types(lift_opts, MAX_FLAT_RESULTS, src_tys.iter().copied()); - let dst_flat = + let src_flat = self.types.flatten_types( + lift_opts, + if lift_opts.async_ { + MAX_FLAT_PARAMS + } else { + MAX_FLAT_RESULTS + }, + src_tys.iter().copied(), + ); + let dst_flat = if lower_opts.async_ { + None + } else { self.types - .flatten_types(lower_opts, MAX_FLAT_RESULTS, dst_tys.iter().copied()); + .flatten_types(lower_opts, MAX_FLAT_RESULTS, dst_tys.iter().copied()) + }; let src = if src_flat.is_some() { Source::Stack(Stack { @@ -465,7 +866,7 @@ impl Compiler<'_, '_> { .map(|t| self.types.align(lift_opts, t)) .max() .unwrap_or(1); - assert_eq!(result_locals.len(), 1); + assert_eq!(result_locals.len(), if lower_opts.async_ { 2 } else { 1 }); let (addr, ty) = result_locals[0]; assert_eq!(ty, lift_opts.ptr()); Source::Memory(self.memory_operand(lift_opts, TempLocal::new(addr, ty), align)) @@ -587,7 +988,11 @@ impl Compiler<'_, '_> { InterfaceType::Option(_) | InterfaceType::Result(_) => 2, // TODO(#6696) - something nonzero, is 1 right? - InterfaceType::Own(_) | InterfaceType::Borrow(_) => 1, + InterfaceType::Own(_) + | InterfaceType::Borrow(_) + | InterfaceType::Future(_) + | InterfaceType::Stream(_) + | InterfaceType::ErrorContext(_) => 1, }; match self.fuel.checked_sub(cost) { @@ -622,6 +1027,11 @@ impl Compiler<'_, '_> { InterfaceType::Result(t) => self.translate_result(*t, src, dst_ty, dst), InterfaceType::Own(t) => self.translate_own(*t, src, dst_ty, dst), InterfaceType::Borrow(t) => self.translate_borrow(*t, src, dst_ty, dst), + InterfaceType::Future(t) => self.translate_future(*t, src, dst_ty, dst), + InterfaceType::Stream(t) => self.translate_stream(*t, src, dst_ty, dst), + InterfaceType::ErrorContext(t) => { + self.translate_error_context_context(*t, src, dst_ty, dst) + } } } @@ -2448,6 +2858,54 @@ impl Compiler<'_, '_> { } } + fn translate_future( + &mut self, + src_ty: TypeFutureTableIndex, + src: &Source<'_>, + dst_ty: &InterfaceType, + dst: &Destination, + ) { + let dst_ty = match dst_ty { + InterfaceType::Future(t) => *t, + _ => panic!("expected a `Future`"), + }; + assert!(src_ty == dst_ty); + let transfer = self.module.import_future_transfer(); + self.translate_handle(src_ty.as_u32(), src, dst_ty.as_u32(), dst, transfer); + } + + fn translate_stream( + &mut self, + src_ty: TypeStreamTableIndex, + src: &Source<'_>, + dst_ty: &InterfaceType, + dst: &Destination, + ) { + let dst_ty = match dst_ty { + InterfaceType::Stream(t) => *t, + _ => panic!("expected a `Stream`"), + }; + assert!(src_ty == dst_ty); + let transfer = self.module.import_stream_transfer(); + self.translate_handle(src_ty.as_u32(), src, dst_ty.as_u32(), dst, transfer); + } + + fn translate_error_context_context( + &mut self, + src_ty: TypeErrorContextTableIndex, + src: &Source<'_>, + dst_ty: &InterfaceType, + dst: &Destination, + ) { + let dst_ty = match dst_ty { + InterfaceType::ErrorContext(t) => *t, + _ => panic!("expected an `ErrorContext`"), + }; + assert!(src_ty == dst_ty); + let transfer = self.module.import_error_context_transfer(); + self.translate_handle(src_ty.as_u32(), src, dst_ty.as_u32(), dst, transfer); + } + fn translate_own( &mut self, src_ty: TypeResourceTableIndex, @@ -2460,7 +2918,7 @@ impl Compiler<'_, '_> { _ => panic!("expected an `Own`"), }; let transfer = self.module.import_resource_transfer_own(); - self.translate_resource(src_ty, src, dst_ty, dst, transfer); + self.translate_handle(src_ty.as_u32(), src, dst_ty.as_u32(), dst, transfer); } fn translate_borrow( @@ -2476,7 +2934,7 @@ impl Compiler<'_, '_> { }; let transfer = self.module.import_resource_transfer_borrow(); - self.translate_resource(src_ty, src, dst_ty, dst, transfer); + self.translate_handle(src_ty.as_u32(), src, dst_ty.as_u32(), dst, transfer); } /// Translates the index `src`, which resides in the table `src_ty`, into @@ -2486,11 +2944,11 @@ impl Compiler<'_, '_> { /// cranelift-generated trampoline to satisfy this import will call. The /// `transfer` function is an imported function which takes the src, src_ty, /// and dst_ty, and returns the dst index. - fn translate_resource( + fn translate_handle( &mut self, - src_ty: TypeResourceTableIndex, + src_ty: u32, src: &Source<'_>, - dst_ty: TypeResourceTableIndex, + dst_ty: u32, dst: &Destination, transfer: FuncIndex, ) { @@ -2499,8 +2957,8 @@ impl Compiler<'_, '_> { Source::Memory(mem) => self.i32_load(mem), Source::Stack(stack) => self.stack_get(stack, ValType::I32), } - self.instruction(I32Const(src_ty.as_u32() as i32)); - self.instruction(I32Const(dst_ty.as_u32() as i32)); + self.instruction(I32Const(src_ty as i32)); + self.instruction(I32Const(dst_ty as i32)); self.instruction(Call(transfer.as_u32())); match dst { Destination::Memory(mem) => self.i32_store(mem), diff --git a/crates/environ/src/trap_encoding.rs b/crates/environ/src/trap_encoding.rs index 6c08471fd7ca..0cb1fd558fb8 100644 --- a/crates/environ/src/trap_encoding.rs +++ b/crates/environ/src/trap_encoding.rs @@ -88,7 +88,10 @@ pub enum Trap { /// would have violated the reentrance rules of the component model, /// triggering a trap instead. CannotEnterComponent, - // if adding a variant here be sure to update the `check!` macro below + + /// Async-lifted export failed to produce a result by calling `task.return` + /// before returning `STATUS_DONE`. + NoAsyncResult, // if adding a variant here be sure to update the `check!` macro below } impl Trap { @@ -124,6 +127,7 @@ impl Trap { AllocationTooLarge CastFailure CannotEnterComponent + NoAsyncResult } None @@ -154,6 +158,7 @@ impl fmt::Display for Trap { AllocationTooLarge => "allocation size too large", CastFailure => "cast failure", CannotEnterComponent => "cannot enter component instance", + NoAsyncResult => "async-lifted export failed to produce a result", }; write!(f, "wasm trap: {desc}") } diff --git a/crates/misc/component-test-util/src/lib.rs b/crates/misc/component-test-util/src/lib.rs index eddb7654722f..9875aec02af5 100644 --- a/crates/misc/component-test-util/src/lib.rs +++ b/crates/misc/component-test-util/src/lib.rs @@ -8,15 +8,23 @@ use wasmtime::component::{ComponentNamedList, ComponentType, Func, Lift, Lower, use wasmtime::{AsContextMut, Config, Engine}; pub trait TypedFuncExt { - fn call_and_post_return(&self, store: impl AsContextMut, params: P) -> Result; + fn call_and_post_return( + &self, + store: impl AsContextMut, + params: P, + ) -> Result; } impl TypedFuncExt for TypedFunc where P: ComponentNamedList + Lower, - R: ComponentNamedList + Lift, + R: ComponentNamedList + Lift + Send + Sync + 'static, { - fn call_and_post_return(&self, mut store: impl AsContextMut, params: P) -> Result { + fn call_and_post_return( + &self, + mut store: impl AsContextMut, + params: P, + ) -> Result { let result = self.call(&mut store, params)?; self.post_return(&mut store)?; Ok(result) diff --git a/crates/wasi-config/Cargo.toml b/crates/wasi-config/Cargo.toml index 81bd61ef7184..cadaf77bb7a2 100644 --- a/crates/wasi-config/Cargo.toml +++ b/crates/wasi-config/Cargo.toml @@ -13,7 +13,7 @@ workspace = true [dependencies] anyhow = { workspace = true } -wasmtime = { workspace = true, features = ["runtime", "component-model"] } +wasmtime = { workspace = true, features = ["runtime", "component-model", "async"] } [dev-dependencies] test-programs-artifacts = { workspace = true } diff --git a/crates/wasi-keyvalue/src/lib.rs b/crates/wasi-keyvalue/src/lib.rs index 4d53a9acf20b..d88b5220aeec 100644 --- a/crates/wasi-keyvalue/src/lib.rs +++ b/crates/wasi-keyvalue/src/lib.rs @@ -75,7 +75,7 @@ mod generated { }, trappable_error_type: { "wasi:keyvalue/store/error" => crate::Error, - }, + } }); } diff --git a/crates/wasi-keyvalue/wit/deps/keyvalue/atomic.wit b/crates/wasi-keyvalue/wit/deps/keyvalue/atomic.wit index 059efc48896b..02b3f6210c08 100644 --- a/crates/wasi-keyvalue/wit/deps/keyvalue/atomic.wit +++ b/crates/wasi-keyvalue/wit/deps/keyvalue/atomic.wit @@ -19,4 +19,4 @@ interface atomics { /// /// If any other error occurs, it returns an `Err(error)`. increment: func(bucket: borrow, key: string, delta: u64) -> result; -} \ No newline at end of file +} diff --git a/crates/wasmtime/Cargo.toml b/crates/wasmtime/Cargo.toml index 5d0818df1e80..65e0ccb01f57 100644 --- a/crates/wasmtime/Cargo.toml +++ b/crates/wasmtime/Cargo.toml @@ -60,6 +60,7 @@ smallvec = { workspace = true, optional = true } hashbrown = { workspace = true } libm = "0.2.7" bitflags = { workspace = true } +futures = { workspace = true, features = ["alloc"] } [target.'cfg(target_os = "windows")'.dependencies.windows-sys] workspace = true @@ -315,3 +316,12 @@ memory-protection-keys = ["pooling-allocator"] # version changes. This is why the re-export is guarded by a feature flag which # is off by default. reexport-wasmparser = [] + +# Enables support for the Component Model Async ABI, along with `future`, +# `stream`, and `error-context` types. +component-model-async = [ + "async", + "component-model", + "std", + 'wasmtime-component-macro?/component-model-async', +] diff --git a/crates/wasmtime/src/config.rs b/crates/wasmtime/src/config.rs index 4c0de5b0e810..7ad00d1b37ec 100644 --- a/crates/wasmtime/src/config.rs +++ b/crates/wasmtime/src/config.rs @@ -1027,6 +1027,17 @@ impl Config { self } + /// Configures whether components support the async ABI [proposal] for + /// lifting and lowering functions, as well as `stream`, `future`, and + /// `error-context` types. + /// + /// [proposal]: https://github.com/WebAssembly/component-model/blob/main/design/mvp/Async.md + #[cfg(feature = "component-model-async")] + pub fn wasm_component_model_async(&mut self, enable: bool) -> &mut Self { + self.wasm_feature(WasmFeatures::COMPONENT_MODEL_ASYNC, enable); + self + } + /// Configures which compilation strategy will be used for wasm modules. /// /// This method can be used to configure which compiler is used for wasm diff --git a/crates/wasmtime/src/engine/serialization.rs b/crates/wasmtime/src/engine/serialization.rs index f4bdd79eb0ce..98e942f20c68 100644 --- a/crates/wasmtime/src/engine/serialization.rs +++ b/crates/wasmtime/src/engine/serialization.rs @@ -204,6 +204,7 @@ struct WasmFeatures { custom_page_sizes: bool, component_model_more_flags: bool, component_model_multiple_returns: bool, + component_model_async: bool, gc_types: bool, wide_arithmetic: bool, } @@ -233,6 +234,7 @@ impl Metadata<'_> { component_model_nested_names, component_model_more_flags, component_model_multiple_returns, + component_model_async, legacy_exceptions, gc_types, stack_switching, @@ -278,6 +280,7 @@ impl Metadata<'_> { custom_page_sizes, component_model_more_flags, component_model_multiple_returns, + component_model_async, gc_types, wide_arithmetic, }, @@ -482,6 +485,7 @@ impl Metadata<'_> { custom_page_sizes, component_model_more_flags, component_model_multiple_returns, + component_model_async, gc_types, wide_arithmetic, } = self.features; @@ -568,6 +572,11 @@ impl Metadata<'_> { other.contains(F::COMPONENT_MODEL_MULTIPLE_RETURNS), "WebAssembly component model support for multiple returns", )?; + Self::check_bool( + component_model_async, + other.contains(F::COMPONENT_MODEL_ASYNC), + "WebAssembly component model support for async lifts/lowers, futures, streams, and errors", + )?; Self::check_cfg_bool( cfg!(feature = "gc"), "gc", diff --git a/crates/wasmtime/src/runtime/component/component.rs b/crates/wasmtime/src/runtime/component/component.rs index 21add33282fd..140e372e8b58 100644 --- a/crates/wasmtime/src/runtime/component/component.rs +++ b/crates/wasmtime/src/runtime/component/component.rs @@ -602,6 +602,7 @@ impl Component { GlobalInitializer::LowerImport { .. } | GlobalInitializer::ExtractMemory(_) | GlobalInitializer::ExtractRealloc(_) + | GlobalInitializer::ExtractCallback(_) | GlobalInitializer::ExtractPostReturn(_) | GlobalInitializer::Resource(_) => {} } diff --git a/crates/wasmtime/src/runtime/component/concurrent.rs b/crates/wasmtime/src/runtime/component/concurrent.rs new file mode 100644 index 000000000000..da304c36241e --- /dev/null +++ b/crates/wasmtime/src/runtime/component/concurrent.rs @@ -0,0 +1,1817 @@ +use { + crate::{ + component::func::{self, Func, Lower as _, LowerContext, Options}, + vm::{ + component::VMComponentContext, + mpk::{self, ProtectionMask}, + AsyncWasmCallState, PreviousAsyncWasmCallState, SendSyncPtr, VMFuncRef, + VMMemoryDefinition, VMOpaqueContext, VMStore, + }, + AsContextMut, Engine, StoreContextMut, ValRaw, + }, + anyhow::{anyhow, bail, Result}, + futures::{ + future::{self, Either, FutureExt}, + stream::{FuturesUnordered, StreamExt}, + }, + once_cell::sync::Lazy, + ready_chunks::ReadyChunks, + std::{ + any::Any, + borrow::ToOwned, + boxed::Box, + cell::UnsafeCell, + collections::{HashMap, HashSet, VecDeque}, + future::Future, + marker::PhantomData, + mem::{self, MaybeUninit}, + pin::{pin, Pin}, + ptr::{self, NonNull}, + sync::Arc, + task::{Context, Poll, Wake, Waker}, + vec::Vec, + }, + table::{Table, TableId}, + wasmtime_environ::component::{ + InterfaceType, RuntimeComponentInstanceIndex, StringEncoding, MAX_FLAT_PARAMS, + MAX_FLAT_RESULTS, + }, + wasmtime_fiber::{Fiber, Suspend}, +}; + +use futures_and_streams::TransmitState; +pub(crate) use futures_and_streams::{ + error_context_debug_message, error_context_drop, error_context_new, flat_stream_read, + flat_stream_write, future_cancel_read, future_cancel_write, future_close_readable, + future_close_writable, future_new, future_read, future_write, stream_cancel_read, + stream_cancel_write, stream_close_readable, stream_close_writable, stream_new, stream_read, + stream_write, +}; +pub use futures_and_streams::{future, stream, ErrorContext, FutureReader, StreamReader}; + +mod futures_and_streams; +mod ready_chunks; +mod table; + +// TODO: Currently, we're exposing global (to the top-level component instance) task IDs to guests; which is a +// slight information leak and source of nondeterminsm. We should instead convert between global IDs and +// per-instance IDs. + +// TODO: The handling of `task.yield` and `task.backpressure` was bolted on late in the implementation and is +// currently haphazard. We need a refactor to manage yielding, backpressure, and event polling and delivery in a +// more unified and structured way. + +// TODO: move these into an enum: +const STATUS_STARTING: u32 = 0; +const STATUS_STARTED: u32 = 1; +const STATUS_RETURNED: u32 = 2; +const STATUS_DONE: u32 = 3; + +mod events { + // TODO: move these into an enum: + pub const _EVENT_CALL_STARTING: u32 = 0; + pub const EVENT_CALL_STARTED: u32 = 1; + pub const EVENT_CALL_RETURNED: u32 = 2; + pub const EVENT_CALL_DONE: u32 = 3; + pub const _EVENT_YIELDED: u32 = 4; + pub const EVENT_STREAM_READ: u32 = 5; + pub const EVENT_STREAM_WRITE: u32 = 6; + pub const EVENT_FUTURE_READ: u32 = 7; + pub const EVENT_FUTURE_WRITE: u32 = 8; +} + +const EXIT_FLAG_ASYNC_CALLER: u32 = 1 << 0; +const EXIT_FLAG_ASYNC_CALLEE: u32 = 1 << 1; + +struct HostTaskResult { + event: u32, + param: u32, + caller: TableId, +} + +type HostTaskFuture = Pin< + Box< + dyn Future< + Output = ( + u32, + Box Result>, + ), + > + Send + + Sync + + 'static, + >, +>; + +struct HostTask { + caller_instance: RuntimeComponentInstanceIndex, +} + +enum Deferred { + None, + Sync(StoreFiber<'static>), + Async { + call: Box Result + Send + Sync + 'static>, + instance: RuntimeComponentInstanceIndex, + callback: SendSyncPtr, + }, +} + +impl Deferred { + fn take_fiber(&mut self) -> Option> { + if let Self::Sync(_) = self { + let Self::Sync(fiber) = mem::replace(self, Self::None) else { + unreachable!() + }; + Some(fiber) + } else { + None + } + } +} + +struct GuestTask { + lower_params: Option, + lift_result: Option, + result: Option, + callback: Option<(SendSyncPtr, u32)>, + events: VecDeque<(u32, AnyTask, u32)>, + caller: Option>, + caller_instance: Option, + deferred: Deferred, + should_yield: bool, +} + +impl Default for GuestTask { + fn default() -> Self { + Self { + lower_params: None, + lift_result: None, + result: None, + callback: None, + events: VecDeque::new(), + caller: None, + caller_instance: None, + deferred: Deferred::None, + should_yield: false, + } + } +} + +#[derive(Copy, Clone)] +enum AnyTask { + Host(TableId), + Guest(TableId), + Transmit(TableId), +} + +impl AnyTask { + fn rep(&self) -> u32 { + match self { + Self::Host(task) => task.rep(), + Self::Guest(task) => task.rep(), + Self::Transmit(task) => task.rep(), + } + } + + fn delete_all_from(&self, mut store: StoreContextMut) -> Result<()> { + match self { + Self::Host(task) => { + log::trace!("delete {}", task.rep()); + store.concurrent_state().table.delete(*task).map(drop) + } + Self::Guest(task) => { + let finished = store + .concurrent_state() + .table + .get(*task)? + .events + .iter() + .filter_map(|(event, call, _)| { + (*event == events::EVENT_CALL_DONE).then_some(*call) + }) + .collect::>(); + + for call in finished { + log::trace!("will delete call {}", call.rep()); + call.delete_all_from(store.as_context_mut())?; + } + + log::trace!("delete {}", task.rep()); + store.concurrent_state().table.delete(*task).map(drop) + } + Self::Transmit(task) => store.concurrent_state().table.delete(*task).map(drop), + }?; + + Ok(()) + } +} + +pub(crate) struct LiftLowerContext { + pub(crate) pointer: *mut u8, + pub(crate) dropper: fn(*mut u8), +} + +unsafe impl Send for LiftLowerContext {} +unsafe impl Sync for LiftLowerContext {} + +impl Drop for LiftLowerContext { + fn drop(&mut self) { + (self.dropper)(self.pointer); + } +} + +type RawLower = + Box]) -> Result<()> + Send + Sync>; + +type LowerFn = fn(LiftLowerContext, *mut dyn VMStore, &mut [MaybeUninit]) -> Result<()>; + +type RawLift = Box< + dyn FnOnce(*mut dyn VMStore, &[ValRaw]) -> Result>> + + Send + + Sync, +>; + +type LiftFn = + fn(LiftLowerContext, *mut dyn VMStore, &[ValRaw]) -> Result>>; + +type LiftedResult = Box; + +struct Reset(*mut T, T); + +impl Drop for Reset { + fn drop(&mut self) { + unsafe { + *self.0 = self.1; + } + } +} + +struct AsyncState { + current_suspend: UnsafeCell< + *mut Suspend< + (Option<*mut dyn VMStore>, Result<()>), + Option<*mut dyn VMStore>, + (Option<*mut dyn VMStore>, Result<()>), + >, + >, + current_poll_cx: UnsafeCell<*mut Context<'static>>, +} + +unsafe impl Send for AsyncState {} +unsafe impl Sync for AsyncState {} + +pub(crate) struct AsyncCx { + current_suspend: *mut *mut wasmtime_fiber::Suspend< + (Option<*mut dyn VMStore>, Result<()>), + Option<*mut dyn VMStore>, + (Option<*mut dyn VMStore>, Result<()>), + >, + current_stack_limit: *mut usize, + current_poll_cx: *mut *mut Context<'static>, + track_pkey_context_switch: bool, +} + +impl AsyncCx { + pub(crate) fn new(store: &mut StoreContextMut) -> Self { + Self { + current_suspend: store.concurrent_state().async_state.current_suspend.get(), + current_stack_limit: store.0.runtime_limits().stack_limit.get(), + current_poll_cx: store.concurrent_state().async_state.current_poll_cx.get(), + track_pkey_context_switch: store.has_pkey(), + } + } + + unsafe fn poll(&self, mut future: Pin<&mut (dyn Future + Send)>) -> Poll { + let poll_cx = *self.current_poll_cx; + let _reset = Reset(self.current_poll_cx, poll_cx); + *self.current_poll_cx = ptr::null_mut(); + assert!(!poll_cx.is_null()); + future.as_mut().poll(&mut *poll_cx) + } + + pub(crate) unsafe fn block_on<'a, T, U>( + &self, + mut future: Pin<&mut (dyn Future + Send)>, + mut store: Option>, + ) -> Result<(U, Option>)> { + loop { + match self.poll(future.as_mut()) { + Poll::Ready(v) => break Ok((v, store)), + Poll::Pending => {} + } + + store = self.suspend(store)?; + } + } + + unsafe fn suspend<'a, T>( + &self, + store: Option>, + ) -> Result>> { + let previous_mask = if self.track_pkey_context_switch { + let previous_mask = mpk::current_mask(); + mpk::allow(ProtectionMask::all()); + previous_mask + } else { + ProtectionMask::all() + }; + let store = suspend_fiber(self.current_suspend, self.current_stack_limit, store); + if self.track_pkey_context_switch { + mpk::allow(previous_mask); + } + store + } +} + +#[derive(Default)] +struct InstanceState { + backpressure: bool, + task_queue: VecDeque>, +} + +pub struct ConcurrentState { + guest_task: Option>, + futures: ReadyChunks>, + table: Table, + async_state: AsyncState, + instance_states: HashMap, + yielding: HashSet, + unblocked: HashSet, + _phantom: PhantomData, +} + +impl Default for ConcurrentState { + fn default() -> Self { + Self { + guest_task: None, + table: Table::new(), + futures: ReadyChunks::new(FuturesUnordered::new(), 1024), + async_state: AsyncState { + current_suspend: UnsafeCell::new(ptr::null_mut()), + current_poll_cx: UnsafeCell::new(ptr::null_mut()), + }, + instance_states: HashMap::new(), + yielding: HashSet::new(), + unblocked: HashSet::new(), + _phantom: PhantomData, + } + } +} + +fn dummy_waker() -> Waker { + struct DummyWaker; + + impl Wake for DummyWaker { + fn wake(self: Arc) {} + } + + static WAKER: Lazy> = Lazy::new(|| Arc::new(DummyWaker)); + + WAKER.clone().into() +} + +/// Provide a hint to Rust type inferencer that we're returning a compatible +/// closure from a `LinkerInstance::func_wrap_concurrent` future. +pub fn for_any(fun: F) -> F +where + F: FnOnce(StoreContextMut) -> R + 'static, + R: 'static, +{ + fun +} + +fn for_any_lower< + F: FnOnce(*mut dyn VMStore, &mut [MaybeUninit]) -> Result<()> + Send + Sync, +>( + fun: F, +) -> F { + fun +} + +fn for_any_lift< + F: FnOnce(*mut dyn VMStore, &[ValRaw]) -> Result>> + Send + Sync, +>( + fun: F, +) -> F { + fun +} + +pub(crate) fn first_poll( + mut store: StoreContextMut, + future: impl Future) -> Result + 'static> + + Send + + Sync + + 'static, + caller_instance: RuntimeComponentInstanceIndex, + lower: impl FnOnce(StoreContextMut, R) -> Result<()> + Send + Sync + 'static, +) -> Result> { + let caller = store.concurrent_state().guest_task.unwrap(); + let task = store + .concurrent_state() + .table + .push_child(HostTask { caller_instance }, caller)?; + log::trace!("new child of {}: {}", caller.rep(), task.rep()); + let mut future = Box::pin(future.map(move |fun| { + ( + task.rep(), + Box::new(move |store: *mut dyn VMStore| { + let mut store = unsafe { StoreContextMut(&mut *store.cast()) }; + let result = fun(store.as_context_mut())?; + lower(store, result)?; + Ok(HostTaskResult { + event: events::EVENT_CALL_DONE, + param: 0u32, + caller, + }) + }) as Box Result>, + ) + })) as HostTaskFuture; + + Ok( + match future + .as_mut() + .poll(&mut Context::from_waker(&dummy_waker())) + { + Poll::Ready((_, fun)) => { + log::trace!("delete {} (already ready)", task.rep()); + store.concurrent_state().table.delete(task)?; + fun(store.0.traitobj())?; + None + } + Poll::Pending => { + store.concurrent_state().futures.get_mut().push(future); + Some(task.rep()) + } + }, + ) +} + +pub(crate) fn poll_and_block<'a, T, R: Send + Sync + 'static>( + mut store: StoreContextMut<'a, T>, + future: impl Future) -> Result + 'static> + + Send + + Sync + + 'static, + caller_instance: RuntimeComponentInstanceIndex, +) -> Result<(R, StoreContextMut<'a, T>)> { + let caller = store.concurrent_state().guest_task.unwrap(); + let old_result = store + .concurrent_state() + .table + .get_mut(caller)? + .result + .take(); + let task = store + .concurrent_state() + .table + .push_child(HostTask { caller_instance }, caller)?; + log::trace!("new child of {}: {}", caller.rep(), task.rep()); + let mut future = Box::pin(future.map(move |fun| { + ( + task.rep(), + Box::new(move |store: *mut dyn VMStore| { + let mut store = unsafe { StoreContextMut(&mut *store.cast()) }; + let result = fun(store.as_context_mut())?; + store.concurrent_state().table.get_mut(caller)?.result = + Some(Box::new(result) as _); + Ok(HostTaskResult { + event: events::EVENT_CALL_DONE, + param: 0u32, + caller, + }) + }) as Box Result>, + ) + })) as HostTaskFuture; + + Ok( + match unsafe { AsyncCx::new(&mut store).poll(future.as_mut()) } { + Poll::Ready((_, fun)) => { + log::trace!("delete {} (already ready)", task.rep()); + store.concurrent_state().table.delete(task)?; + let store = store.0.traitobj(); + fun(store)?; + let mut store = unsafe { StoreContextMut(&mut *store.cast()) }; + let result = *mem::replace( + &mut store.concurrent_state().table.get_mut(caller)?.result, + old_result, + ) + .unwrap() + .downcast() + .unwrap(); + (result, store) + } + Poll::Pending => { + store.concurrent_state().futures.get_mut().push(future); + loop { + if let Some(result) = store + .concurrent_state() + .table + .get_mut(caller)? + .result + .take() + { + store.concurrent_state().table.get_mut(caller)?.result = old_result; + break (*result.downcast().unwrap(), store); + } else { + let async_cx = AsyncCx::new(&mut store); + store = unsafe { async_cx.suspend(Some(store)) }?.unwrap(); + } + } + } + }, + ) +} + +pub(crate) async fn on_fiber<'a, R: Send + Sync + 'static, T: Send>( + mut store: StoreContextMut<'a, T>, + handle: Func, + func: impl FnOnce(&mut StoreContextMut) -> R + Send, +) -> Result<(R, StoreContextMut<'a, T>)> { + assert!(store.concurrent_state().guest_task.is_none()); + + let guest_task = store.concurrent_state().table.push(GuestTask::default())?; + store.concurrent_state().guest_task = Some(guest_task); + + let is_concurrent = store.0[handle.0].options.async_(); + let instance = store.0[handle.0].component_instance; + + let mut fiber = make_fiber(&mut store, instance, move |mut store| { + let result = func(&mut store); + assert!(store + .concurrent_state() + .table + .get(guest_task)? + .result + .is_none()); + store.concurrent_state().table.get_mut(guest_task)?.result = Some(Box::new(result) as _); + Ok(()) + })?; + + if !is_concurrent { + let state = &mut store + .concurrent_state() + .instance_states + .entry(instance) + .or_default(); + assert!(state.task_queue.is_empty()); + state.task_queue.push_back(guest_task); + } + + store = poll_fn(store, move |_, mut store| { + match resume_fiber(&mut fiber, store.take(), Ok(())) { + Ok(Ok((store, result))) => Ok(result.map(|()| store)), + Ok(Err(s)) => Err(s), + Err(e) => Ok(Err(e)), + } + }) + .await?; + + let result = store + .concurrent_state() + .table + .get_mut(guest_task)? + .result + .take(); + + store = poll_fn(store, move |_, mut store| { + Ok(poll_ready(store.take().unwrap())) + }) + .await?; + + if !is_concurrent { + store = poll_fn(store, move |_, mut store| { + Ok(maybe_resume_next_task( + store.take().unwrap(), + guest_task, + instance, + )) + }) + .await?; + } + + if let Some(result) = result { + let task = store.concurrent_state().table.get(guest_task)?; + + if task.callback.is_none() { + // The task is finished -- delete it. + // + // Note that this isn't always the case -- a task may yield a result without completing, in which case + // it should be polled until it's completed using `poll_for_completion`. + + AnyTask::Guest(guest_task).delete_all_from(store.as_context_mut())?; + store.concurrent_state().guest_task = None; + } + Ok((*result.downcast().unwrap(), store)) + } else { + // All outstanding host tasks completed, but the guest never yielded a result. + Err(anyhow!(crate::Trap::NoAsyncResult)) + } +} + +fn maybe_send_event<'a, T>( + mut store: StoreContextMut<'a, T>, + guest_task: TableId, + event: u32, + call: AnyTask, + result: u32, +) -> Result> { + assert_ne!(guest_task.rep(), call.rep()); + if let Some((callback, context)) = store.concurrent_state().table.get(guest_task)?.callback { + log::trace!( + "use callback to deliver event {event} to {} for {}: {:?} {context}", + guest_task.rep(), + call.rep(), + callback + ); + let old_task = store.concurrent_state().guest_task.replace(guest_task); + let params = &mut [ + ValRaw::u32(context), + ValRaw::u32(event), + ValRaw::u32(call.rep()), + ValRaw::u32(result), + ]; + unsafe { + crate::Func::call_unchecked_raw( + &mut store, + callback.as_non_null(), + params.as_mut_ptr(), + params.len(), + )?; + } + let done = params[0].get_u32() != 0; + log::trace!("{} done? {done}", guest_task.rep()); + if done { + store.concurrent_state().table.get_mut(guest_task)?.callback = None; + + if let Some(next) = store.concurrent_state().table.get(guest_task)?.caller { + store = maybe_send_event( + store, + next, + events::EVENT_CALL_DONE, + AnyTask::Guest(guest_task), + 0, + )?; + } + } + store.concurrent_state().guest_task = old_task; + Ok(store) + } else { + store + .concurrent_state() + .table + .get_mut(guest_task)? + .events + .push_back((event, call, result)); + + if let Some(fiber) = store + .concurrent_state() + .table + .get_mut(guest_task)? + .deferred + .take_fiber() + { + log::trace!( + "use fiber to deliver event {event} to {} for {}", + guest_task.rep(), + call.rep() + ); + let old_task = store.concurrent_state().guest_task.replace(guest_task); + store = resume_sync(store, guest_task, fiber)?; + store.concurrent_state().guest_task = old_task; + } else { + log::trace!( + "queue event {event} to {} for {}", + guest_task.rep(), + call.rep() + ); + } + + Ok(store) + } +} + +fn resume_sync<'a, T>( + mut store: StoreContextMut<'a, T>, + guest_task: TableId, + mut fiber: StoreFiber<'static>, +) -> Result> { + match resume_fiber(&mut fiber, Some(store), Ok(()))? { + Ok((mut store, result)) => { + result?; + store = maybe_resume_next_task(store, guest_task, fiber.instance)?; + for (event, call, _) in + mem::take(&mut store.concurrent_state().table.get_mut(guest_task)?.events) + { + if event == events::EVENT_CALL_DONE { + log::trace!("resume_sync will delete {}", call.rep()); + call.delete_all_from(store.as_context_mut())?; + } + } + if let Some(next) = store.concurrent_state().table.get(guest_task)?.caller { + maybe_send_event( + store, + next, + events::EVENT_CALL_DONE, + AnyTask::Guest(guest_task), + 0, + ) + } else { + Ok(store) + } + } + Err(new_store) => { + store = new_store.unwrap(); + store.concurrent_state().table.get_mut(guest_task)?.deferred = Deferred::Sync(fiber); + Ok(store) + } + } +} + +fn resume_async<'a, T>( + store: StoreContextMut<'a, T>, + guest_task: TableId, + call: Box Result>, + instance: RuntimeComponentInstanceIndex, + callback: SendSyncPtr, +) -> Result> { + let store = store.0.traitobj(); + let guest_context = call(store)?; + let mut store = unsafe { StoreContextMut(&mut *store.cast()) }; + + let task = store.concurrent_state().table.get_mut(guest_task)?; + let event = if task.lift_result.is_some() { + events::EVENT_CALL_STARTED + } else if guest_context != 0 { + events::EVENT_CALL_RETURNED + } else { + events::EVENT_CALL_DONE + }; + if guest_context != 0 { + log::trace!("set callback for {}", guest_task.rep()); + task.callback = Some((callback, guest_context)); + for (event, call, result) in mem::take(&mut task.events) { + store = maybe_send_event(store, guest_task, event, call, result)?; + } + } + store = maybe_resume_next_task(store, guest_task, instance)?; + if let Some(next) = store.concurrent_state().table.get(guest_task)?.caller { + maybe_send_event(store, next, event, AnyTask::Guest(guest_task), 0) + } else { + Ok(store) + } +} + +fn poll_for_result<'a, T>(mut store: StoreContextMut<'a, T>) -> Result> { + let task = store.concurrent_state().guest_task; + poll_loop(store, move |store| { + task.map(|task| { + Ok::<_, anyhow::Error>(store.concurrent_state().table.get(task)?.result.is_none()) + }) + .unwrap_or(Ok(true)) + }) +} + +fn handle_ready<'a, T>( + mut store: StoreContextMut<'a, T>, + ready: Vec<( + u32, + Box Result>, + )>, +) -> Result> { + for (task, fun) in ready { + let vm_store = store.0.traitobj(); + let result = fun(vm_store)?; + store = unsafe { StoreContextMut::(&mut *vm_store.cast()) }; + let task = match result.event { + events::EVENT_CALL_DONE => AnyTask::Host(TableId::::new(task)), + events::EVENT_STREAM_READ + | events::EVENT_FUTURE_READ + | events::EVENT_STREAM_WRITE + | events::EVENT_FUTURE_WRITE => AnyTask::Transmit(TableId::::new(task)), + _ => unreachable!(), + }; + store = maybe_send_event(store, result.caller, result.event, task, result.param)?; + } + Ok(store) +} + +fn maybe_yield<'a, T>(mut store: StoreContextMut<'a, T>) -> Result> { + let guest_task = store.concurrent_state().guest_task.unwrap(); + + if store.concurrent_state().table.get(guest_task)?.should_yield { + log::trace!("maybe_yield suspend"); + + store.concurrent_state().yielding.insert(guest_task.rep()); + let cx = AsyncCx::new(&mut store); + store = unsafe { cx.suspend(Some(store)) }?.unwrap(); + + log::trace!("maybe_yield resume"); + } else { + log::trace!("maybe_yield skip"); + } + + Ok(store) +} + +fn unyield<'a, T>(mut store: StoreContextMut<'a, T>) -> Result<(StoreContextMut<'a, T>, bool)> { + let mut resumed = false; + for task in mem::take(&mut store.concurrent_state().yielding) { + let guest_task = TableId::::new(task); + if let Some(fiber) = store + .concurrent_state() + .table + .get_mut(guest_task)? + .deferred + .take_fiber() + { + resumed = true; + let old_task = store.concurrent_state().guest_task.replace(guest_task); + store = resume_sync(store, guest_task, fiber)?; + store.concurrent_state().guest_task = old_task; + } + } + + for instance in mem::take(&mut store.concurrent_state().unblocked) { + let entry = store + .concurrent_state() + .instance_states + .entry(instance) + .or_default(); + + if !entry.backpressure { + if let Some(task) = entry.task_queue.iter().copied().next() { + resumed = true; + store = resume(store, task)?; + } + } + } + + Ok((store, resumed)) +} + +fn poll_loop<'a, T>( + mut store: StoreContextMut<'a, T>, + continue_: impl Fn(&mut StoreContextMut<'a, T>) -> Result, +) -> Result> { + loop { + let cx = AsyncCx::new(&mut store); + let mut future = pin!(store.concurrent_state().futures.next()); + let ready = unsafe { cx.poll(future.as_mut()) }; + + match ready { + Poll::Ready(Some(ready)) => { + store = handle_ready(store, ready)?; + } + Poll::Ready(None) => { + let (s, resumed) = unyield(store)?; + store = s; + if !resumed { + break; + } + } + Poll::Pending => { + let (s, resumed) = unyield(store)?; + store = s; + if continue_(&mut store)? { + let cx = AsyncCx::new(&mut store); + store = unsafe { cx.suspend(Some(store)) }?.unwrap(); + } else if !resumed { + break; + } + } + } + } + + Ok(store) +} + +fn resume<'a, T>( + mut store: StoreContextMut<'a, T>, + task: TableId, +) -> Result> { + log::trace!("resume {}", task.rep()); + + // TODO: Avoid tail calling `resume_sync` or `resume_async` here, because it may call us, leading to + // recursion limited only by the number of waiters. Flatten this into an iteration instead. + match mem::replace( + &mut store.concurrent_state().table.get_mut(task)?.deferred, + Deferred::None, + ) { + Deferred::None => unreachable!(), + Deferred::Sync(fiber) => resume_sync(store, task, fiber), + Deferred::Async { + call, + instance, + callback, + } => resume_async(store, task, call, instance, callback), + } +} + +fn maybe_resume_next_task<'a, T>( + mut store: StoreContextMut<'a, T>, + current_task: TableId, + instance: RuntimeComponentInstanceIndex, +) -> Result> { + let state = store + .concurrent_state() + .instance_states + .get_mut(&instance) + .unwrap(); + + if state.backpressure { + Ok(store) + } else { + assert_eq!( + current_task.rep(), + state.task_queue.pop_front().unwrap().rep() + ); + + if let Some(next) = state.task_queue.iter().copied().next() { + resume(store, next) + } else { + Ok(store) + } + } +} + +struct StoreFiber<'a> { + fiber: Option< + Fiber< + 'a, + (Option<*mut dyn VMStore>, Result<()>), + Option<*mut dyn VMStore>, + (Option<*mut dyn VMStore>, Result<()>), + >, + >, + state: Option, + engine: Engine, + suspend: *mut *mut Suspend< + (Option<*mut dyn VMStore>, Result<()>), + Option<*mut dyn VMStore>, + (Option<*mut dyn VMStore>, Result<()>), + >, + stack_limit: *mut usize, + instance: RuntimeComponentInstanceIndex, +} + +impl<'a> Drop for StoreFiber<'a> { + fn drop(&mut self) { + if !self.fiber.as_ref().unwrap().done() { + let result = unsafe { resume_fiber_raw(self, None, Err(anyhow!("future dropped"))) }; + debug_assert!(result.is_ok()); + } + + self.state.take().unwrap().assert_null(); + + unsafe { + self.engine + .allocator() + .deallocate_fiber_stack(self.fiber.take().unwrap().into_stack()); + } + } +} + +unsafe impl<'a> Send for StoreFiber<'a> {} +unsafe impl<'a> Sync for StoreFiber<'a> {} + +fn make_fiber<'a, T>( + store: &mut StoreContextMut, + instance: RuntimeComponentInstanceIndex, + fun: impl FnOnce(StoreContextMut) -> Result<()> + 'a, +) -> Result> { + let engine = store.engine().clone(); + let stack = engine.allocator().allocate_fiber_stack()?; + Ok(StoreFiber { + fiber: Some(Fiber::new( + stack, + move |(store_ptr, result): (Option<*mut dyn VMStore>, Result<()>), suspend| { + if result.is_err() { + (store_ptr, result) + } else { + unsafe { + let store_ptr = store_ptr.unwrap(); + let mut store = StoreContextMut(&mut *store_ptr.cast()); + let suspend_ptr = + store.concurrent_state().async_state.current_suspend.get(); + let _reset = Reset(suspend_ptr, *suspend_ptr); + *suspend_ptr = suspend; + (Some(store_ptr), fun(store.as_context_mut())) + } + } + }, + )?), + state: Some(AsyncWasmCallState::new()), + engine, + suspend: store.concurrent_state().async_state.current_suspend.get(), + stack_limit: store.0.runtime_limits().stack_limit.get(), + instance, + }) +} + +unsafe fn resume_fiber_raw<'a>( + fiber: *mut StoreFiber<'a>, + store: Option<*mut dyn VMStore>, + result: Result<()>, +) -> Result<(Option<*mut dyn VMStore>, Result<()>), Option<*mut dyn VMStore>> { + struct Restore<'a> { + fiber: *mut StoreFiber<'a>, + state: Option, + } + + impl Drop for Restore<'_> { + fn drop(&mut self) { + unsafe { + (*self.fiber).state = Some(self.state.take().unwrap().restore()); + } + } + } + + let _reset_suspend = Reset((*fiber).suspend, *(*fiber).suspend); + let _reset_stack_limit = Reset((*fiber).stack_limit, *(*fiber).stack_limit); + let state = Some((*fiber).state.take().unwrap().push()); + let restore = Restore { fiber, state }; + (*restore.fiber) + .fiber + .as_ref() + .unwrap() + .resume((store, result)) +} + +fn poll_ready<'a, T>(mut store: StoreContextMut<'a, T>) -> Result> { + unsafe { + let cx = *store.concurrent_state().async_state.current_poll_cx.get(); + assert!(!cx.is_null()); + while let Poll::Ready(Some(ready)) = + store.concurrent_state().futures.poll_next_unpin(&mut *cx) + { + match handle_ready(store, ready) { + Ok(s) => { + store = s; + } + Err(e) => { + return Err(e); + } + } + } + } + Ok(store) +} + +fn resume_fiber<'a, T>( + fiber: &mut StoreFiber, + mut store: Option>, + result: Result<()>, +) -> Result, Result<()>), Option>>> { + if let Some(s) = store.take() { + store = Some(poll_ready(s)?); + } + + unsafe { + match resume_fiber_raw(fiber, store.map(|s| s.0.traitobj()), result) + .map(|(store, result)| (StoreContextMut(&mut *store.unwrap().cast()), result)) + .map_err(|v| v.map(|v| StoreContextMut(&mut *v.cast()))) + { + Ok(pair) => Ok(Ok(pair)), + Err(s) => { + if let Some(range) = fiber.fiber.as_ref().unwrap().stack().range() { + AsyncWasmCallState::assert_current_state_not_in_range(range); + } + + Ok(Err(s)) + } + } + } +} + +unsafe fn suspend_fiber<'a, T>( + suspend: *mut *mut Suspend< + (Option<*mut dyn VMStore>, Result<()>), + Option<*mut dyn VMStore>, + (Option<*mut dyn VMStore>, Result<()>), + >, + stack_limit: *mut usize, + store: Option>, +) -> Result>> { + let _reset_suspend = Reset(suspend, *suspend); + let _reset_stack_limit = Reset(stack_limit, *stack_limit); + let (store, result) = (**suspend).suspend(store.map(|s| s.0.traitobj())); + result?; + Ok(store.map(|v| StoreContextMut(&mut *v.cast()))) +} + +enum TaskCheck { + Wait(*mut VMMemoryDefinition, u32), + Poll(*mut VMMemoryDefinition, u32), + Yield, +} + +unsafe fn task_check(cx: *mut VMOpaqueContext, async_: bool, check: TaskCheck) -> Result { + if async_ { + bail!("todo: async `task.wait`, `task.poll`, and `task.yield` not yet implemented"); + } + + let cx = VMComponentContext::from_opaque(cx); + let instance = (*cx).instance(); + let mut cx = StoreContextMut::(&mut *(*instance).store().cast()); + + cx = maybe_yield(cx)?; + + let guest_task = cx.concurrent_state().guest_task.unwrap(); + + let wait = matches!(check, TaskCheck::Wait(..)); + + if wait + && cx + .concurrent_state() + .table + .get(guest_task)? + .callback + .is_some() + { + bail!("cannot call `task.wait` from async-lifted export with callback"); + } + + let mut cx = poll_loop(cx, move |cx| { + Ok::<_, anyhow::Error>( + wait && cx + .concurrent_state() + .table + .get(guest_task)? + .events + .is_empty(), + ) + })?; + + let result = match check { + TaskCheck::Wait(memory, payload) => { + let (event, call, result) = cx + .concurrent_state() + .table + .get_mut(guest_task)? + .events + .pop_front() + .unwrap(); + + log::trace!( + "deliver event {event} via task.wait to {} for {}", + guest_task.rep(), + call.rep() + ); + + let options = Options::new( + cx.0.id(), + NonNull::new(memory), + None, + StringEncoding::Utf8, + true, + None, + ); + let types = (*instance).component_types(); + let ptr = + func::validate_inbounds::(options.memory_mut(cx.0), &ValRaw::u32(payload))?; + let mut lower = LowerContext::new(cx, &options, types, instance); + call.rep().store(&mut lower, InterfaceType::U32, ptr)?; + result.store(&mut lower, InterfaceType::U32, ptr + 4)?; + + Ok(event) + } + TaskCheck::Poll(memory, payload) => { + if let Some((event, call, result)) = cx + .concurrent_state() + .table + .get_mut(guest_task)? + .events + .pop_front() + { + log::trace!( + "deliver event {event} via task.poll to {} for {}", + guest_task.rep(), + call.rep() + ); + + let options = Options::new( + cx.0.id(), + NonNull::new(memory), + None, + StringEncoding::Utf8, + true, + None, + ); + let types = (*instance).component_types(); + let ptr = func::validate_inbounds::<(u32, u32)>( + options.memory_mut(cx.0), + &ValRaw::u32(payload), + )?; + let mut lower = LowerContext::new(cx, &options, types, instance); + event.store(&mut lower, InterfaceType::U32, ptr)?; + call.rep().store(&mut lower, InterfaceType::U32, ptr + 4)?; + result.store(&mut lower, InterfaceType::U32, ptr + 8)?; + + Ok(1) + } else { + log::trace!( + "no events ready to deliver via task.poll to {}", + guest_task.rep() + ); + + Ok(0) + } + } + TaskCheck::Yield => Ok(0), + }; + + result +} + +unsafe fn handle_result(func: impl FnOnce() -> Result) -> T { + match crate::runtime::vm::catch_unwind_and_longjmp(func) { + Ok(value) => value, + Err(e) => { + log::trace!("handle_result error: {e:?}"); + crate::trap::raise(e) + } + } +} + +pub(crate) extern "C" fn task_backpressure( + cx: *mut VMOpaqueContext, + caller_instance: RuntimeComponentInstanceIndex, + enabled: u32, +) { + unsafe { + handle_result(|| { + let cx = VMComponentContext::from_opaque(cx); + let instance = (*cx).instance(); + let mut cx = StoreContextMut::(&mut *(*instance).store().cast()); + let entry = cx + .concurrent_state() + .instance_states + .entry(caller_instance) + .or_default(); + let old = entry.backpressure; + let new = enabled != 0; + entry.backpressure = new; + + if old && !new { + if let Some(_) = entry.task_queue.iter().next() { + cx.concurrent_state().unblocked.insert(caller_instance); + } + } + + Ok(()) + }) + } +} + +pub(crate) extern "C" fn task_return( + cx: *mut VMOpaqueContext, + storage: *mut MaybeUninit, + storage_len: usize, +) { + unsafe { + handle_result(|| { + let storage = std::slice::from_raw_parts(storage, storage_len); + let cx = VMComponentContext::from_opaque(cx); + let instance = (*cx).instance(); + let mut cx = StoreContextMut::(&mut *(*instance).store().cast()); + let guest_task = cx.concurrent_state().guest_task.unwrap(); + let lift = cx + .concurrent_state() + .table + .get_mut(guest_task)? + .lift_result + .take() + .ok_or_else(|| anyhow!("call.return called more than once"))?; + + assert!(cx + .concurrent_state() + .table + .get(guest_task)? + .result + .is_none()); + + let cx = cx.0.traitobj(); + let result = lift( + cx, + mem::transmute::<&[MaybeUninit], &[ValRaw]>(storage), + )?; + + let mut cx = StoreContextMut::(&mut *cx.cast()); + cx.concurrent_state().table.get_mut(guest_task)?.result = result; + + Ok(()) + }) + } +} + +pub(crate) extern "C" fn task_wait( + cx: *mut VMOpaqueContext, + async_: bool, + memory: *mut VMMemoryDefinition, + payload: u32, +) -> u32 { + unsafe { handle_result(|| task_check::(cx, async_, TaskCheck::Wait(memory, payload))) } +} + +pub(crate) extern "C" fn task_poll( + cx: *mut VMOpaqueContext, + async_: bool, + memory: *mut VMMemoryDefinition, + payload: u32, +) -> u32 { + unsafe { handle_result(|| task_check::(cx, async_, TaskCheck::Poll(memory, payload))) } +} + +pub(crate) extern "C" fn task_yield(cx: *mut VMOpaqueContext, async_: bool) { + unsafe { + handle_result(|| task_check::(cx, async_, TaskCheck::Yield)); + } +} + +pub(crate) extern "C" fn subtask_drop( + cx: *mut VMOpaqueContext, + caller_instance: RuntimeComponentInstanceIndex, + task_id: u32, +) { + unsafe { + handle_result(|| { + let cx = VMComponentContext::from_opaque(cx); + let instance = (*cx).instance(); + let mut cx = StoreContextMut::(&mut *(*instance).store().cast()); + let table = &mut cx.concurrent_state().table; + log::trace!("subtask_drop delete {task_id}"); + let task = table.delete_any(task_id)?; + let expected_caller_instance = match task.downcast::() { + Ok(task) => Some(task.caller_instance), + Err(task) => match task.downcast::() { + Ok(task) => task.caller_instance, + Err(_) => bail!("invalid handle: {task_id}"), + }, + }; + if expected_caller_instance != Some(caller_instance) { + bail!("invalid handle: {task_id}"); + } + Ok(()) + }) + } +} + +pub(crate) extern "C" fn async_enter( + cx: *mut VMOpaqueContext, + start: *mut VMFuncRef, + return_: *mut VMFuncRef, + caller_instance: RuntimeComponentInstanceIndex, + params: u32, + results: u32, +) { + unsafe { + handle_result(|| { + let cx = VMComponentContext::from_opaque(cx); + let instance = (*cx).instance(); + let mut cx = StoreContextMut::(&mut *(*instance).store().cast()); + let start = SendSyncPtr::new(NonNull::new(start).unwrap()); + let return_ = SendSyncPtr::new(NonNull::new(return_).unwrap()); + let old_task = cx.concurrent_state().guest_task.take(); + let old_task_rep = old_task.map(|v| v.rep()); + let new_task = GuestTask { + lower_params: Some(Box::new(move |cx, dst| { + let mut cx = StoreContextMut::(&mut *cx.cast()); + assert!(dst.len() <= MAX_FLAT_PARAMS); + let mut src = [MaybeUninit::uninit(); MAX_FLAT_PARAMS]; + src[0] = MaybeUninit::new(ValRaw::u32(params)); + crate::Func::call_unchecked_raw( + &mut cx, + start.as_non_null(), + src.as_mut_ptr() as _, + 1.max(dst.len()), + )?; + dst.copy_from_slice(&src[..dst.len()]); + let task = cx.concurrent_state().guest_task.unwrap(); + if let Some(rep) = old_task_rep { + maybe_send_event( + cx, + TableId::new(rep), + events::EVENT_CALL_STARTED, + AnyTask::Guest(task), + 0, + )?; + } + Ok(()) + })), + lift_result: Some(Box::new(move |cx, src| { + let mut cx = StoreContextMut::(&mut *cx.cast()); + let mut my_src = src.to_owned(); // TODO: use stack to avoid allocation? + my_src.push(ValRaw::u32(results)); + crate::Func::call_unchecked_raw( + &mut cx, + return_.as_non_null(), + my_src.as_mut_ptr(), + my_src.len(), + )?; + let task = cx.concurrent_state().guest_task.unwrap(); + if let Some(rep) = old_task_rep { + maybe_send_event( + cx, + TableId::new(rep), + events::EVENT_CALL_RETURNED, + AnyTask::Guest(task), + 0, + )?; + } + Ok(None) + })), + result: None, + callback: None, + caller: old_task, + caller_instance: Some(caller_instance), + deferred: Deferred::None, + events: VecDeque::new(), + should_yield: false, + }; + let guest_task = if let Some(old_task) = old_task { + let child = cx.concurrent_state().table.push_child(new_task, old_task)?; + log::trace!("new child of {}: {}", old_task.rep(), child.rep()); + child + } else { + cx.concurrent_state().table.push(new_task)? + }; + + cx.concurrent_state().guest_task = Some(guest_task); + + Ok(()) + }) + } +} + +fn make_call( + guest_task: TableId, + callee: SendSyncPtr, + param_count: usize, + result_count: usize, +) -> impl FnOnce( + StoreContextMut, +) -> Result<([MaybeUninit; MAX_FLAT_PARAMS], StoreContextMut)> + + Send + + Sync + + 'static { + move |mut cx: StoreContextMut| { + let mut storage = [MaybeUninit::uninit(); MAX_FLAT_PARAMS]; + let lower = cx + .concurrent_state() + .table + .get_mut(guest_task)? + .lower_params + .take() + .unwrap(); + let cx = cx.0.traitobj(); + lower(cx, &mut storage[..param_count])?; + let mut cx = unsafe { StoreContextMut::(&mut *cx.cast()) }; + + unsafe { + crate::Func::call_unchecked_raw( + &mut cx, + callee.as_non_null(), + storage.as_mut_ptr() as _, + param_count.max(result_count), + )?; + } + + Ok((storage, cx)) + } +} + +pub(crate) extern "C" fn async_exit( + cx: *mut VMOpaqueContext, + callback: *mut VMFuncRef, + callee: *mut VMFuncRef, + callee_instance: RuntimeComponentInstanceIndex, + param_count: u32, + result_count: u32, + flags: u32, +) -> u32 { + unsafe { + handle_result(|| { + let cx = VMComponentContext::from_opaque(cx); + let instance = (*cx).instance(); + let mut cx = StoreContextMut::(&mut *(*instance).store().cast()); + + let guest_task = cx.concurrent_state().guest_task.unwrap(); + let callee = SendSyncPtr::new(NonNull::new(callee).unwrap()); + let param_count = usize::try_from(param_count).unwrap(); + assert!(param_count <= MAX_FLAT_PARAMS); + let result_count = usize::try_from(result_count).unwrap(); + assert!(result_count <= MAX_FLAT_RESULTS); + + let call = make_call(guest_task, callee, param_count, result_count); + + let state = &mut cx + .concurrent_state() + .instance_states + .entry(callee_instance) + .or_default(); + let ready = state.task_queue.is_empty() && !state.backpressure; + + let mut guest_context = 0; + + let mut cx = if (flags & EXIT_FLAG_ASYNC_CALLEE) == 0 { + state.task_queue.push_back(guest_task); + + let mut fiber = make_fiber(&mut cx, callee_instance, move |cx| { + let (storage, mut cx) = call(cx)?; + + let lift = cx + .concurrent_state() + .table + .get_mut(guest_task)? + .lift_result + .take() + .unwrap(); + + assert!(cx + .concurrent_state() + .table + .get(guest_task)? + .result + .is_none()); + + let cx = cx.0.traitobj(); + let result = lift( + cx, + mem::transmute::<&[MaybeUninit], &[ValRaw]>( + &storage[..result_count], + ), + )?; + let mut cx = StoreContextMut::(&mut *cx.cast()); + + cx.concurrent_state().table.get_mut(guest_task)?.result = result; + + Ok(()) + })?; + + cx.concurrent_state() + .table + .get_mut(guest_task)? + .should_yield = true; + + if ready { + let mut cx = Some(cx); + loop { + match resume_fiber(&mut fiber, cx.take(), Ok(()))? { + Ok((cx, result)) => { + result?; + break maybe_resume_next_task(cx, guest_task, callee_instance)?; + } + Err(cx) => { + if let Some(mut cx) = cx { + cx.concurrent_state().table.get_mut(guest_task)?.deferred = + Deferred::Sync(fiber); + break cx; + } else { + suspend_fiber::(fiber.suspend, fiber.stack_limit, None)?; + } + } + } + } + } else { + cx.concurrent_state().table.get_mut(guest_task)?.deferred = + Deferred::Sync(fiber); + cx + } + } else { + if ready { + let (storage, cx) = call(cx)?; + guest_context = storage[0].assume_init().get_i32() as u32; + cx + } else { + cx.concurrent_state() + .instance_states + .get_mut(&callee_instance) + .unwrap() + .task_queue + .push_back(guest_task); + + cx.concurrent_state().table.get_mut(guest_task)?.deferred = Deferred::Async { + call: Box::new(move |cx| { + let mut cx = StoreContextMut(&mut *cx.cast()); + let old_task = cx.concurrent_state().guest_task.replace(guest_task); + let (storage, mut cx) = call(cx)?; + cx.concurrent_state().guest_task = old_task; + Ok(storage[0].assume_init().get_i32() as u32) + }), + instance: callee_instance, + callback: SendSyncPtr::new(NonNull::new(callback).unwrap()), + }; + cx + } + }; + + let guest_task = cx.concurrent_state().guest_task.take().unwrap(); + + let caller = cx.concurrent_state().table.get(guest_task)?.caller; + cx.concurrent_state().guest_task = caller; + + let task = cx.concurrent_state().table.get_mut(guest_task)?; + + if guest_context != 0 { + log::trace!("set callback for {}", guest_task.rep()); + task.callback = Some(( + SendSyncPtr::new(NonNull::new(callback).unwrap()), + guest_context, + )); + for (event, call, result) in mem::take(&mut task.events) { + cx = maybe_send_event(cx, guest_task, event, call, result)?; + } + } + + let task = cx.concurrent_state().table.get(guest_task)?; + + let mut status = if task.lower_params.is_some() { + STATUS_STARTING + } else if task.lift_result.is_some() { + STATUS_STARTED + } else if guest_context != 0 { + STATUS_RETURNED + } else { + STATUS_DONE + }; + + let call = if status != STATUS_DONE { + if (flags & EXIT_FLAG_ASYNC_CALLER) != 0 { + guest_task.rep() + } else { + poll_for_result(cx)?; + status = STATUS_DONE; + 0 + } + } else { + 0 + }; + + Ok((status << 30) | call) + }) + } +} + +pub(crate) fn call<'a, T: Send, LowerParams: Copy, R: 'static>( + mut store: StoreContextMut<'a, T>, + lower_params: LowerFn, + lower_context: LiftLowerContext, + lift_result: LiftFn, + lift_context: LiftLowerContext, + callback: NonNull, + callee: NonNull, + callee_instance: RuntimeComponentInstanceIndex, +) -> Result<(R, StoreContextMut<'a, T>)> { + let guest_task = store.concurrent_state().guest_task.unwrap(); + let task = store.concurrent_state().table.get_mut(guest_task)?; + task.lower_params = Some(Box::new(for_any_lower(move |store, params| { + lower_params(lower_context, store, params) + })) as RawLower); + task.lift_result = Some(Box::new(for_any_lift(move |store, result| { + lift_result(lift_context, store, result) + })) as RawLift); + + let state = &mut store + .concurrent_state() + .instance_states + .entry(callee_instance) + .or_default(); + let ready = state.task_queue.is_empty() && !state.backpressure; + + let call = make_call( + guest_task, + SendSyncPtr::new(callee), + mem::size_of::() / mem::size_of::(), + 1, + ); + + let result = if ready { + let (storage, cx) = call(store)?; + store = cx; + Ok(unsafe { storage[0].assume_init() }.get_i32() as u32) + } else { + Err(call) + }; + + if !matches!(result, Ok(0)) { + match result { + Ok(guest_context) => { + log::trace!("set callback for {}", guest_task.rep()); + let task = store.concurrent_state().table.get_mut(guest_task)?; + task.callback = Some((SendSyncPtr::new(callback), guest_context)); + for (event, call, result) in mem::take(&mut task.events) { + store = maybe_send_event(store, guest_task, event, call, result)?; + } + } + Err(call) => { + store + .concurrent_state() + .instance_states + .get_mut(&callee_instance) + .unwrap() + .task_queue + .push_back(guest_task); + + store.concurrent_state().table.get_mut(guest_task)?.deferred = Deferred::Async { + call: Box::new(move |cx| { + let cx = unsafe { StoreContextMut(&mut *cx.cast()) }; + let (storage, _) = call(cx)?; + Ok(unsafe { storage[0].assume_init() }.get_i32() as u32) + }), + instance: callee_instance, + callback: SendSyncPtr::new(callback), + }; + } + } + + store = poll_for_result(store)?; + } + + if let Some(result) = store + .concurrent_state() + .table + .get_mut(guest_task)? + .result + .take() + { + Ok((*result.downcast().unwrap(), store)) + } else { + // All outstanding host tasks completed, but the guest never yielded a result. + Err(anyhow!(crate::Trap::NoAsyncResult)) + } +} + +pub(crate) async fn poll<'a, T: Send>( + mut store: StoreContextMut<'a, T>, +) -> Result> { + let guest_task = store.concurrent_state().guest_task.unwrap(); + while store + .concurrent_state() + .table + .get(guest_task)? + .callback + .is_some() + { + let ready = store.concurrent_state().futures.next().await.unwrap(); + store = handle_ready(store, ready)?; + } + + Ok(store) +} + +pub(crate) async fn poll_until<'a, T: Send, U>( + mut store: StoreContextMut<'a, T>, + future: impl Future, +) -> Result<(StoreContextMut<'a, T>, U)> { + let mut future = Box::pin(future); + loop { + let ready = pin!(store.concurrent_state().futures.next()); + + match future::select(ready, future).await { + Either::Left((None, future_again)) => break Ok((store, future_again.await)), + Either::Left((Some(ready), future_again)) => { + let mut ready = Some(ready); + store = poll_fn(store, move |_, mut store| { + Ok(handle_ready(store.take().unwrap(), ready.take().unwrap())) + }) + .await?; + future = future_again; + } + Either::Right((result, _)) => break Ok((store, result)), + } + } +} + +async fn poll_fn<'a, T>( + mut store: StoreContextMut<'a, T>, + mut fun: impl FnMut( + &mut Context, + Option>, + ) -> Result>, Option>>, +) -> Result> { + #[derive(Clone, Copy)] + struct PollCx(*mut *mut Context<'static>); + + unsafe impl Send for PollCx {} + + let poll_cx = PollCx(store.concurrent_state().async_state.current_poll_cx.get()); + future::poll_fn({ + let mut store = Some(store); + + move |cx| unsafe { + let _reset = Reset(poll_cx.0, *poll_cx.0); + *poll_cx.0 = mem::transmute::<&mut Context<'_>, *mut Context<'static>>(cx); + #[allow(dropping_copy_types)] + drop(poll_cx); + + match fun(cx, store.take()) { + Ok(v) => Poll::Ready(v), + Err(s) => { + store = s; + Poll::Pending + } + } + } + }) + .await +} diff --git a/crates/wasmtime/src/runtime/component/concurrent/futures_and_streams.rs b/crates/wasmtime/src/runtime/component/concurrent/futures_and_streams.rs new file mode 100644 index 000000000000..a9cbef0074ca --- /dev/null +++ b/crates/wasmtime/src/runtime/component/concurrent/futures_and_streams.rs @@ -0,0 +1,1729 @@ +use { + super::{events, handle_result, table::TableId, GuestTask, HostTaskFuture, HostTaskResult}, + crate::{ + component::{ + func::{self, LiftContext, LowerContext, Options}, + matching::InstanceType, + Val, WasmList, + }, + vm::{ + component::{ComponentInstance, TableIndex, TableState, VMComponentContext}, + SendSyncPtr, VMFuncRef, VMMemoryDefinition, VMOpaqueContext, VMStore, + }, + AsContextMut, StoreContextMut, + }, + anyhow::{anyhow, bail, Context, Result}, + futures::{channel::oneshot, future::FutureExt}, + std::{ + any::Any, + boxed::Box, + future::Future, + marker::PhantomData, + mem::{self, MaybeUninit}, + pin::Pin, + ptr::NonNull, + string::ToString, + sync::Arc, + vec::Vec, + }, + wasmtime_environ::component::{ + CanonicalAbiInfo, ComponentTypes, InterfaceType, StringEncoding, + TypeErrorContextTableIndex, TypeFutureTableIndex, TypeStreamTableIndex, + }, +}; + +// TODO: add `validate_inbounds` calls where appropriate + +// TODO: Many of the functions in this module are used for both futures and streams, using runtime branches for +// specialization. We should consider using generics instead to move those branches to compile time. + +// TODO: Improve the host APIs for sending to and receiving from streams. Currently, they require explicitly +// interleaving calls to `write` or `read` and `StoreContextMut::wait_until`; see +// https://github.com/dicej/rfcs/blob/component-async/accepted/component-model-async.md#host-apis-for-creating-using-and-sharing-streams-futures-and-errors +// for an alternative approach. + +const BLOCKED: usize = 0xffff_ffff; +const CLOSED: usize = 0x8000_0000; + +fn payload(ty: TableIndex, types: &Arc) -> Option { + match ty { + TableIndex::Future(ty) => types[types[ty].ty].payload, + TableIndex::Stream(ty) => Some(types[types[ty].ty].payload), + } +} + +fn accept( + values: Vec, + mut offset: usize, + transmit_id: TableId, +) -> impl FnOnce(Reader) -> Result + Send + Sync + 'static { + move |reader| { + Ok(match reader { + Reader::Guest { + lower: + RawLowerContext { + store, + options, + types, + instance, + }, + ty, + address, + count, + } => { + let mut store = unsafe { StoreContextMut::(&mut *store.cast()) }; + let lower = &mut unsafe { + LowerContext::new(store.as_context_mut(), options, types, instance) + }; + let count = values.len().min(usize::try_from(count).unwrap()); + if let Some(ty) = payload(ty, types) { + T::store_list(lower, ty, address, &values[offset..][..count])?; + } + offset += count; + + if offset < values.len() { + let transmit = store.concurrent_state().table.get_mut(transmit_id)?; + assert!(matches!(&transmit.write, WriteState::Open)); + + transmit.write = WriteState::HostReady { + accept: Box::new(accept::(values, offset, transmit_id)), + close: false, + }; + } + + count + } + Reader::Host { accept } => { + assert!(offset == 0); // todo: do we need to handle offset != 0? + let count = values.len(); + accept(Box::new(values))?; + count + } + Reader::None => 0, + }) + } +} + +fn host_write>( + mut store: S, + rep: u32, + values: Vec, +) -> Result<()> { + let mut store = store.as_context_mut(); + let transmit_id = TableId::::new(rep); + let mut offset = 0; + + loop { + let transmit = store + .concurrent_state() + .table + .get_mut(transmit_id) + .with_context(|| rep.to_string())?; + let new_state = if let ReadState::Closed = &transmit.read { + ReadState::Closed + } else { + ReadState::Open + }; + + match mem::replace(&mut transmit.read, new_state) { + ReadState::Open => { + assert!(matches!(&transmit.write, WriteState::Open)); + + transmit.write = WriteState::HostReady { + accept: Box::new(accept::(values, offset, transmit_id)), + close: false, + }; + } + + ReadState::GuestReady { + ty, + flat_abi: _, + options, + address, + count, + instance, + tx, + handle, + caller, + } => unsafe { + let types = (*instance.as_ptr()).component_types(); + let lower = &mut LowerContext::new( + store.as_context_mut(), + &options, + types, + instance.as_ptr(), + ); + let count = values.len().min(count); + if let Some(ty) = payload(ty, types) { + T::store_list(lower, ty, address, &values[offset..][..count])?; + } + offset += count; + + log::trace!( + "remove read child of {}: {}", + caller.rep(), + transmit_id.rep() + ); + store + .concurrent_state() + .table + .remove_child(transmit_id, caller)?; + + (*instance.as_ptr()) + .handle_table() + .insert((ty, handle), TableState::Read); + + _ = tx.send(count); + + if offset < values.len() { + continue; + } + }, + + ReadState::HostReady { accept } => { + accept(Writer::Host { + values: Box::new(values), + })?; + } + + ReadState::Closed => {} + } + + break Ok(()); + } +} + +pub fn host_read>( + mut store: S, + rep: u32, +) -> Result>>> { + let mut store = store.as_context_mut(); + let (tx, rx) = oneshot::channel(); + let transmit_id = TableId::::new(rep); + let transmit = store + .concurrent_state() + .table + .get_mut(transmit_id) + .with_context(|| rep.to_string())?; + let new_state = if let WriteState::Closed = &transmit.write { + WriteState::Closed + } else { + WriteState::Open + }; + + match mem::replace(&mut transmit.write, new_state) { + WriteState::Open => { + assert!(matches!(&transmit.read, ReadState::Open)); + + transmit.read = ReadState::HostReady { + accept: Box::new(move |writer| { + Ok(match writer { + Writer::Guest { + lift, + ty, + address, + count, + } => { + _ = tx.send( + ty.map(|ty| { + let list = &WasmList::new(address, count, lift, ty)?; + T::load_list(lift, list) + }) + .transpose()?, + ); + count + } + Writer::Host { values } => { + let values = *values + .downcast::>() + .map_err(|_| anyhow!("transmit type mismatch"))?; + let count = values.len(); + _ = tx.send(Some(values)); + count + } + Writer::None => 0, + }) + }), + }; + } + + WriteState::GuestReady { + ty, + flat_abi: _, + options, + address, + count, + instance, + tx: write_tx, + handle, + caller, + close, + } => unsafe { + let types = (*instance.as_ptr()).component_types(); + let lift = &mut LiftContext::new(store.0, &options, types, instance.as_ptr()); + _ = tx.send( + payload(ty, types) + .map(|ty| { + let list = &WasmList::new(address, count, lift, ty)?; + T::load_list(lift, list) + }) + .transpose()?, + ); + + log::trace!( + "remove write child of {}: {}", + caller.rep(), + transmit_id.rep() + ); + store + .concurrent_state() + .table + .remove_child(transmit_id, caller)?; + + if close { + store.concurrent_state().table.get_mut(transmit_id)?.write = WriteState::Closed; + } else { + (*instance.as_ptr()) + .handle_table() + .insert((ty, handle), TableState::Write); + } + + _ = write_tx.send(count); + }, + + WriteState::HostReady { accept, close } => { + accept(Reader::Host { + accept: Box::new(move |any| { + _ = tx.send(Some( + *any.downcast() + .map_err(|_| anyhow!("transmit type mismatch"))?, + )); + Ok(()) + }), + })?; + + if close { + store.concurrent_state().table.get_mut(transmit_id)?.write = WriteState::Closed; + } + } + + WriteState::Closed => {} + } + + Ok(rx) +} + +fn host_close_writer>(mut store: S, rep: u32) -> Result<()> { + let mut store = store.as_context_mut(); + let transmit_id = TableId::::new(rep); + let transmit = store.concurrent_state().table.get_mut(transmit_id)?; + + match &mut transmit.write { + WriteState::GuestReady { close, .. } => { + *close = true; + } + + WriteState::HostReady { close, .. } => { + *close = true; + } + + v @ WriteState::Open => { + *v = WriteState::Closed; + } + + WriteState::Closed => unreachable!(), + } + + let new_state = if let ReadState::Closed = &transmit.read { + ReadState::Closed + } else { + ReadState::Open + }; + + match mem::replace(&mut transmit.read, new_state) { + ReadState::GuestReady { + ty, + tx, + instance, + handle, + .. + } => unsafe { + _ = tx.send(CLOSED); + + (*instance.as_ptr()) + .handle_table() + .insert((ty, handle), TableState::Read); + }, + + ReadState::HostReady { accept } => { + accept(Writer::None)?; + } + + ReadState::Open => {} + + ReadState::Closed => { + log::trace!("host_close_writer delete {}", transmit_id.rep()); + store.concurrent_state().table.delete(transmit_id)?; + } + } + Ok(()) +} + +fn host_close_reader>(mut store: S, rep: u32) -> Result<()> { + let mut store = store.as_context_mut(); + let transmit_id = TableId::::new(rep); + let transmit = store.concurrent_state().table.get_mut(transmit_id)?; + + transmit.read = ReadState::Closed; + + let new_state = if let WriteState::Closed = &transmit.write { + WriteState::Closed + } else { + WriteState::Open + }; + + match mem::replace(&mut transmit.write, new_state) { + WriteState::GuestReady { + ty, + tx, + instance, + handle, + close, + .. + } => unsafe { + _ = tx.send(CLOSED); + + if close { + store.concurrent_state().table.delete(transmit_id)?; + } else { + (*instance.as_ptr()) + .handle_table() + .insert((ty, handle), TableState::Write); + } + }, + + WriteState::HostReady { accept, close } => { + accept(Reader::None)?; + + if close { + store.concurrent_state().table.delete(transmit_id)?; + } + } + + WriteState::Open => {} + + WriteState::Closed => { + log::trace!("host_close_reader delete {}", transmit_id.rep()); + store.concurrent_state().table.delete(transmit_id)?; + } + } + Ok(()) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct FlatAbi { + size: u32, + align: u32, +} + +/// Represents the writable end of a Component Model `future`. +pub struct FutureWriter { + rep: u32, + _phantom: PhantomData, +} + +impl FutureWriter { + /// Write the specified value to this `future`. + pub fn write>(self, store: S, value: T) -> Result<()> + where + T: func::Lower + Send + Sync + 'static, + { + host_write(store, self.rep, vec![value]) + } + + /// Close this object without writing a value. + /// + /// If this object is dropped without calling either this method or `write`, + /// any read on the readable end will remain pending forever. + pub fn close>(self, store: S) -> Result<()> { + host_close_writer(store, self.rep) + } +} + +/// Represents the readable end of a Component Model `future`. +pub struct FutureReader { + rep: u32, + _phantom: PhantomData, +} + +impl FutureReader { + /// Read the value from this `future` by converting it to a `std::future::Future`. + pub fn read>( + self, + store: S, + ) -> Result< + Pin, oneshot::Canceled>> + Send + Sync + 'static>>, + > + where + T: func::Lift + Sync + Send + 'static, + { + host_read(store, self.rep).map(|v| { + Box::pin(v.map(|v| v.map(|v| v.map(|v| v.into_iter().next().unwrap())))) + as Pin< + Box< + dyn Future, oneshot::Canceled>> + + Send + + Sync + + 'static, + >, + > + }) + } + + fn lower_to_index(&self, cx: &mut LowerContext<'_, U>, ty: InterfaceType) -> Result { + match ty { + InterfaceType::Future(dst) => { + unsafe { + assert!((*cx.instance) + .handle_table() + .insert((TableIndex::Future(dst), self.rep), TableState::Read) + .is_none()); + } + + Ok(self.rep) + } + _ => func::bad_type_info(), + } + } + + fn lift_from_index(cx: &mut LiftContext<'_>, ty: InterfaceType, index: u32) -> Result { + match ty { + InterfaceType::Future(src) => { + let handle_table = unsafe { (*cx.instance).handle_table() }; + let src = TableIndex::Future(src); + + let Some(state) = handle_table.get(&(src, index)) else { + bail!("invalid handle"); + }; + + match state { + TableState::Local => { + handle_table.insert((src, index), TableState::Write); + } + TableState::Read => { + handle_table.remove(&(src, index)); + } + TableState::Write => bail!("cannot transfer write end of future"), + TableState::Busy => bail!("cannot transfer busy future"), + } + + Ok(Self { + rep: index, + _phantom: PhantomData, + }) + } + _ => func::bad_type_info(), + } + } + + /// Close this object without reading the value. + /// + /// If this object is dropped without calling either this method or `read`, + /// any write on the writable end will remain pending forever. + pub fn close>(self, store: S) -> Result<()> { + host_close_reader(store, self.rep) + } +} + +unsafe impl func::ComponentType for FutureReader { + const ABI: CanonicalAbiInfo = CanonicalAbiInfo::SCALAR4; + + type Lower = ::Lower; + + fn typecheck(ty: &InterfaceType, _types: &InstanceType<'_>) -> Result<()> { + match ty { + InterfaceType::Future(_) => Ok(()), + other => bail!("expected `future`, found `{}`", func::desc(other)), + } + } +} + +unsafe impl func::Lower for FutureReader { + fn lower( + &self, + cx: &mut LowerContext<'_, U>, + ty: InterfaceType, + dst: &mut MaybeUninit, + ) -> Result<()> { + self.lower_to_index(cx, ty)? + .lower(cx, InterfaceType::U32, dst) + } + + fn store( + &self, + cx: &mut LowerContext<'_, U>, + ty: InterfaceType, + offset: usize, + ) -> Result<()> { + self.lower_to_index(cx, ty)? + .store(cx, InterfaceType::U32, offset) + } +} + +unsafe impl func::Lift for FutureReader { + fn lift(cx: &mut LiftContext<'_>, ty: InterfaceType, src: &Self::Lower) -> Result { + let index = u32::lift(cx, InterfaceType::U32, src)?; + Self::lift_from_index(cx, ty, index) + } + + fn load(cx: &mut LiftContext<'_>, ty: InterfaceType, bytes: &[u8]) -> Result { + let index = u32::load(cx, InterfaceType::U32, bytes)?; + Self::lift_from_index(cx, ty, index) + } +} + +/// Create a new Component Model `future` as pair of writable and readable ends, +/// the latter of which may be passed to guest code. +pub fn future>( + mut store: S, +) -> Result<(FutureWriter, FutureReader)> { + let mut store = store.as_context_mut(); + let transmit = store.concurrent_state().table.push(TransmitState { + read: ReadState::Open, + write: WriteState::Open, + })?; + + Ok(( + FutureWriter { + rep: transmit.rep(), + _phantom: PhantomData, + }, + FutureReader { + rep: transmit.rep(), + _phantom: PhantomData, + }, + )) +} + +/// Represents the writable end of a Component Model `stream`. +pub struct StreamWriter { + rep: u32, + _phantom: PhantomData, +} + +impl StreamWriter { + /// Write the specified values to the `stream`. + pub fn write>(&mut self, store: S, values: Vec) -> Result<()> + where + T: func::Lower + Send + Sync + 'static, + { + host_write(store, self.rep, values) + } + + /// Close this object without writing any more values. + /// + /// If this object is dropped without calling this method, any read on the + /// readable end will remain pending forever. + pub fn close>(self, store: S) -> Result<()> { + host_close_writer(store, self.rep) + } +} + +/// Represents the readable end of a Component Model `stream`. +pub struct StreamReader { + rep: u32, + _phantom: PhantomData, +} + +impl StreamReader { + /// Read the next values (if any) from this `stream`. + pub fn read>( + &mut self, + store: S, + ) -> Result>>> + where + T: func::Lift + Sync + Send + 'static, + { + host_read(store, self.rep) + } + + fn lower_to_index(&self, cx: &mut LowerContext<'_, U>, ty: InterfaceType) -> Result { + match ty { + InterfaceType::Stream(dst) => { + unsafe { + assert!((*cx.instance) + .handle_table() + .insert((TableIndex::Stream(dst), self.rep), TableState::Read) + .is_none()); + } + + Ok(self.rep) + } + _ => func::bad_type_info(), + } + } + + fn lift_from_index(cx: &mut LiftContext<'_>, ty: InterfaceType, index: u32) -> Result { + match ty { + InterfaceType::Stream(src) => { + let handle_table = unsafe { (*cx.instance).handle_table() }; + let src = TableIndex::Stream(src); + + let Some(state) = handle_table.get(&(src, index)) else { + bail!("invalid handle"); + }; + + match state { + TableState::Local => { + handle_table.insert((src, index), TableState::Write); + } + TableState::Read => { + handle_table.remove(&(src, index)); + } + TableState::Write => bail!("cannot transfer write end of stream"), + TableState::Busy => bail!("cannot transfer busy stream"), + } + + Ok(Self { + rep: index, + _phantom: PhantomData, + }) + } + _ => func::bad_type_info(), + } + } + + /// Close this object without reading any more values. + /// + /// If the object is dropped without calling this method, any write on the + /// writable end will remain pending forever. + pub fn close>(self, store: S) -> Result<()> { + host_close_reader(store, self.rep) + } +} + +unsafe impl func::ComponentType for StreamReader { + const ABI: CanonicalAbiInfo = CanonicalAbiInfo::SCALAR4; + + type Lower = ::Lower; + + fn typecheck(ty: &InterfaceType, _types: &InstanceType<'_>) -> Result<()> { + match ty { + InterfaceType::Stream(_) => Ok(()), + other => bail!("expected `stream`, found `{}`", func::desc(other)), + } + } +} + +unsafe impl func::Lower for StreamReader { + fn lower( + &self, + cx: &mut LowerContext<'_, U>, + ty: InterfaceType, + dst: &mut MaybeUninit, + ) -> Result<()> { + self.lower_to_index(cx, ty)? + .lower(cx, InterfaceType::U32, dst) + } + + fn store( + &self, + cx: &mut LowerContext<'_, U>, + ty: InterfaceType, + offset: usize, + ) -> Result<()> { + self.lower_to_index(cx, ty)? + .store(cx, InterfaceType::U32, offset) + } +} + +unsafe impl func::Lift for StreamReader { + fn lift(cx: &mut LiftContext<'_>, ty: InterfaceType, src: &Self::Lower) -> Result { + let index = u32::lift(cx, InterfaceType::U32, src)?; + Self::lift_from_index(cx, ty, index) + } + + fn load(cx: &mut LiftContext<'_>, ty: InterfaceType, bytes: &[u8]) -> Result { + let index = u32::load(cx, InterfaceType::U32, bytes)?; + Self::lift_from_index(cx, ty, index) + } +} + +/// Create a new Component Model `stream` as pair of writable and readable ends, +/// the latter of which may be passed to guest code. +pub fn stream>( + mut store: S, +) -> Result<(StreamWriter, StreamReader)> { + let mut store = store.as_context_mut(); + let transmit = store.concurrent_state().table.push(TransmitState { + read: ReadState::Open, + write: WriteState::Open, + })?; + + Ok(( + StreamWriter { + rep: transmit.rep(), + _phantom: PhantomData, + }, + StreamReader { + rep: transmit.rep(), + _phantom: PhantomData, + }, + )) +} + +/// Represents a Component Model `error-context`. +pub struct ErrorContext { + rep: u32, +} + +impl ErrorContext { + fn lower_to_index(&self, cx: &mut LowerContext<'_, U>, ty: InterfaceType) -> Result { + match ty { + InterfaceType::ErrorContext(dst) => { + unsafe { + *(*cx.instance) + .error_context_table() + .entry((dst, self.rep)) + .or_default() += 1; + } + Ok(self.rep) + } + _ => func::bad_type_info(), + } + } + + fn lift_from_index(cx: &mut LiftContext<'_>, ty: InterfaceType, index: u32) -> Result { + match ty { + InterfaceType::ErrorContext(src) => { + if !unsafe { + (*cx.instance) + .error_context_table() + .contains_key(&(src, index)) + } { + bail!("invalid handle"); + } + Ok(Self { rep: index }) + } + _ => func::bad_type_info(), + } + } +} + +unsafe impl func::ComponentType for ErrorContext { + const ABI: CanonicalAbiInfo = CanonicalAbiInfo::SCALAR4; + + type Lower = ::Lower; + + fn typecheck(ty: &InterfaceType, _types: &InstanceType<'_>) -> Result<()> { + match ty { + InterfaceType::ErrorContext(_) => Ok(()), + other => bail!("expected `error`, found `{}`", func::desc(other)), + } + } +} + +unsafe impl func::Lower for ErrorContext { + fn lower( + &self, + cx: &mut LowerContext<'_, T>, + ty: InterfaceType, + dst: &mut MaybeUninit, + ) -> Result<()> { + self.lower_to_index(cx, ty)? + .lower(cx, InterfaceType::U32, dst) + } + + fn store( + &self, + cx: &mut LowerContext<'_, T>, + ty: InterfaceType, + offset: usize, + ) -> Result<()> { + self.lower_to_index(cx, ty)? + .store(cx, InterfaceType::U32, offset) + } +} + +unsafe impl func::Lift for ErrorContext { + fn lift(cx: &mut LiftContext<'_>, ty: InterfaceType, src: &Self::Lower) -> Result { + let index = u32::lift(cx, InterfaceType::U32, src)?; + Self::lift_from_index(cx, ty, index) + } + + fn load(cx: &mut LiftContext<'_>, ty: InterfaceType, bytes: &[u8]) -> Result { + let index = u32::load(cx, InterfaceType::U32, bytes)?; + Self::lift_from_index(cx, ty, index) + } +} + +pub(super) struct TransmitState { + write: WriteState, + read: ReadState, +} + +enum WriteState { + Open, + GuestReady { + ty: TableIndex, + flat_abi: Option, + options: Options, + address: usize, + count: usize, + instance: SendSyncPtr, + tx: oneshot::Sender, + handle: u32, + caller: TableId, + close: bool, + }, + HostReady { + accept: Box Result + Send + Sync>, + close: bool, + }, + Closed, +} + +enum ReadState { + Open, + GuestReady { + ty: TableIndex, + flat_abi: Option, + options: Options, + address: usize, + count: usize, + instance: SendSyncPtr, + tx: oneshot::Sender, + handle: u32, + caller: TableId, + }, + HostReady { + accept: Box Result + Send + Sync>, + }, + Closed, +} + +enum Writer<'a> { + Guest { + lift: &'a mut LiftContext<'a>, + ty: Option, + address: usize, + count: usize, + }, + Host { + values: Box, + }, + None, +} + +struct RawLowerContext<'a> { + store: *mut dyn VMStore, + options: &'a Options, + types: &'a Arc, + instance: *mut ComponentInstance, +} + +enum Reader<'a> { + Guest { + lower: RawLowerContext<'a>, + ty: TableIndex, + address: usize, + count: usize, + }, + Host { + accept: Box) -> Result<()>>, + }, + None, +} + +fn guest_new(vmctx: *mut VMOpaqueContext, ty: TableIndex) -> u32 { + unsafe { + handle_result(|| { + let cx = VMComponentContext::from_opaque(vmctx); + let instance = (*cx).instance(); + let mut cx = StoreContextMut::(&mut *(*instance).store().cast()); + let transmit = cx.concurrent_state().table.push(TransmitState { + read: ReadState::Open, + write: WriteState::Open, + })?; + (*instance) + .handle_table() + .insert((ty, transmit.rep()), TableState::Local); + Ok(transmit.rep()) + }) + } +} + +unsafe fn copy( + mut cx: StoreContextMut<'_, T>, + types: &Arc, + instance: *mut ComponentInstance, + flat_abi: Option, + write_ty: TableIndex, + write_options: &Options, + write_address: usize, + read_ty: TableIndex, + read_options: &Options, + read_address: usize, + count: usize, +) -> Result<()> { + match (write_ty, read_ty) { + (TableIndex::Future(write_ty), TableIndex::Future(read_ty)) => { + assert_eq!(count, 1); + + let val = types[types[write_ty].ty] + .payload + .map(|ty| { + let lift = &mut LiftContext::new(cx.0, write_options, types, instance); + Val::load( + lift, + ty, + &lift.memory()[usize::try_from(write_address).unwrap()..] + [..usize::try_from(types.canonical_abi(&ty).size32).unwrap()], + ) + }) + .transpose()?; + + let mut lower = LowerContext::new(cx.as_context_mut(), read_options, types, instance); + if let Some(val) = val { + val.store( + &mut lower, + types[types[read_ty].ty].payload.unwrap(), + usize::try_from(read_address).unwrap(), + )?; + } + } + (TableIndex::Stream(write_ty), TableIndex::Stream(read_ty)) => { + let lift = &mut LiftContext::new(cx.0, write_options, types, instance); + if let Some(flat_abi) = flat_abi { + // Fast path memcpy for "flat" (i.e. no pointers or handles) payloads: + let length_in_bytes = usize::try_from(flat_abi.size).unwrap() * count; + + { + let src = + write_options.memory(cx.0)[write_address..][..length_in_bytes].as_ptr(); + let dst = read_options.memory_mut(cx.0)[read_address..][..length_in_bytes] + .as_mut_ptr(); + src.copy_to(dst, length_in_bytes); + } + } else { + let ty = types[types[write_ty].ty].payload; + let abi = lift.types.canonical_abi(&ty); + let size = usize::try_from(abi.size32).unwrap(); + let values = (0..count) + .map(|index| { + Val::load( + lift, + ty, + &lift.memory()[write_address + (index * size)..][..size], + ) + }) + .collect::>>()?; + + let lower = + &mut LowerContext::new(cx.as_context_mut(), read_options, types, instance); + let mut ptr = read_address; + let ty = types[types[read_ty].ty].payload; + let abi = lower.types.canonical_abi(&ty); + let size = usize::try_from(abi.size32).unwrap(); + for value in values { + value.store(lower, ty, ptr)?; + ptr += size + } + } + } + _ => unreachable!(), + } + + Ok(()) +} + +fn guest_write( + vmctx: *mut VMOpaqueContext, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: StringEncoding, + ty: TableIndex, + flat_abi: Option, + handle: u32, + address: usize, + count: usize, +) -> usize { + unsafe { + handle_result(|| { + let cx = VMComponentContext::from_opaque(vmctx); + let instance = (*cx).instance(); + let mut cx = StoreContextMut::(&mut *(*instance).store().cast()); + let options = Options::new( + cx.0.id(), + NonNull::new(memory), + NonNull::new(realloc), + string_encoding, + true, + None, + ); + let types = (*instance).component_types(); + let Some(TableState::Write) = (*instance) + .handle_table() + .insert((ty, handle), TableState::Busy) + else { + bail!("invalid handle"); + }; + let transmit_id = TableId::::new(handle); + let transmit = cx.concurrent_state().table.get_mut(transmit_id)?; + let new_state = if let ReadState::Closed = &transmit.read { + ReadState::Closed + } else { + ReadState::Open + }; + + let result = match mem::replace(&mut transmit.read, new_state) { + ReadState::GuestReady { + ty: read_ty, + flat_abi: read_flat_abi, + options: read_options, + address: read_address, + count: read_count, + instance: _, + tx: read_tx, + handle: read_handle, + caller: read_caller, + } => { + assert_eq!(flat_abi, read_flat_abi); + + let count = count.min(read_count); + + copy( + cx.as_context_mut(), + types, + instance, + flat_abi, + ty, + &options, + address, + read_ty, + &read_options, + read_address, + count, + )?; + + log::trace!( + "remove read child of {}: {}", + read_caller.rep(), + transmit_id.rep() + ); + cx.concurrent_state() + .table + .remove_child(transmit_id, read_caller)?; + + (*instance) + .handle_table() + .insert((read_ty, read_handle), TableState::Read); + + _ = read_tx.send(count); + + count + } + + ReadState::HostReady { accept } => { + let lift = &mut LiftContext::new(cx.0, &options, types, instance); + accept(Writer::Guest { + lift, + ty: payload(ty, types), + address, + count, + })? + } + + ReadState::Open => { + assert!(matches!(&transmit.write, WriteState::Open)); + + let (event, name) = match ty { + TableIndex::Future(_) => (events::EVENT_FUTURE_READ, "future"), + TableIndex::Stream(_) => (events::EVENT_STREAM_READ, "stream"), + }; + let caller = cx.concurrent_state().guest_task.unwrap(); + log::trace!( + "add write {name} child of {}: {}", + caller.rep(), + transmit_id.rep() + ); + cx.concurrent_state().table.add_child(transmit_id, caller)?; + let (tx, rx) = oneshot::channel(); + let future = Box::pin(rx.map(move |result| { + ( + handle, + Box::new(move |_| { + Ok(HostTaskResult { + event, + param: u32::try_from(result?).unwrap(), + caller, + }) + }) + as Box Result>, + ) + })) as HostTaskFuture; + cx.concurrent_state().futures.get_mut().push(future); + + let transmit = cx.concurrent_state().table.get_mut(transmit_id)?; + transmit.write = WriteState::GuestReady { + ty, + flat_abi, + options, + address: usize::try_from(address).unwrap(), + count: usize::try_from(count).unwrap(), + instance: SendSyncPtr::new(NonNull::new(instance).unwrap()), + tx, + handle, + caller, + close: false, + }; + + BLOCKED + } + + ReadState::Closed => CLOSED, + }; + + if result != BLOCKED { + (*instance) + .handle_table() + .insert((ty, handle), TableState::Write); + } + + Ok(result) + }) + } +} + +fn guest_read( + vmctx: *mut VMOpaqueContext, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: StringEncoding, + ty: TableIndex, + flat_abi: Option, + handle: u32, + address: usize, + count: usize, +) -> usize { + unsafe { + handle_result(|| { + let cx = VMComponentContext::from_opaque(vmctx); + let instance = (*cx).instance(); + let mut cx = StoreContextMut::(&mut *(*instance).store().cast()); + let options = Options::new( + cx.0.id(), + NonNull::new(memory), + NonNull::new(realloc), + string_encoding, + true, + None, + ); + let types = (*instance).component_types(); + let Some(TableState::Read) = (*instance) + .handle_table() + .insert((ty, handle), TableState::Busy) + else { + bail!("invalid handle"); + }; + let transmit_id = TableId::::new(handle); + let transmit = cx.concurrent_state().table.get_mut(transmit_id)?; + let new_state = if let WriteState::Closed = &transmit.write { + WriteState::Closed + } else { + WriteState::Open + }; + + let result = match mem::replace(&mut transmit.write, new_state) { + WriteState::GuestReady { + ty: write_ty, + flat_abi: write_flat_abi, + options: write_options, + address: write_address, + count: write_count, + instance: _, + tx: write_tx, + handle: write_handle, + caller: write_caller, + close, + } => { + assert_eq!(flat_abi, write_flat_abi); + + let count = count.min(write_count); + + copy( + cx.as_context_mut(), + types, + instance, + flat_abi, + write_ty, + &write_options, + write_address, + ty, + &options, + address, + count, + )?; + + log::trace!( + "remove write child of {}: {}", + write_caller.rep(), + transmit_id.rep() + ); + cx.concurrent_state() + .table + .remove_child(transmit_id, write_caller)?; + + if close { + cx.concurrent_state().table.get_mut(transmit_id)?.write = + WriteState::Closed; + } else { + (*instance) + .handle_table() + .insert((write_ty, write_handle), TableState::Write); + } + + _ = write_tx.send(count); + + count + } + + WriteState::HostReady { accept, close } => { + let count = accept(Reader::Guest { + lower: RawLowerContext { + store: cx.0.traitobj(), + options: &options, + types, + instance, + }, + ty, + address: usize::try_from(address).unwrap(), + count, + })?; + + if close { + cx.concurrent_state().table.get_mut(transmit_id)?.write = + WriteState::Closed; + } + + count + } + + WriteState::Open => { + assert!(matches!(&transmit.read, ReadState::Open)); + + let (event, name) = match ty { + TableIndex::Future(_) => (events::EVENT_FUTURE_READ, "future"), + TableIndex::Stream(_) => (events::EVENT_STREAM_READ, "stream"), + }; + let caller = cx.concurrent_state().guest_task.unwrap(); + log::trace!( + "add read {name} child of {}: {}", + caller.rep(), + transmit_id.rep() + ); + cx.concurrent_state().table.add_child(transmit_id, caller)?; + let (tx, rx) = oneshot::channel(); + let future = Box::pin(rx.map(move |result| { + ( + transmit_id.rep(), + Box::new(move |_| { + Ok(HostTaskResult { + event, + param: u32::try_from(result?).unwrap(), + caller, + }) + }) + as Box Result>, + ) + })) as HostTaskFuture; + cx.concurrent_state().futures.get_mut().push(future); + + let transmit = cx.concurrent_state().table.get_mut(transmit_id)?; + transmit.read = ReadState::GuestReady { + ty, + flat_abi, + options, + address: usize::try_from(address).unwrap(), + count: usize::try_from(count).unwrap(), + instance: SendSyncPtr::new(NonNull::new(instance).unwrap()), + tx, + handle, + caller, + }; + + BLOCKED + } + + WriteState::Closed => CLOSED, + }; + + if result != BLOCKED { + (*instance) + .handle_table() + .insert((ty, handle), TableState::Read); + } + + Ok(result) + }) + } +} + +fn guest_close_writable(vmctx: *mut VMOpaqueContext, ty: TableIndex, writer: u32, error: u32) { + unsafe { + handle_result(|| { + if error != 0 { + bail!("todo: closing writable streams and futures with errors not yet implemented"); + } + + let cx = VMComponentContext::from_opaque(vmctx); + let instance = (*cx).instance(); + let cx = StoreContextMut::(&mut *(*instance).store().cast()); + let Some(state) = (*instance).handle_table().remove(&(ty, writer)) else { + bail!("invalid handle"); + }; + match state { + TableState::Local | TableState::Write => {} + TableState::Read => bail!("passed read end to `{{stream|future}}.close-writable`"), + TableState::Busy => bail!("cannot drop busy stream or future"), + } + host_close_writer(cx, writer) + }) + } +} + +fn guest_close_readable(vmctx: *mut VMOpaqueContext, ty: TableIndex, reader: u32) { + unsafe { + handle_result(|| { + let cx = VMComponentContext::from_opaque(vmctx); + let instance = (*cx).instance(); + let cx = StoreContextMut::(&mut *(*instance).store().cast()); + let Some(state) = (*instance).handle_table().remove(&(ty, reader)) else { + bail!("invalid handle"); + }; + match state { + TableState::Local | TableState::Read => {} + TableState::Write => { + bail!("passed write end to `{{stream|future}}.close-readable`") + } + TableState::Busy => bail!("cannot drop busy stream or future"), + } + host_close_reader(cx, reader) + }) + } +} + +pub(crate) extern "C" fn future_new( + vmctx: *mut VMOpaqueContext, + ty: TypeFutureTableIndex, +) -> u32 { + guest_new::(vmctx, TableIndex::Future(ty)) +} + +pub(crate) extern "C" fn future_write( + vmctx: *mut VMOpaqueContext, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: StringEncoding, + ty: TypeFutureTableIndex, + future: u32, + address: u32, +) -> u32 { + u32::try_from(guest_write::( + vmctx, + memory, + realloc, + string_encoding, + TableIndex::Future(ty), + None, + future, + usize::try_from(address).unwrap(), + 1, + )) + .unwrap() +} + +pub(crate) extern "C" fn future_read( + vmctx: *mut VMOpaqueContext, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: StringEncoding, + ty: TypeFutureTableIndex, + future: u32, + address: u32, +) -> u32 { + u32::try_from(guest_read::( + vmctx, + memory, + realloc, + string_encoding, + TableIndex::Future(ty), + None, + future, + usize::try_from(address).unwrap(), + 1, + )) + .unwrap() +} + +pub(crate) extern "C" fn future_cancel_write( + vmctx: *mut VMOpaqueContext, + ty: TypeFutureTableIndex, + async_: bool, + writer: u32, +) -> u32 { + unsafe { + handle_result(|| { + _ = (vmctx, ty, async_, writer); + bail!("todo: `future.cancel-write` not yet implemented"); + }) + } +} + +pub(crate) extern "C" fn future_cancel_read( + vmctx: *mut VMOpaqueContext, + ty: TypeFutureTableIndex, + async_: bool, + reader: u32, +) -> u32 { + unsafe { + handle_result(|| { + _ = (vmctx, ty, async_, reader); + bail!("todo: `future.cancel-read` not yet implemented"); + }) + } +} + +pub(crate) extern "C" fn future_close_writable( + vmctx: *mut VMOpaqueContext, + ty: TypeFutureTableIndex, + writer: u32, + error: u32, +) { + guest_close_writable::(vmctx, TableIndex::Future(ty), writer, error) +} + +pub(crate) extern "C" fn future_close_readable( + vmctx: *mut VMOpaqueContext, + ty: TypeFutureTableIndex, + reader: u32, +) { + guest_close_readable::(vmctx, TableIndex::Future(ty), reader) +} + +pub(crate) extern "C" fn stream_new( + vmctx: *mut VMOpaqueContext, + ty: TypeStreamTableIndex, +) -> u32 { + guest_new::(vmctx, TableIndex::Stream(ty)) +} + +pub(crate) extern "C" fn stream_write( + vmctx: *mut VMOpaqueContext, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: StringEncoding, + ty: TypeStreamTableIndex, + stream: u32, + address: u32, + count: u32, +) -> u32 { + u32::try_from(guest_write::( + vmctx, + memory, + realloc, + string_encoding, + TableIndex::Stream(ty), + None, + stream, + usize::try_from(address).unwrap(), + usize::try_from(count).unwrap(), + )) + .unwrap() +} + +pub(crate) extern "C" fn stream_read( + vmctx: *mut VMOpaqueContext, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: StringEncoding, + ty: TypeStreamTableIndex, + stream: u32, + address: u32, + count: u32, +) -> u32 { + u32::try_from(guest_read::( + vmctx, + memory, + realloc, + string_encoding, + TableIndex::Stream(ty), + None, + stream, + usize::try_from(address).unwrap(), + usize::try_from(count).unwrap(), + )) + .unwrap() +} + +pub(crate) extern "C" fn stream_cancel_write( + vmctx: *mut VMOpaqueContext, + ty: TypeStreamTableIndex, + async_: bool, + writer: u32, +) -> u32 { + unsafe { + handle_result(|| { + _ = (vmctx, ty, async_, writer); + bail!("todo: `stream.cancel-write` not yet implemented"); + }) + } +} + +pub(crate) extern "C" fn stream_cancel_read( + vmctx: *mut VMOpaqueContext, + ty: TypeStreamTableIndex, + async_: bool, + reader: u32, +) -> u32 { + unsafe { + handle_result(|| { + _ = (vmctx, ty, async_, reader); + bail!("todo: `stream.cancel-read` not yet implemented"); + }) + } +} + +pub(crate) extern "C" fn stream_close_writable( + vmctx: *mut VMOpaqueContext, + ty: TypeStreamTableIndex, + writer: u32, + error: u32, +) { + guest_close_writable::(vmctx, TableIndex::Stream(ty), writer, error) +} + +pub(crate) extern "C" fn stream_close_readable( + vmctx: *mut VMOpaqueContext, + ty: TypeStreamTableIndex, + reader: u32, +) { + guest_close_readable::(vmctx, TableIndex::Stream(ty), reader) +} + +pub(crate) extern "C" fn flat_stream_write( + vmctx: *mut VMOpaqueContext, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + ty: TypeStreamTableIndex, + payload_size: u32, + payload_align: u32, + stream: u32, + address: u32, + count: u32, +) -> u32 { + u32::try_from(guest_write::( + vmctx, + memory, + realloc, + StringEncoding::Utf8, + TableIndex::Stream(ty), + Some(FlatAbi { + size: payload_size, + align: payload_align, + }), + stream, + usize::try_from(address).unwrap(), + usize::try_from(count).unwrap(), + )) + .unwrap() +} + +pub(crate) extern "C" fn flat_stream_read( + vmctx: *mut VMOpaqueContext, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + ty: TypeStreamTableIndex, + payload_size: u32, + payload_align: u32, + stream: u32, + address: u32, + count: u32, +) -> u32 { + u32::try_from(guest_read::( + vmctx, + memory, + realloc, + StringEncoding::Utf8, + TableIndex::Stream(ty), + Some(FlatAbi { + size: payload_size, + align: payload_align, + }), + stream, + usize::try_from(address).unwrap(), + usize::try_from(count).unwrap(), + )) + .unwrap() +} + +pub(crate) extern "C" fn error_context_new( + vmctx: *mut VMOpaqueContext, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: StringEncoding, + ty: TypeErrorContextTableIndex, + address: u32, + count: u32, +) -> u32 { + unsafe { + handle_result(|| { + _ = (vmctx, memory, realloc, string_encoding, ty, address, count); + bail!("todo: `error.new` not yet implemented"); + }) + } +} + +pub(crate) extern "C" fn error_context_debug_message( + vmctx: *mut VMOpaqueContext, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: StringEncoding, + ty: TypeErrorContextTableIndex, + handle: u32, + address: u32, +) { + unsafe { + handle_result(|| { + _ = (vmctx, memory, realloc, string_encoding, ty, handle, address); + bail!("todo: `error.debug-message` not yet implemented"); + }) + } +} + +pub(crate) extern "C" fn error_context_drop( + vmctx: *mut VMOpaqueContext, + ty: TypeErrorContextTableIndex, + error: u32, +) { + unsafe { + handle_result(|| { + let cx = VMComponentContext::from_opaque(vmctx); + let instance = (*cx).instance(); + let count = if let Some(count) = (*instance).error_context_table().get_mut(&(ty, error)) + { + assert!(*count > 0); + *count -= 1; + *count + } else { + bail!("invalid handle"); + }; + + if count == 0 { + (*instance).error_context_table().remove(&(ty, error)); + } + + Ok(()) + }) + } +} diff --git a/crates/wasmtime/src/runtime/component/concurrent/ready_chunks.rs b/crates/wasmtime/src/runtime/component/concurrent/ready_chunks.rs new file mode 100644 index 000000000000..f82bddcee4c7 --- /dev/null +++ b/crates/wasmtime/src/runtime/component/concurrent/ready_chunks.rs @@ -0,0 +1,59 @@ +//! Like `futures::stream::ReadyChunks` but without fusing the inner stream. +//! +//! We use this with `FuturesUnordered` which may produce `Poll::Ready(None)` but later produce more elements due +//! to additional futures having been added, so fusing is not appropriate. + +use { + futures::{Stream, StreamExt}, + std::{ + pin::Pin, + task::{Context, Poll}, + vec::Vec, + }, +}; + +pub struct ReadyChunks { + stream: S, + capacity: usize, +} + +impl ReadyChunks { + pub fn new(stream: S, capacity: usize) -> Self { + Self { stream, capacity } + } + + pub fn get_mut(&mut self) -> &mut S { + &mut self.stream + } +} + +impl Stream for ReadyChunks { + type Item = Vec; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut items = Vec::new(); + + loop { + match self.stream.poll_next_unpin(cx) { + Poll::Pending => { + break if items.is_empty() { + Poll::Pending + } else { + Poll::Ready(Some(items)) + } + } + + Poll::Ready(Some(item)) => { + items.push(item); + if items.len() >= self.capacity { + break Poll::Ready(Some(items)); + } + } + + Poll::Ready(None) => { + break Poll::Ready(if items.is_empty() { None } else { Some(items) }); + } + } + } + } +} diff --git a/crates/wasmtime/src/runtime/component/concurrent/table.rs b/crates/wasmtime/src/runtime/component/concurrent/table.rs new file mode 100644 index 000000000000..ffad2579d7f4 --- /dev/null +++ b/crates/wasmtime/src/runtime/component/concurrent/table.rs @@ -0,0 +1,312 @@ +// TODO: This duplicates a lot of resource_table.rs; consider reducing that +// duplication. +// +// The main difference between this and resource_table.rs is that the key type, +// `TableId` implements `Copy`, making them much easier to work with than +// `Resource`. I've also added a `Table::delete_any` function, useful for +// implementing `subtask.drop`. + +use std::{any::Any, boxed::Box, collections::BTreeSet, marker::PhantomData, vec::Vec}; + +pub struct TableId { + rep: u32, + _marker: PhantomData T>, +} + +impl TableId { + pub fn new(rep: u32) -> Self { + Self { + rep, + _marker: PhantomData, + } + } +} + +impl Clone for TableId { + fn clone(&self) -> Self { + Self::new(self.rep) + } +} + +impl Copy for TableId {} + +impl TableId { + pub fn rep(&self) -> u32 { + self.rep + } +} + +#[derive(Debug)] +/// Errors returned by operations on `Table` +pub enum TableError { + /// Table has no free keys + Full, + /// Entry not present in table + NotPresent, + /// Resource present in table, but with a different type + WrongType, + /// Entry cannot be deleted because child entrys exist in the table. + HasChildren, +} + +impl std::fmt::Display for TableError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Full => write!(f, "table has no free keys"), + Self::NotPresent => write!(f, "entry not present"), + Self::WrongType => write!(f, "entry is of another type"), + Self::HasChildren => write!(f, "entry has children"), + } + } +} +impl std::error::Error for TableError {} + +/// The `Table` type maps a `TableId` to its entry. +#[derive(Default)] +pub struct Table { + entries: Vec, + free_head: Option, +} + +enum Entry { + Free { next: Option }, + Occupied { entry: TableEntry }, +} + +impl Entry { + pub fn occupied(&self) -> Option<&TableEntry> { + match self { + Self::Occupied { entry } => Some(entry), + Self::Free { .. } => None, + } + } + + pub fn occupied_mut(&mut self) -> Option<&mut TableEntry> { + match self { + Self::Occupied { entry } => Some(entry), + Self::Free { .. } => None, + } + } +} + +/// This structure tracks parent and child relationships for a given table entry. +/// +/// Parents and children are referred to by table index. We maintain the +/// following invariants to prevent orphans and cycles: +/// * parent can only be assigned on creating the entry. +/// * parent, if some, must exist when creating the entry. +/// * whenever a child is created, its index is added to children. +/// * whenever a child is deleted, its index is removed from children. +/// * an entry with children may not be deleted. +struct TableEntry { + /// The entry in the table + entry: Box, + /// The index of the parent of this entry, if it has one. + parent: Option, + /// The indicies of any children of this entry. + children: BTreeSet, +} + +impl TableEntry { + fn new(entry: Box, parent: Option) -> Self { + Self { + entry, + parent, + children: BTreeSet::new(), + } + } + fn add_child(&mut self, child: u32) { + assert!(self.children.insert(child)); + } + fn remove_child(&mut self, child: u32) { + assert!(self.children.remove(&child)); + } +} + +impl Table { + /// Create an empty table + pub fn new() -> Self { + let mut me = Self { + entries: Vec::new(), + free_head: None, + }; + + // TODO: remove this once we've stopped exposing these indexes to guest code: + me.push(()).unwrap(); + + me + } + + /// Inserts a new entry into this table, returning a corresponding + /// `TableId` which can be used to refer to it after it was inserted. + pub fn push(&mut self, entry: T) -> Result, TableError> { + let idx = self.push_(TableEntry::new(Box::new(entry), None))?; + Ok(TableId::new(idx)) + } + + /// Pop an index off of the free list, if it's not empty. + fn pop_free_list(&mut self) -> Option { + if let Some(ix) = self.free_head { + // Advance free_head to the next entry if one is available. + match &self.entries[ix] { + Entry::Free { next } => self.free_head = *next, + Entry::Occupied { .. } => unreachable!(), + } + Some(ix) + } else { + None + } + } + + /// Free an entry in the table, returning its [`TableEntry`]. Add the index to the free list. + fn free_entry(&mut self, ix: usize) -> TableEntry { + let entry = match std::mem::replace( + &mut self.entries[ix], + Entry::Free { + next: self.free_head, + }, + ) { + Entry::Occupied { entry } => entry, + Entry::Free { .. } => unreachable!(), + }; + + self.free_head = Some(ix); + + entry + } + + /// Push a new entry into the table, returning its handle. This will prefer to use free entries + /// if they exist, falling back on pushing new entries onto the end of the table. + fn push_(&mut self, e: TableEntry) -> Result { + if let Some(free) = self.pop_free_list() { + self.entries[free] = Entry::Occupied { entry: e }; + Ok(free as u32) + } else { + let ix = self + .entries + .len() + .try_into() + .map_err(|_| TableError::Full)?; + self.entries.push(Entry::Occupied { entry: e }); + Ok(ix) + } + } + + fn occupied(&self, key: u32) -> Result<&TableEntry, TableError> { + self.entries + .get(key as usize) + .and_then(Entry::occupied) + .ok_or(TableError::NotPresent) + } + + fn occupied_mut(&mut self, key: u32) -> Result<&mut TableEntry, TableError> { + self.entries + .get_mut(key as usize) + .and_then(Entry::occupied_mut) + .ok_or(TableError::NotPresent) + } + + /// Insert a entry at the next available index, and track that it has a + /// parent entry. + /// + /// The parent must exist to create a child. All child entrys must be + /// destroyed before a parent can be destroyed - otherwise [`Table::delete`] + /// will fail with [`TableError::HasChildren`]. + /// + /// Parent-child relationships are tracked inside the table to ensure that a + /// parent is not deleted while it has live children. This allows children + /// to hold "references" to a parent by table index, to avoid needing + /// e.g. an `Arc>` and the associated locking overhead and + /// design issues, such as child existence extending lifetime of parent + /// referent even after parent is destroyed, possibility for deadlocks. + /// + /// Parent-child relationships may not be modified once created. There is no + /// way to observe these relationships through the [`Table`] methods except + /// for erroring on deletion, or the [`std::fmt::Debug`] impl. + pub fn push_child( + &mut self, + entry: T, + parent: TableId, + ) -> Result, TableError> { + let parent = parent.rep(); + self.occupied(parent)?; + let child = self.push_(TableEntry::new(Box::new(entry), Some(parent)))?; + self.occupied_mut(parent)?.add_child(child); + Ok(TableId::new(child)) + } + + pub fn add_child( + &mut self, + child: TableId, + parent: TableId, + ) -> Result<(), TableError> { + self.occupied(child.rep())?; + self.occupied_mut(parent.rep())?.add_child(child.rep()); + Ok(()) + } + + pub fn remove_child( + &mut self, + child: TableId, + parent: TableId, + ) -> Result<(), TableError> { + self.occupied(child.rep())?; + self.occupied_mut(parent.rep())?.remove_child(child.rep()); + Ok(()) + } + + /// Get an immutable reference to a task of a given type at a given index. + /// + /// Multiple shared references can be borrowed at any given time. + pub fn get(&self, key: TableId) -> Result<&T, TableError> { + self.get_(key.rep())? + .downcast_ref() + .ok_or(TableError::WrongType) + } + + fn get_(&self, key: u32) -> Result<&dyn Any, TableError> { + let r = self.occupied(key)?; + Ok(&*r.entry) + } + + /// Get an mutable reference to a task of a given type at a given index. + pub fn get_mut(&mut self, key: TableId) -> Result<&mut T, TableError> { + self.get_mut_(key.rep())? + .downcast_mut() + .ok_or(TableError::WrongType) + } + + pub fn get_mut_(&mut self, key: u32) -> Result<&mut dyn Any, TableError> { + let r = self.occupied_mut(key)?; + Ok(&mut *r.entry) + } + + /// Delete the specified task + pub fn delete(&mut self, key: TableId) -> Result { + self.delete_entry(key.rep())? + .entry + .downcast() + .map(|v| *v) + .map_err(|_| TableError::WrongType) + } + + pub fn delete_any(&mut self, key: u32) -> Result, TableError> { + Ok(self.delete_entry(key)?.entry) + } + + fn delete_entry(&mut self, key: u32) -> Result { + if !self.occupied(key)?.children.is_empty() { + return Err(TableError::HasChildren); + } + let e = self.free_entry(key as usize); + if let Some(parent) = e.parent { + // Remove deleted task from parent's child list. Parent must still + // be present because it cant be deleted while still having + // children: + self.occupied_mut(parent) + .expect("missing parent") + .remove_child(key); + } + Ok(e) + } +} diff --git a/crates/wasmtime/src/runtime/component/func.rs b/crates/wasmtime/src/runtime/component/func.rs index c7fda91181ea..0485e053601e 100644 --- a/crates/wasmtime/src/runtime/component/func.rs +++ b/crates/wasmtime/src/runtime/component/func.rs @@ -36,16 +36,16 @@ union ParamsAndResults { /// [`wasmtime::Func`](crate::Func) it's possible to call functions either /// synchronously or asynchronously and either typed or untyped. #[derive(Copy, Clone, Debug)] -pub struct Func(Stored); +pub struct Func(pub(crate) Stored); #[doc(hidden)] pub struct FuncData { export: ExportFunction, ty: TypeFuncIndex, types: Arc, - options: Options, + pub(crate) options: Options, instance: Instance, - component_instance: RuntimeComponentInstanceIndex, + pub(crate) component_instance: RuntimeComponentInstanceIndex, post_return: Option, post_return_arg: Option, } @@ -72,7 +72,19 @@ impl Func { ExportFunction { func_ref } }); let component_instance = options.instance; - let options = unsafe { Options::new(store.id(), memory, realloc, options.string_encoding) }; + let callback = options + .callback + .map(|i| data.instance().runtime_callback(i)); + let options = unsafe { + Options::new( + store.id(), + memory, + realloc, + options.string_encoding, + options.async_, + callback, + ) + }; Func(store.store_data_mut().insert(FuncData { export, options, @@ -267,9 +279,9 @@ impl Func { /// Panics if this is called on a function in an asynchronous store. This /// only works with functions defined within a synchronous store. Also /// panics if `store` does not own this function. - pub fn call( + pub fn call( &self, - mut store: impl AsContextMut, + mut store: impl AsContextMut, params: &[Val], results: &mut [Val], ) -> Result<()> { @@ -292,15 +304,12 @@ impl Func { /// only works with functions defined within an asynchronous store. Also /// panics if `store` does not own this function. #[cfg(feature = "async")] - pub async fn call_async( + pub async fn call_async( &self, mut store: impl AsContextMut, params: &[Val], results: &mut [Val], - ) -> Result<()> - where - T: Send, - { + ) -> Result<()> { let mut store = store.as_context_mut(); assert!( store.0.async_support(), @@ -311,13 +320,13 @@ impl Func { .await? } - fn call_impl( + fn call_impl( &self, - mut store: impl AsContextMut, + mut store: impl AsContextMut, params: &[Val], results: &mut [Val], ) -> Result<()> { - let store = &mut store.as_context_mut(); + let store = store.as_context_mut(); let param_tys = self.params(&store); let result_tys = self.results(&store); @@ -377,6 +386,73 @@ impl Func { ) } + #[cfg(feature = "component-model-async")] + fn call_raw_async< + 'a, + T: Send, + Params, + Return: Send + Sync + 'static, + LowerParams, + LowerReturn, + Lower: FnOnce( + &mut LowerContext, + &Params, + InterfaceType, + &mut MaybeUninit, + ) -> Result<()> + + Send + + Sync, + Lift: FnOnce(&mut LiftContext, InterfaceType, &LowerReturn) -> Result + Send + Sync, + >( + &self, + store: StoreContextMut<'a, T>, + params: Params, + lower: Lower, + lift: Lift, + ) -> Result<(Return, StoreContextMut<'a, T>)> + where + LowerParams: Copy, + LowerReturn: Copy, + { + use crate::component::concurrent; + + let me = self.0; + let FuncData { + export, + component_instance, + .. + } = store.0[self.0]; + let callback = store.0[me] + .options + .callback + .expect("todo: support callback-less async exports"); + + // Note that we smuggle the params through as raw pointers to avoid + // requiring `Params: Send + Sync + 'static` bounds on this function, + // which would prevent passing references as parameters. Technically, + // we don't need to do that for the return type, but we do it anyway for + // symmetry. + // + // This is only safe because `concurrent::call` will either consume or + // drop the contexts before returning. + concurrent::call::<_, LowerParams, _>( + store, + lower_params::, + concurrent::LiftLowerContext { + pointer: &std::mem::ManuallyDrop::new((me, params, lower)) as *const _ as _, + dropper: drop_context::<(Stored, Params, Lower)>, + }, + lift_results::, + concurrent::LiftLowerContext { + pointer: &std::mem::ManuallyDrop::new((me, lift)) as *const _ as _, + dropper: drop_context::<(Stored, Lift)>, + }, + callback, + export.func_ref, + component_instance, + ) + } + /// Invokes the underlying wasm function, lowering arguments and lifting the /// result. /// @@ -387,7 +463,7 @@ impl Func { /// happening. fn call_raw( &self, - store: &mut StoreContextMut<'_, T>, + mut store: StoreContextMut<'_, T>, params: &Params, lower: impl FnOnce( &mut LowerContext<'_, T>, @@ -466,7 +542,7 @@ impl Func { // on the correctness of this module and `ComponentType` // implementations, hence `ComponentType` being an `unsafe` trait. crate::Func::call_unchecked_raw( - store, + &mut store, export.func_ref, space.as_mut_ptr().cast(), mem::size_of_val(space) / mem::size_of::(), @@ -686,3 +762,112 @@ impl Func { Ok(()) } } + +#[cfg(feature = "component-model-async")] +fn drop_context(pointer: *mut u8) { + drop(unsafe { std::ptr::read(pointer as *const T) }) +} + +#[cfg(feature = "component-model-async")] +fn lower_params< + Params, + LowerParams, + T, + F: FnOnce( + &mut LowerContext, + &Params, + InterfaceType, + &mut MaybeUninit, + ) -> Result<()> + + Send + + Sync, +>( + context: crate::component::concurrent::LiftLowerContext, + store: *mut dyn crate::vm::VMStore, + lowered: &mut [MaybeUninit], +) -> Result<()> { + use crate::component::storage::slice_to_storage_mut; + + let (me, params, lower) = unsafe { + std::ptr::read( + std::mem::ManuallyDrop::new(context).pointer as *const (Stored, Params, F), + ) + }; + + let mut store = unsafe { StoreContextMut(&mut *store.cast()) }; + let FuncData { + options, + instance, + component_instance, + ty, + .. + } = store.0[me]; + + let instance = store.0[instance.0].as_ref().unwrap(); + let types = instance.component_types().clone(); + let instance_ptr = instance.instance_ptr(); + let mut flags = instance.instance().instance_flags(component_instance); + + unsafe { + if !flags.may_enter() { + bail!(crate::Trap::CannotEnterComponent); + } + flags.set_may_enter(false); + + flags.set_may_leave(false); + let mut cx = LowerContext::new(store.as_context_mut(), &options, &types, instance_ptr); + cx.enter_call(); + let result = lower( + &mut cx, + ¶ms, + InterfaceType::Tuple(types[ty].params), + slice_to_storage_mut(lowered), + ); + flags.set_may_leave(true); + result?; + Ok(()) + } +} + +#[cfg(feature = "component-model-async")] +fn lift_results< + Return: Send + Sync + 'static, + LowerReturn, + T, + F: FnOnce(&mut LiftContext, InterfaceType, &LowerReturn) -> Result + Send + Sync, +>( + context: crate::component::concurrent::LiftLowerContext, + store: *mut dyn crate::vm::VMStore, + lowered: &[ValRaw], +) -> Result>> { + use crate::component::storage::slice_to_storage; + + let (me, lift) = unsafe { + std::ptr::read(std::mem::ManuallyDrop::new(context).pointer as *const (Stored, F)) + }; + + let store = unsafe { StoreContextMut::(&mut *store.cast()) }; + let FuncData { + options, + instance, + component_instance, + ty, + .. + } = store.0[me]; + + let instance = store.0[instance.0].as_ref().unwrap(); + let types = instance.component_types().clone(); + let instance_ptr = instance.instance_ptr(); + let mut flags = instance.instance().instance_flags(component_instance); + + store.0[me].post_return_arg = Some(ValRaw::i32(0)); + + unsafe { + flags.set_needs_post_return(true); + Ok(Some(Box::new(lift( + &mut LiftContext::new(store.0, &options, &types, instance_ptr), + InterfaceType::Tuple(types[ty].results), + slice_to_storage(lowered), + )?) as Box)) + } +} diff --git a/crates/wasmtime/src/runtime/component/func/host.rs b/crates/wasmtime/src/runtime/component/func/host.rs index 0433e87252e6..cceb8b9f6461 100644 --- a/crates/wasmtime/src/runtime/component/func/host.rs +++ b/crates/wasmtime/src/runtime/component/func/host.rs @@ -1,3 +1,4 @@ +use crate::component::concurrent; use crate::component::func::{LiftContext, LowerContext, Options}; use crate::component::matching::InstanceType; use crate::component::storage::slice_to_storage_mut; @@ -12,11 +13,15 @@ use alloc::sync::Arc; use core::any::Any; use core::mem::{self, MaybeUninit}; use core::ptr::NonNull; +use std::future::Future; use wasmtime_environ::component::{ - CanonicalAbiInfo, ComponentTypes, InterfaceType, StringEncoding, TypeFuncIndex, - MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, + CanonicalAbiInfo, ComponentTypes, InterfaceType, RuntimeComponentInstanceIndex, StringEncoding, + TypeFuncIndex, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, }; +#[cfg(feature = "component-model-async")] +use crate::runtime::vm::SendSyncPtr; + pub struct HostFunc { entrypoint: VMLoweringCallee, typecheck: Box) -> Result<()>) + Send + Sync>, @@ -28,9 +33,23 @@ impl HostFunc { where F: Fn(StoreContextMut, P) -> Result + Send + Sync + 'static, P: ComponentNamedList + Lift + 'static, - R: ComponentNamedList + Lower + 'static, + R: ComponentNamedList + Lower + Send + Sync + 'static, + { + Self::from_concurrent(move |store, params| { + let result = func(store, params); + async move { concurrent::for_any(move |_| result) } + }) + } + + pub(crate) fn from_concurrent(func: F) -> Arc + where + N: FnOnce(StoreContextMut) -> Result + 'static, + FN: Future + Send + Sync + 'static, + F: Fn(StoreContextMut, P) -> FN + Send + Sync + 'static, + P: ComponentNamedList + Lift + 'static, + R: ComponentNamedList + Lower + Send + Sync + 'static, { - let entrypoint = Self::entrypoint::; + let entrypoint = Self::entrypoint::; Arc::new(HostFunc { entrypoint, typecheck: Box::new(typecheck::), @@ -38,35 +57,46 @@ impl HostFunc { }) } - extern "C" fn entrypoint( + extern "C" fn entrypoint( cx: *mut VMOpaqueContext, data: *mut u8, ty: TypeFuncIndex, + caller_instance: RuntimeComponentInstanceIndex, flags: InstanceFlags, memory: *mut VMMemoryDefinition, realloc: *mut VMFuncRef, string_encoding: StringEncoding, + async_: bool, storage: *mut MaybeUninit, storage_len: usize, ) where - F: Fn(StoreContextMut, P) -> Result, + N: FnOnce(StoreContextMut) -> Result + 'static, + FN: Future + Send + Sync + 'static, + F: Fn(StoreContextMut, P) -> FN + Send + Sync + 'static, P: ComponentNamedList + Lift + 'static, - R: ComponentNamedList + Lower + 'static, + R: ComponentNamedList + Lower + Send + Sync + 'static, { - let data = data as *const F; + struct Ptr(*const F); + + unsafe impl Sync for Ptr {} + unsafe impl Send for Ptr {} + + let data = Ptr(data as *const F); unsafe { call_host_and_handle_result::(cx, |instance, types, store| { - call_host::<_, _, _, _>( + call_host( instance, types, store, ty, + caller_instance, flags, memory, realloc, string_encoding, + async_, core::slice::from_raw_parts_mut(storage, storage_len), - |store, args| (*data)(store, args), + move |store, args| (*data.0)(store, args), ) }) } @@ -132,22 +162,26 @@ where /// This function is in general `unsafe` as the validity of all the parameters /// must be upheld. Generally that's done by ensuring this is only called from /// the select few places it's intended to be called from. -unsafe fn call_host( +unsafe fn call_host( instance: *mut ComponentInstance, types: &Arc, mut cx: StoreContextMut<'_, T>, ty: TypeFuncIndex, + caller_instance: RuntimeComponentInstanceIndex, mut flags: InstanceFlags, memory: *mut VMMemoryDefinition, realloc: *mut VMFuncRef, string_encoding: StringEncoding, + async_: bool, storage: &mut [MaybeUninit], closure: F, ) -> Result<()> where + N: FnOnce(StoreContextMut) -> Result + 'static, + FN: Future + Send + Sync + 'static, + F: Fn(StoreContextMut, Params) -> FN + 'static, Params: Lift, - Return: Lower, - F: FnOnce(StoreContextMut<'_, T>, Params) -> Result, + Return: Lower + Send + Sync + 'static, { /// Representation of arguments to this function when a return pointer is in /// use, namely the argument list is followed by a single value which is the @@ -172,6 +206,8 @@ where NonNull::new(memory), NonNull::new(realloc), string_encoding, + async_, + None, ); // Perform a dynamic check that this instance can indeed be left. Exiting @@ -185,39 +221,87 @@ where let param_tys = InterfaceType::Tuple(ty.params); let result_tys = InterfaceType::Tuple(ty.results); - // There's a 2x2 matrix of whether parameters and results are stored on the - // stack or on the heap. Each of the 4 branches here have a different - // representation of the storage of arguments/returns. - // - // Also note that while four branches are listed here only one is taken for - // any particular `Params` and `Return` combination. This should be - // trivially DCE'd by LLVM. Perhaps one day with enough const programming in - // Rust we can make monomorphizations of this function codegen only one - // branch, but today is not that day. - let mut storage: Storage<'_, Params, Return> = if Params::flatten_count() <= MAX_FLAT_PARAMS { - if Return::flatten_count() <= MAX_FLAT_RESULTS { - Storage::Direct(slice_to_storage_mut(storage)) - } else { - Storage::ResultsIndirect(slice_to_storage_mut(storage).assume_init_ref()) + if async_ { + #[cfg(feature = "component-model-async")] + { + const STATUS_PARAMS_READ: u32 = 1; + const STATUS_DONE: u32 = 3; + + let paramptr = storage[0].assume_init(); + let retptr = storage[1].assume_init(); + + let params = { + let lift = &mut LiftContext::new(cx.0, &options, types, instance); + lift.enter_call(); + let ptr = validate_inbounds::(lift.memory(), ¶mptr)?; + Params::load(lift, param_tys, &lift.memory()[ptr..][..Params::SIZE32])? + }; + + let future = closure(cx.as_context_mut(), params); + + let task = concurrent::first_poll(cx.as_context_mut(), future, caller_instance, { + let types = types.clone(); + let instance = SendSyncPtr::new(NonNull::new(instance).unwrap()); + move |cx, ret: Return| { + let mut lower = LowerContext::new(cx, &options, &types, instance.as_ptr()); + let ptr = validate_inbounds::(lower.as_slice_mut(), &retptr)?; + ret.store(&mut lower, result_tys, ptr) + } + })?; + + let status = if let Some(task) = task { + (STATUS_PARAMS_READ << 30) | task + } else { + STATUS_DONE << 30 + }; + + storage[0] = MaybeUninit::new(ValRaw::i32(status as i32)); + } + #[cfg(not(feature = "component-model-async"))] + { + unreachable!( + "async-lowered imports should have failed validation \ + when `component-model-async` feature disabled" + ); } } else { - if Return::flatten_count() <= MAX_FLAT_RESULTS { - Storage::ParamsIndirect(slice_to_storage_mut(storage)) + // There's a 2x2 matrix of whether parameters and results are stored on the + // stack or on the heap. Each of the 4 branches here have a different + // representation of the storage of arguments/returns. + // + // Also note that while four branches are listed here only one is taken for + // any particular `Params` and `Return` combination. This should be + // trivially DCE'd by LLVM. Perhaps one day with enough const programming in + // Rust we can make monomorphizations of this function codegen only one + // branch, but today is not that day. + let mut storage: Storage<'_, Params, Return> = if Params::flatten_count() <= MAX_FLAT_PARAMS + { + if Return::flatten_count() <= MAX_FLAT_RESULTS { + Storage::Direct(slice_to_storage_mut(storage)) + } else { + Storage::ResultsIndirect(slice_to_storage_mut(storage).assume_init_ref()) + } } else { - Storage::Indirect(slice_to_storage_mut(storage).assume_init_ref()) - } - }; - let mut lift = LiftContext::new(cx.0, &options, types, instance); - lift.enter_call(); - let params = storage.lift_params(&mut lift, param_tys)?; + if Return::flatten_count() <= MAX_FLAT_RESULTS { + Storage::ParamsIndirect(slice_to_storage_mut(storage)) + } else { + Storage::Indirect(slice_to_storage_mut(storage).assume_init_ref()) + } + }; + let mut lift = LiftContext::new(cx.0, &options, types, instance); + lift.enter_call(); + let params = storage.lift_params(&mut lift, param_tys)?; - let ret = closure(cx.as_context_mut(), params)?; - flags.set_may_leave(false); - let mut lower = LowerContext::new(cx, &options, types, instance); - storage.lower_results(&mut lower, result_tys, ret)?; - flags.set_may_leave(true); + let future = closure(cx.as_context_mut(), params); + + let (ret, cx) = concurrent::poll_and_block(cx, future, caller_instance)?; - lower.exit_call()?; + flags.set_may_leave(false); + let mut lower = LowerContext::new(cx, &options, types, instance); + storage.lower_results(&mut lower, result_tys, ret)?; + flags.set_may_leave(true); + lower.exit_call()?; + } return Ok(()); @@ -272,7 +356,7 @@ where } } -fn validate_inbounds(memory: &[u8], ptr: &ValRaw) -> Result { +pub(crate) fn validate_inbounds(memory: &[u8], ptr: &ValRaw) -> Result { // FIXME: needs memory64 support let ptr = usize::try_from(ptr.get_u32()).err2anyhow()?; if ptr % usize::try_from(T::ALIGN32).err2anyhow()? != 0 { @@ -320,21 +404,29 @@ unsafe fn call_host_dynamic( types: &Arc, mut store: StoreContextMut<'_, T>, ty: TypeFuncIndex, + _caller_instance: RuntimeComponentInstanceIndex, mut flags: InstanceFlags, memory: *mut VMMemoryDefinition, realloc: *mut VMFuncRef, string_encoding: StringEncoding, + async_: bool, storage: &mut [MaybeUninit], closure: F, ) -> Result<()> where F: FnOnce(StoreContextMut<'_, T>, &[Val], &mut [Val]) -> Result<()>, { + if async_ { + todo!("support async-lowered imports in `dynamic_entrypoint`"); + } + let options = Options::new( store.0.id(), NonNull::new(memory), NonNull::new(realloc), string_encoding, + async_, + None, ); // Perform a dynamic check that this instance can indeed be left. Exiting @@ -429,10 +521,12 @@ extern "C" fn dynamic_entrypoint( cx: *mut VMOpaqueContext, data: *mut u8, ty: TypeFuncIndex, + caller_instance: RuntimeComponentInstanceIndex, flags: InstanceFlags, memory: *mut VMMemoryDefinition, realloc: *mut VMFuncRef, string_encoding: StringEncoding, + async_: bool, storage: *mut MaybeUninit, storage_len: usize, ) where @@ -446,10 +540,12 @@ extern "C" fn dynamic_entrypoint( types, store, ty, + caller_instance, flags, memory, realloc, string_encoding, + async_, core::slice::from_raw_parts_mut(storage, storage_len), |store, params, results| (*data)(store, params, results), ) diff --git a/crates/wasmtime/src/runtime/component/func/options.rs b/crates/wasmtime/src/runtime/component/func/options.rs index 83accf3dee97..d63c1c92e203 100644 --- a/crates/wasmtime/src/runtime/component/func/options.rs +++ b/crates/wasmtime/src/runtime/component/func/options.rs @@ -43,6 +43,11 @@ pub struct Options { /// /// This defaults to utf-8 but can be changed if necessary. string_encoding: StringEncoding, + + async_: bool, + + #[cfg_attr(not(feature = "component-model-async"), allow(unused))] + pub(crate) callback: Option>, } // The `Options` structure stores raw pointers but they're never used unless a @@ -66,12 +71,16 @@ impl Options { memory: Option>, realloc: Option>, string_encoding: StringEncoding, + async_: bool, + callback: Option>, ) -> Options { Options { store_id, memory, realloc, string_encoding, + async_, + callback, } } @@ -163,6 +172,11 @@ impl Options { pub fn store_id(&self) -> StoreId { self.store_id } + + /// Returns whether this lifting or lowering uses the async ABI. + pub fn async_(&self) -> bool { + self.async_ + } } /// A helper structure which is a "package" of the context used during lowering @@ -196,7 +210,7 @@ pub struct LowerContext<'a, T> { /// into. /// /// This pointer is required to be owned by the `store` provided. - instance: *mut ComponentInstance, + pub(crate) instance: *mut ComponentInstance, } #[doc(hidden)] @@ -402,7 +416,7 @@ pub struct LiftContext<'a> { memory: Option<&'a [u8]>, - instance: *mut ComponentInstance, + pub(crate) instance: *mut ComponentInstance, host_table: &'a mut ResourceTable, host_resource_data: &'a mut HostResourceData, diff --git a/crates/wasmtime/src/runtime/component/func/typed.rs b/crates/wasmtime/src/runtime/component/func/typed.rs index a2ab99c55ba8..bd57f986fa2f 100644 --- a/crates/wasmtime/src/runtime/component/func/typed.rs +++ b/crates/wasmtime/src/runtime/component/func/typed.rs @@ -154,7 +154,14 @@ where /// Panics if this is called on a function in an asynchronous store. This /// only works with functions defined within a synchronous store. Also /// panics if `store` does not own this function. - pub fn call(&self, store: impl AsContextMut, params: Params) -> Result { + pub fn call( + &self, + store: impl AsContextMut, + params: Params, + ) -> Result + where + Return: Send + Sync + 'static, + { assert!( !store.as_context().async_support(), "must use `call_async` when async support is enabled on the config" @@ -170,28 +177,95 @@ where /// only works with functions defined within an asynchronous store. Also /// panics if `store` does not own this function. #[cfg(feature = "async")] - pub async fn call_async( - &self, + pub async fn call_async( + self, mut store: impl AsContextMut, params: Params, ) -> Result where - T: Send, Params: Send + Sync, - Return: Send + Sync, + Return: Send + Sync + 'static, { - let mut store = store.as_context_mut(); + let store = store.as_context_mut(); assert!( store.0.async_support(), "cannot use `call_async` when async support is not enabled on the config" ); - store - .on_fiber(|store| self.call_impl(store, params)) + #[cfg(feature = "component-model-async")] + { + // TODO: do we need to return the store here due to the possible + // invalidation of the reference we were passed? + crate::component::concurrent::on_fiber(store, self.func, move |store| { + self.call_impl(store, params) + }) .await? + .0 + } + #[cfg(not(feature = "component-model-async"))] + { + let mut store = store; + store + .on_fiber(|store| self.call_impl(store, params)) + .await? + } } - fn call_impl(&self, mut store: impl AsContextMut, params: Params) -> Result { - let store = &mut store.as_context_mut(); + fn call_impl( + &self, + mut store: impl AsContextMut, + params: Params, + ) -> Result + where + Return: Send + Sync + 'static, + { + let store = store.as_context_mut(); + + if store.0[self.func.0].options.async_() { + #[cfg(feature = "component-model-async")] + { + return Ok(if Params::flatten_count() <= MAX_FLAT_PARAMS { + if Return::flatten_count() <= MAX_FLAT_PARAMS { + self.func.call_raw_async( + store, + params, + Self::lower_stack_args, + Self::lift_stack_result, + ) + } else { + self.func.call_raw_async( + store, + params, + Self::lower_stack_args, + Self::lift_heap_result_guest, + ) + } + } else { + if Return::flatten_count() <= MAX_FLAT_PARAMS { + self.func.call_raw_async( + store, + params, + Self::lower_heap_args_guest, + Self::lift_stack_result, + ) + } else { + self.func.call_raw_async( + store, + params, + Self::lower_heap_args_guest, + Self::lift_heap_result_guest, + ) + } + }? + .0); + } + #[cfg(not(feature = "component-model-async"))] + { + bail!( + "must enable the `component-model-async` feature to call async-lifted exports" + ) + } + } + // Note that this is in theory simpler than it might read at this time. // Here we're doing a runtime dispatch on the `flatten_count` for the // params/results to see whether they're inbounds. This creates 4 cases @@ -266,8 +340,6 @@ where ty: InterfaceType, dst: &mut MaybeUninit, ) -> Result<()> { - assert!(Params::flatten_count() > MAX_FLAT_PARAMS); - // Memory must exist via validation if the arguments are stored on the // heap, so we can create a `MemoryMut` at this point. Afterwards // `realloc` is used to allocate space for all the arguments and then @@ -302,7 +374,6 @@ where ty: InterfaceType, dst: &Return::Lower, ) -> Result { - assert!(Return::flatten_count() <= MAX_FLAT_RESULTS); Return::lift(cx, ty, dst) } @@ -328,6 +399,30 @@ where Return::load(cx, ty, bytes) } + #[cfg(feature = "component-model-async")] + fn lower_heap_args_guest( + cx: &mut LowerContext<'_, T>, + params: &Params, + ty: InterfaceType, + dst: &mut MaybeUninit, + ) -> Result<()> { + let ptr = + crate::component::func::validate_inbounds::(cx.as_slice_mut(), &unsafe { + dst.assume_init() + })?; + params.store(cx, ty, ptr) + } + + #[cfg(feature = "component-model-async")] + fn lift_heap_result_guest( + cx: &mut LiftContext<'_>, + ty: InterfaceType, + dst: &ValRaw, + ) -> Result { + _ = (cx, ty, dst); + todo!() + } + /// See [`Func::post_return`] pub fn post_return(&self, store: impl AsContextMut) -> Result<()> { self.func.post_return(store) @@ -1526,7 +1621,7 @@ pub struct WasmList { } impl WasmList { - fn new( + pub(crate) fn new( ptr: usize, len: usize, cx: &mut LiftContext<'_>, @@ -2484,6 +2579,9 @@ pub fn desc(ty: &InterfaceType) -> &'static str { InterfaceType::Enum(_) => "enum", InterfaceType::Own(_) => "owned resource", InterfaceType::Borrow(_) => "borrowed resource", + InterfaceType::Future(_) => "future", + InterfaceType::Stream(_) => "stream", + InterfaceType::ErrorContext(_) => "error-context", } } diff --git a/crates/wasmtime/src/runtime/component/instance.rs b/crates/wasmtime/src/runtime/component/instance.rs index bb7c337eb842..28bdd9d4333b 100644 --- a/crates/wasmtime/src/runtime/component/instance.rs +++ b/crates/wasmtime/src/runtime/component/instance.rs @@ -1,3 +1,4 @@ +use crate::component::concurrent; use crate::component::func::HostFunc; use crate::component::matching::InstanceType; use crate::component::{ @@ -512,9 +513,39 @@ impl<'a> Instantiator<'a> { } } - fn run(&mut self, store: &mut StoreContextMut<'_, T>) -> Result<()> { + fn run(&mut self, store: &mut StoreContextMut<'_, T>) -> Result<()> { let env_component = self.component.env_component(); + self.data.state.set_async_callbacks( + concurrent::task_backpressure::, + concurrent::task_return::, + concurrent::task_wait::, + concurrent::task_poll::, + concurrent::task_yield::, + concurrent::subtask_drop::, + concurrent::async_enter::, + concurrent::async_exit::, + concurrent::future_new::, + concurrent::future_write::, + concurrent::future_read::, + concurrent::future_cancel_write::, + concurrent::future_cancel_read::, + concurrent::future_close_writable::, + concurrent::future_close_readable::, + concurrent::stream_new::, + concurrent::stream_write::, + concurrent::stream_read::, + concurrent::stream_cancel_write::, + concurrent::stream_cancel_read::, + concurrent::stream_close_writable::, + concurrent::stream_close_readable::, + concurrent::flat_stream_write::, + concurrent::flat_stream_read::, + concurrent::error_context_new::, + concurrent::error_context_debug_message::, + concurrent::error_context_drop::, + ); + // Before all initializers are processed configure all destructors for // host-defined resources. No initializer will correspond to these and // it's required to happen before they're needed, so execute this first. @@ -607,6 +638,10 @@ impl<'a> Instantiator<'a> { self.extract_realloc(store.0, realloc) } + GlobalInitializer::ExtractCallback(callback) => { + self.extract_callback(store.0, callback) + } + GlobalInitializer::ExtractPostReturn(post_return) => { self.extract_post_return(store.0, post_return) } @@ -654,6 +689,16 @@ impl<'a> Instantiator<'a> { self.data.state.set_runtime_realloc(realloc.index, func_ref); } + fn extract_callback(&mut self, store: &mut StoreOpaque, callback: &ExtractCallback) { + let func_ref = match self.data.lookup_def(store, &callback.def) { + crate::runtime::vm::Export::Function(f) => f.func_ref, + _ => unreachable!(), + }; + self.data + .state + .set_runtime_callback(callback.index, func_ref); + } + fn extract_post_return(&mut self, store: &mut StoreOpaque, post_return: &ExtractPostReturn) { let func_ref = match self.data.lookup_def(store, &post_return.def) { crate::runtime::vm::Export::Function(f) => f.func_ref, @@ -796,7 +841,10 @@ impl InstancePre { /// Performs the instantiation process into the store specified. // // TODO: needs more docs - pub fn instantiate(&self, store: impl AsContextMut) -> Result { + pub fn instantiate(&self, store: impl AsContextMut) -> Result + where + T: 'static, + { assert!( !store.as_context().async_support(), "must use async instantiation when async support is enabled" @@ -814,7 +862,7 @@ impl InstancePre { mut store: impl AsContextMut, ) -> Result where - T: Send, + T: Send + 'static, { let mut store = store.as_context_mut(); assert!( @@ -824,7 +872,10 @@ impl InstancePre { store.on_fiber(|store| self.instantiate_impl(store)).await? } - fn instantiate_impl(&self, mut store: impl AsContextMut) -> Result { + fn instantiate_impl(&self, mut store: impl AsContextMut) -> Result + where + T: 'static, + { let mut store = store.as_context_mut(); store .engine() diff --git a/crates/wasmtime/src/runtime/component/linker.rs b/crates/wasmtime/src/runtime/component/linker.rs index 31a3d5254884..130bd97ff2d6 100644 --- a/crates/wasmtime/src/runtime/component/linker.rs +++ b/crates/wasmtime/src/runtime/component/linker.rs @@ -265,7 +265,10 @@ impl Linker { &self, store: impl AsContextMut, component: &Component, - ) -> Result { + ) -> Result + where + T: 'static, + { assert!( !store.as_context().async_support(), "must use async instantiation when async support is enabled" @@ -290,7 +293,7 @@ impl Linker { component: &Component, ) -> Result where - T: Send, + T: Send + 'static, { assert!( store.as_context().async_support(), @@ -404,7 +407,7 @@ impl LinkerInstance<'_, T> { where F: Fn(StoreContextMut, Params) -> Result + Send + Sync + 'static, Params: ComponentNamedList + Lift + 'static, - Return: ComponentNamedList + Lower + 'static, + Return: ComponentNamedList + Lower + Send + Sync + 'static, { self.insert(name, Definition::Func(HostFunc::from_closure(func)))?; Ok(()) @@ -425,20 +428,67 @@ impl LinkerInstance<'_, T> { + Sync + 'static, Params: ComponentNamedList + Lift + 'static, - Return: ComponentNamedList + Lower + 'static, + Return: ComponentNamedList + Lower + Send + Sync + 'static, { assert!( self.engine.config().async_support, "cannot use `func_wrap_async` without enabling async support in the config" ); + let ff = move |mut store: StoreContextMut<'_, T>, params: Params| -> Result { - let async_cx = store.as_context_mut().0.async_cx().expect("async cx"); - let mut future = Pin::from(f(store.as_context_mut(), params)); - unsafe { async_cx.block_on(future.as_mut()) }? + #[cfg(feature = "component-model-async")] + { + let async_cx = crate::component::concurrent::AsyncCx::new(&mut store); + let mut future = Pin::from(f(store.as_context_mut(), params)); + unsafe { async_cx.block_on::(future.as_mut(), None) }?.0 + } + #[cfg(not(feature = "component-model-async"))] + { + let async_cx = store.as_context_mut().0.async_cx().expect("async cx"); + let mut future = Pin::from(f(store.as_context_mut(), params)); + unsafe { async_cx.block_on(future.as_mut()) }? + } }; self.func_wrap(name, ff) } + /// Defines a new host-provided async function into this [`Linker`]. + /// + /// This allows the caller to register host functions with the + /// LinkerInstance such that multiple calls to such functions can run + /// concurrently. This isn't possible with the existing func_wrap_async + /// method because it takes a function which returns a future that owns a + /// unique reference to the Store, meaning the Store can't be used for + /// anything else until the future resolves. + /// + /// Ideally, we'd have a way to thread a `StoreContextMut` through an + /// arbitrary `Future` such that it has access to the `Store` only while + /// being polled (i.e. between, but not across, await points). However, + /// there's currently no way to express that in async Rust, so we make do + /// with a more awkward scheme: each function registered using + /// `func_wrap_concurrent` gets access to the `Store` twice: once before + /// doing any concurrent operations (i.e. before awaiting) and once + /// afterward. This allows multiple calls to proceed concurrently without + /// any one of them monopolizing the store. + #[cfg(feature = "async")] + #[cfg(feature = "component-model-async")] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + pub fn func_wrap_concurrent(&mut self, name: &str, f: F) -> Result<()> + where + N: FnOnce(StoreContextMut) -> Result + 'static, + FN: Future + Send + Sync + 'static, + F: Fn(StoreContextMut, Params) -> FN + Send + Sync + 'static, + Params: ComponentNamedList + Lift + 'static, + Return: ComponentNamedList + Lower + Send + Sync + 'static, + { + assert!( + self.engine.config().async_support, + "cannot use `func_wrap_concurrent` without enabling async support in the config" + ); + self.insert(name, Definition::Func(HostFunc::from_concurrent(f)))?; + Ok(()) + } + /// Define a new host-provided function using dynamically typed values. /// /// The `name` provided is the name of the function to define and the @@ -554,6 +604,7 @@ impl LinkerInstance<'_, T> { #[cfg(feature = "async")] pub fn func_new_async(&mut self, name: &str, f: F) -> Result<()> where + T: 'static, F: for<'a> Fn( StoreContextMut<'a, T>, &'a [Val], diff --git a/crates/wasmtime/src/runtime/component/matching.rs b/crates/wasmtime/src/runtime/component/matching.rs index 4222daa6dc62..d6cac001cb9a 100644 --- a/crates/wasmtime/src/runtime/component/matching.rs +++ b/crates/wasmtime/src/runtime/component/matching.rs @@ -1,5 +1,6 @@ use crate::component::func::HostFunc; use crate::component::linker::{Definition, Strings}; +use crate::component::types::{FutureType, StreamType}; use crate::component::ResourceType; use crate::prelude::*; use crate::runtime::vm::component::ComponentInstance; @@ -9,7 +10,7 @@ use alloc::sync::Arc; use core::any::Any; use wasmtime_environ::component::{ ComponentTypes, NameMap, ResourceIndex, TypeComponentInstance, TypeDef, TypeFuncIndex, - TypeModule, TypeResourceTableIndex, + TypeFutureTableIndex, TypeModule, TypeResourceTableIndex, TypeStreamTableIndex, }; use wasmtime_environ::PrimaryMap; @@ -199,6 +200,14 @@ impl<'a> InstanceType<'a> { .copied() .unwrap_or_else(|| ResourceType::uninstantiated(&self.types, index)) } + + pub fn future_type(&self, index: TypeFutureTableIndex) -> FutureType { + FutureType::from(self.types[index].ty, self) + } + + pub fn stream_type(&self, index: TypeStreamTableIndex) -> StreamType { + StreamType::from(self.types[index].ty, self) + } } /// Small helper method to downcast an `Arc` borrow into a borrow of a concrete diff --git a/crates/wasmtime/src/runtime/component/mod.rs b/crates/wasmtime/src/runtime/component/mod.rs index eaf4307e35a8..26bbe65a1fda 100644 --- a/crates/wasmtime/src/runtime/component/mod.rs +++ b/crates/wasmtime/src/runtime/component/mod.rs @@ -101,6 +101,8 @@ #![allow(rustdoc::redundant_explicit_links)] mod component; +#[cfg(feature = "component-model-async")] +pub(crate) mod concurrent; mod func; mod instance; mod linker; @@ -112,6 +114,8 @@ mod store; pub mod types; mod values; pub use self::component::{Component, ComponentExportIndex}; +#[cfg(feature = "component-model-async")] +pub use self::concurrent::{for_any, future, stream, ErrorContext, FutureReader, StreamReader}; pub use self::func::{ ComponentNamedList, ComponentType, Func, Lift, Lower, TypedFunc, WasmList, WasmStr, }; @@ -671,3 +675,347 @@ pub mod bindgen_examples; #[cfg(not(any(docsrs, test, doctest)))] #[doc(hidden)] pub mod bindgen_examples {} + +#[cfg(not(feature = "component-model-async"))] +pub(crate) mod concurrent { + use { + crate::{ + vm::{VMFuncRef, VMMemoryDefinition, VMOpaqueContext}, + AsContextMut, StoreContextMut, ValRaw, + }, + anyhow::Result, + once_cell::sync::Lazy, + std::{ + future::Future, + mem::MaybeUninit, + pin::pin, + sync::Arc, + task::{Context, Poll, Wake, Waker}, + }, + wasmtime_environ::component::{ + RuntimeComponentInstanceIndex, StringEncoding, TypeErrorContextTableIndex, + TypeFutureTableIndex, TypeStreamTableIndex, + }, + }; + + pub(crate) async fn poll<'a, T: Send>( + _store: StoreContextMut<'a, T>, + ) -> Result> { + anyhow::bail!( + "must enable the `component-model-async` feature \ + to use `wasmtime::component::concurrent::poll`" + ) + } + + pub(crate) async fn poll_until<'a, T: Send, U>( + _store: StoreContextMut<'a, T>, + _future: impl Future, + ) -> Result<(StoreContextMut<'a, T>, U)> { + anyhow::bail!( + "must enable the `component-model-async` feature \ + to use `wasmtime::component::concurrent::poll_until`" + ) + } + + pub fn for_any(fun: F) -> F + where + F: FnOnce(StoreContextMut) -> R + 'static, + R: 'static, + { + fun + } + + fn dummy_waker() -> Waker { + struct DummyWaker; + + impl Wake for DummyWaker { + fn wake(self: Arc) {} + } + + static WAKER: Lazy> = Lazy::new(|| Arc::new(DummyWaker)); + + WAKER.clone().into() + } + + pub(crate) fn poll_and_block<'a, T, R: Send + Sync + 'static>( + mut store: StoreContextMut<'a, T>, + future: impl Future) -> Result + 'static> + + Send + + Sync + + 'static, + _caller_instance: RuntimeComponentInstanceIndex, + ) -> Result<(R, StoreContextMut<'a, T>)> { + match pin!(future).poll(&mut Context::from_waker(&dummy_waker())) { + Poll::Ready(fun) => { + let result = fun(store.as_context_mut())?; + Ok((result, store)) + } + Poll::Pending => { + unreachable!() + } + } + } + + pub(crate) extern "C" fn task_backpressure( + _cx: *mut VMOpaqueContext, + _caller_instance: RuntimeComponentInstanceIndex, + _enabled: u32, + ) { + unreachable!() + } + + pub(crate) extern "C" fn task_return( + _cx: *mut VMOpaqueContext, + _storage: *mut MaybeUninit, + _storage_len: usize, + ) { + unreachable!() + } + + pub(crate) extern "C" fn task_wait( + _cx: *mut VMOpaqueContext, + _async_: bool, + _memory: *mut VMMemoryDefinition, + _payload: u32, + ) -> u32 { + unreachable!() + } + + pub(crate) extern "C" fn task_poll( + _cx: *mut VMOpaqueContext, + _async_: bool, + _memory: *mut VMMemoryDefinition, + _payload: u32, + ) -> u32 { + unreachable!() + } + + pub(crate) extern "C" fn task_yield(_cx: *mut VMOpaqueContext, _async_: bool) { + unreachable!() + } + + pub(crate) extern "C" fn subtask_drop( + _cx: *mut VMOpaqueContext, + _caller_instance: RuntimeComponentInstanceIndex, + _task_id: u32, + ) { + unreachable!() + } + + pub(crate) extern "C" fn async_enter( + _cx: *mut VMOpaqueContext, + _start: *mut VMFuncRef, + _return_: *mut VMFuncRef, + _caller_instance: RuntimeComponentInstanceIndex, + _params: u32, + _results: u32, + ) { + unreachable!() + } + + pub(crate) extern "C" fn async_exit( + _cx: *mut VMOpaqueContext, + _callback: *mut VMFuncRef, + _callee: *mut VMFuncRef, + _callee_instance: RuntimeComponentInstanceIndex, + _param_count: u32, + _result_count: u32, + _flags: u32, + ) -> u32 { + unreachable!() + } + + pub(crate) extern "C" fn future_new( + _vmctx: *mut VMOpaqueContext, + _ty: TypeFutureTableIndex, + ) -> u32 { + unreachable!() + } + + pub(crate) extern "C" fn future_write( + _vmctx: *mut VMOpaqueContext, + _memory: *mut VMMemoryDefinition, + _realloc: *mut VMFuncRef, + _string_encoding: StringEncoding, + _ty: TypeFutureTableIndex, + _future: u32, + _address: u32, + ) -> u32 { + unreachable!() + } + + pub(crate) extern "C" fn future_read( + _vmctx: *mut VMOpaqueContext, + _memory: *mut VMMemoryDefinition, + _realloc: *mut VMFuncRef, + _string_encoding: StringEncoding, + _ty: TypeFutureTableIndex, + _future: u32, + _address: u32, + ) -> u32 { + unreachable!() + } + + pub(crate) extern "C" fn future_cancel_write( + _vmctx: *mut VMOpaqueContext, + _ty: TypeFutureTableIndex, + _async_: bool, + _writer: u32, + ) -> u32 { + unreachable!() + } + + pub(crate) extern "C" fn future_cancel_read( + _vmctx: *mut VMOpaqueContext, + _ty: TypeFutureTableIndex, + _async_: bool, + _reader: u32, + ) -> u32 { + unreachable!() + } + + pub(crate) extern "C" fn future_close_writable( + _vmctx: *mut VMOpaqueContext, + _ty: TypeFutureTableIndex, + _writer: u32, + _error: u32, + ) { + unreachable!() + } + + pub(crate) extern "C" fn future_close_readable( + _vmctx: *mut VMOpaqueContext, + _ty: TypeFutureTableIndex, + _reader: u32, + ) { + unreachable!() + } + + pub(crate) extern "C" fn stream_new( + _vmctx: *mut VMOpaqueContext, + _ty: TypeStreamTableIndex, + ) -> u32 { + unreachable!() + } + + pub(crate) extern "C" fn stream_write( + _vmctx: *mut VMOpaqueContext, + _memory: *mut VMMemoryDefinition, + _realloc: *mut VMFuncRef, + _string_encoding: StringEncoding, + _ty: TypeStreamTableIndex, + _stream: u32, + _address: u32, + _count: u32, + ) -> u32 { + unreachable!() + } + + pub(crate) extern "C" fn stream_read( + _vmctx: *mut VMOpaqueContext, + _memory: *mut VMMemoryDefinition, + _realloc: *mut VMFuncRef, + _string_encoding: StringEncoding, + _ty: TypeStreamTableIndex, + _stream: u32, + _address: u32, + _count: u32, + ) -> u32 { + unreachable!() + } + + pub(crate) extern "C" fn stream_cancel_write( + _vmctx: *mut VMOpaqueContext, + _ty: TypeStreamTableIndex, + _async_: bool, + _writer: u32, + ) -> u32 { + unreachable!() + } + + pub(crate) extern "C" fn stream_cancel_read( + _vmctx: *mut VMOpaqueContext, + _ty: TypeStreamTableIndex, + _async_: bool, + _reader: u32, + ) -> u32 { + unreachable!() + } + + pub(crate) extern "C" fn stream_close_writable( + _vmctx: *mut VMOpaqueContext, + _ty: TypeStreamTableIndex, + _writer: u32, + _error: u32, + ) { + unreachable!() + } + + pub(crate) extern "C" fn stream_close_readable( + _vmctx: *mut VMOpaqueContext, + _ty: TypeStreamTableIndex, + _reader: u32, + ) { + unreachable!() + } + + pub(crate) extern "C" fn flat_stream_write( + _vmctx: *mut VMOpaqueContext, + _memory: *mut VMMemoryDefinition, + _realloc: *mut VMFuncRef, + _ty: TypeStreamTableIndex, + _payload_size: u32, + _payload_align: u32, + _stream: u32, + _address: u32, + _count: u32, + ) -> u32 { + unreachable!() + } + + pub(crate) extern "C" fn flat_stream_read( + _vmctx: *mut VMOpaqueContext, + _memory: *mut VMMemoryDefinition, + _realloc: *mut VMFuncRef, + _ty: TypeStreamTableIndex, + _payload_size: u32, + _payload_align: u32, + _stream: u32, + _address: u32, + _count: u32, + ) -> u32 { + unreachable!() + } + + pub(crate) extern "C" fn error_context_new( + _vmctx: *mut VMOpaqueContext, + _memory: *mut VMMemoryDefinition, + _realloc: *mut VMFuncRef, + _string_encoding: StringEncoding, + _ty: TypeErrorContextTableIndex, + _address: u32, + _count: u32, + ) -> u32 { + unreachable!() + } + + pub(crate) extern "C" fn error_context_debug_message( + _vmctx: *mut VMOpaqueContext, + _memory: *mut VMMemoryDefinition, + _realloc: *mut VMFuncRef, + _string_encoding: StringEncoding, + _ty: TypeErrorContextTableIndex, + _handle: u32, + _address: u32, + ) { + unreachable!() + } + + pub(crate) extern "C" fn error_context_drop( + _vmctx: *mut VMOpaqueContext, + _ty: TypeErrorContextTableIndex, + _error: u32, + ) { + unreachable!() + } +} diff --git a/crates/wasmtime/src/runtime/component/storage.rs b/crates/wasmtime/src/runtime/component/storage.rs index 01537b42984d..81655665ac77 100644 --- a/crates/wasmtime/src/runtime/component/storage.rs +++ b/crates/wasmtime/src/runtime/component/storage.rs @@ -41,3 +41,18 @@ pub unsafe fn slice_to_storage_mut(slice: &mut [MaybeUninit]) -> &mut &mut *slice.as_mut_ptr().cast() } + +/// Same as `storage_as_slice`, but in reverse +#[cfg(feature = "component-model-async")] +pub unsafe fn slice_to_storage(slice: &[ValRaw]) -> &T { + assert_raw_slice_compat::(); + + // This is an actual runtime assertion which if performance calls for we may + // need to relax to a debug assertion. This notably tries to ensure that we + // stay within the bounds of the number of actual values given rather than + // reading past the end of an array. This shouldn't actually trip unless + // there's a bug in Wasmtime though. + assert!(mem::size_of_val(slice) >= mem::size_of::()); + + &*slice.as_ptr().cast() +} diff --git a/crates/wasmtime/src/runtime/component/types.rs b/crates/wasmtime/src/runtime/component/types.rs index 6895ea4aa3fd..603dda89b0cf 100644 --- a/crates/wasmtime/src/runtime/component/types.rs +++ b/crates/wasmtime/src/runtime/component/types.rs @@ -7,9 +7,9 @@ use core::fmt; use core::ops::Deref; use wasmtime_environ::component::{ ComponentTypes, InterfaceType, ResourceIndex, TypeComponentIndex, TypeComponentInstanceIndex, - TypeDef, TypeEnumIndex, TypeFlagsIndex, TypeFuncIndex, TypeListIndex, TypeModuleIndex, - TypeOptionIndex, TypeRecordIndex, TypeResourceTableIndex, TypeResultIndex, TypeTupleIndex, - TypeVariantIndex, + TypeDef, TypeEnumIndex, TypeFlagsIndex, TypeFuncIndex, TypeFutureIndex, TypeFutureTableIndex, + TypeListIndex, TypeModuleIndex, TypeOptionIndex, TypeRecordIndex, TypeResourceTableIndex, + TypeResultIndex, TypeStreamIndex, TypeStreamTableIndex, TypeTupleIndex, TypeVariantIndex, }; use wasmtime_environ::PrimaryMap; @@ -145,6 +145,16 @@ impl TypeChecker<'_> { (InterfaceType::String, _) => false, (InterfaceType::Char, InterfaceType::Char) => true, (InterfaceType::Char, _) => false, + (InterfaceType::Future(t1), InterfaceType::Future(t2)) => { + self.future_table_types_equal(t1, t2) + } + (InterfaceType::Future(_), _) => false, + (InterfaceType::Stream(t1), InterfaceType::Stream(t2)) => { + self.stream_table_types_equal(t1, t2) + } + (InterfaceType::Stream(_), _) => false, + (InterfaceType::ErrorContext(_), InterfaceType::ErrorContext(_)) => true, + (InterfaceType::ErrorContext(_), _) => false, } } @@ -244,6 +254,30 @@ impl TypeChecker<'_> { let b = &self.b_types[f2]; a.names == b.names } + + fn future_table_types_equal(&self, t1: TypeFutureTableIndex, t2: TypeFutureTableIndex) -> bool { + self.futures_equal(self.a_types[t1].ty, self.b_types[t2].ty) + } + + fn futures_equal(&self, t1: TypeFutureIndex, t2: TypeFutureIndex) -> bool { + let a = &self.a_types[t1]; + let b = &self.b_types[t2]; + match (a.payload, b.payload) { + (Some(t1), Some(t2)) => self.interface_types_equal(t1, t2), + (None, None) => true, + _ => false, + } + } + + fn stream_table_types_equal(&self, t1: TypeStreamTableIndex, t2: TypeStreamTableIndex) -> bool { + self.streams_equal(self.a_types[t1].ty, self.b_types[t2].ty) + } + + fn streams_equal(&self, t1: TypeStreamIndex, t2: TypeStreamIndex) -> bool { + let a = &self.a_types[t1]; + let b = &self.b_types[t2]; + self.interface_types_equal(a.payload, b.payload) + } } /// A `list` interface type @@ -416,7 +450,7 @@ impl PartialEq for OptionType { impl Eq for OptionType {} -/// An `expected` interface type +/// A `result` interface type #[derive(Clone, Debug)] pub struct ResultType(Handle); @@ -476,6 +510,55 @@ impl PartialEq for Flags { impl Eq for Flags {} +/// An `future` interface type +#[derive(Clone, Debug)] +pub struct FutureType(Handle); + +impl FutureType { + pub(crate) fn from(index: TypeFutureIndex, ty: &InstanceType<'_>) -> Self { + FutureType(Handle::new(index, ty)) + } + + /// Retrieve the type parameter for this `future`. + pub fn ty(&self) -> Option { + Some(Type::from( + self.0.types[self.0.index].payload.as_ref()?, + &self.0.instance(), + )) + } +} + +impl PartialEq for FutureType { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::futures_equal) + } +} + +impl Eq for FutureType {} + +/// An `stream` interface type +#[derive(Clone, Debug)] +pub struct StreamType(Handle); + +impl StreamType { + pub(crate) fn from(index: TypeStreamIndex, ty: &InstanceType<'_>) -> Self { + StreamType(Handle::new(index, ty)) + } + + /// Retrieve the type parameter for this `stream`. + pub fn ty(&self) -> Type { + Type::from(&self.0.types[self.0.index].payload, &self.0.instance()) + } +} + +impl PartialEq for StreamType { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::streams_equal) + } +} + +impl Eq for StreamType {} + /// Represents a component model interface type #[derive(Clone, PartialEq, Eq, Debug)] #[allow(missing_docs)] @@ -503,6 +586,9 @@ pub enum Type { Flags(Flags), Own(ResourceType), Borrow(ResourceType), + Future(FutureType), + Stream(StreamType), + ErrorContext, } impl Type { @@ -660,6 +746,9 @@ impl Type { InterfaceType::Flags(index) => Type::Flags(Flags::from(*index, instance)), InterfaceType::Own(index) => Type::Own(instance.resource_type(*index)), InterfaceType::Borrow(index) => Type::Borrow(instance.resource_type(*index)), + InterfaceType::Future(index) => Type::Future(instance.future_type(*index)), + InterfaceType::Stream(index) => Type::Stream(instance.stream_type(*index)), + InterfaceType::ErrorContext(_) => Type::ErrorContext, } } @@ -688,6 +777,9 @@ impl Type { Type::Flags(_) => "flags", Type::Own(_) => "own", Type::Borrow(_) => "borrow", + Type::Future(_) => "future", + Type::Stream(_) => "stream", + Type::ErrorContext => "error-context", } } } diff --git a/crates/wasmtime/src/runtime/component/values.rs b/crates/wasmtime/src/runtime/component/values.rs index a57e783a6e94..bd1e811a808a 100644 --- a/crates/wasmtime/src/runtime/component/values.rs +++ b/crates/wasmtime/src/runtime/component/values.rs @@ -86,6 +86,9 @@ pub enum Val { Result(Result>, Option>>), Flags(Vec), Resource(ResourceAny), + Future(u32), + Stream(u32), + ErrorContext(u32), } impl Val { @@ -198,6 +201,11 @@ impl Val { Val::Flags(flags.into()) } + InterfaceType::Future(_) + | InterfaceType::Stream(_) + | InterfaceType::ErrorContext(_) => { + todo!() + } }) } @@ -319,6 +327,11 @@ impl Val { } Val::Flags(flags.into()) } + InterfaceType::Future(_) + | InterfaceType::Stream(_) + | InterfaceType::ErrorContext(_) => { + todo!() + } }) } @@ -429,6 +442,14 @@ impl Val { Ok(()) } (InterfaceType::Flags(_), _) => unexpected(ty, self), + ( + InterfaceType::Future(_) + | InterfaceType::Stream(_) + | InterfaceType::ErrorContext(_), + _, + ) => { + todo!() + } } } @@ -566,6 +587,14 @@ impl Val { Ok(()) } (InterfaceType::Flags(_), _) => unexpected(ty, self), + ( + InterfaceType::Future(_) + | InterfaceType::Stream(_) + | InterfaceType::ErrorContext(_), + _, + ) => { + todo!() + } } } @@ -593,6 +622,9 @@ impl Val { Val::Result(_) => "result", Val::Resource(_) => "resource", Val::Flags(_) => "flags", + Val::Future(_) => "future", + Val::Stream(_) => "stream", + Val::ErrorContext(_) => "error-context", } } } @@ -658,6 +690,12 @@ impl PartialEq for Val { (Self::Flags(_), _) => false, (Self::Resource(l), Self::Resource(r)) => l == r, (Self::Resource(_), _) => false, + (Self::Future(l), Self::Future(r)) => l == r, + (Self::Future(_), _) => false, + (Self::Stream(l), Self::Stream(r)) => l == r, + (Self::Stream(_), _) => false, + (Self::ErrorContext(l), Self::ErrorContext(r)) => l == r, + (Self::ErrorContext(_), _) => false, } } } diff --git a/crates/wasmtime/src/runtime/store.rs b/crates/wasmtime/src/runtime/store.rs index ac22f08407ba..8bd2e00bee65 100644 --- a/crates/wasmtime/src/runtime/store.rs +++ b/crates/wasmtime/src/runtime/store.rs @@ -76,6 +76,7 @@ //! contents of `StoreOpaque`. This is an invariant that we, as the authors of //! `wasmtime`, must uphold for the public interface to be safe. +use crate::component::concurrent; use crate::hash_set::HashSet; use crate::instance::InstanceData; use crate::linker::Definition; @@ -224,6 +225,8 @@ pub struct StoreInner { Option) -> Result + Send + Sync>>, // for comments about `ManuallyDrop`, see `Store::into_data` data: ManuallyDrop, + #[cfg(feature = "component-model-async")] + concurrent_state: concurrent::ConcurrentState, } enum ResourceLimiterInner { @@ -576,6 +579,8 @@ impl Store { call_hook: None, epoch_deadline_behavior: None, data: ManuallyDrop::new(data), + #[cfg(feature = "component-model-async")] + concurrent_state: Default::default(), }); // Wasmtime uses the callee argument to host functions to learn about @@ -1090,6 +1095,37 @@ impl<'a, T> StoreContextMut<'a, T> { self.0.data_mut() } + /// Asynchronously for all outstanding tasks involving this `Store` to + /// complete. + #[cfg(feature = "async")] + pub async fn wait(self) -> Result<()> + where + T: Send + 'static, + { + concurrent::poll(self).await.map(drop) + } + + /// Asynchronously for either the specified future to complete or all + /// outstanding tasks involving this `Store` to complete -- whichever + /// happens first. + #[cfg(feature = "async")] + pub async fn wait_until(self, future: impl Future) -> Result + where + T: Send + 'static, + { + concurrent::poll_until(self, future).await.map(|(_, v)| v) + } + + #[cfg(feature = "component-model-async")] + pub(crate) fn concurrent_state(&mut self) -> &mut concurrent::ConcurrentState { + self.0.concurrent_state() + } + + #[cfg(feature = "component-model-async")] + pub(crate) fn has_pkey(&self) -> bool { + self.0.pkey.is_some() + } + /// Returns the underlying [`Engine`] this store is connected to. pub fn engine(&self) -> &Engine { self.0.engine() @@ -1175,6 +1211,11 @@ impl StoreInner { &mut self.data } + #[cfg(feature = "component-model-async")] + fn concurrent_state(&mut self) -> &mut concurrent::ConcurrentState { + &mut self.concurrent_state + } + #[inline] pub fn call_hook(&mut self, s: CallHook) -> Result<()> { if self.inner.pkey.is_none() && self.call_hook.is_none() { diff --git a/crates/wasmtime/src/runtime/vm/component.rs b/crates/wasmtime/src/runtime/vm/component.rs index 59f5bcad3598..f2ea5c8e35a3 100644 --- a/crates/wasmtime/src/runtime/vm/component.rs +++ b/crates/wasmtime/src/runtime/vm/component.rs @@ -20,6 +20,7 @@ use core::mem::offset_of; use core::ops::Deref; use core::ptr::{self, NonNull}; use sptr::Strict; +use std::collections::HashMap; use wasmtime_environ::component::*; use wasmtime_environ::{HostPtr, PrimaryMap, VMSharedTypeIndex}; @@ -58,6 +59,10 @@ pub struct ComponentInstance { /// is how this field is manipulated. component_resource_tables: PrimaryMap, + handle_table: HashMap<(TableIndex, u32), TableState>, + + error_context_table: HashMap<(TypeErrorContextTableIndex, u32), usize>, + /// Storage for the type information about resources within this component /// instance. /// @@ -106,10 +111,12 @@ pub type VMLoweringCallee = extern "C" fn( vmctx: *mut VMOpaqueContext, data: *mut u8, ty: TypeFuncIndex, + caller_instance: RuntimeComponentInstanceIndex, flags: InstanceFlags, opt_memory: *mut VMMemoryDefinition, opt_realloc: *mut VMFuncRef, string_encoding: StringEncoding, + async_: bool, args_and_results: *mut mem::MaybeUninit, nargs_and_results: usize, ); @@ -126,6 +133,166 @@ pub struct VMLowering { pub data: *mut u8, } +/// Type signature for the host-defined `task.backpressure` built-in function. +pub type VMTaskBackpressureCallback = extern "C" fn( + vmctx: *mut VMOpaqueContext, + caller_instance: RuntimeComponentInstanceIndex, + arg: u32, +); + +/// Type signature for the host-defined `task.return` built-in function. +pub type VMTaskReturnCallback = extern "C" fn( + vmctx: *mut VMOpaqueContext, + args_and_results: *mut mem::MaybeUninit, + nargs_and_results: usize, +); + +/// Type signature for the host-defined `task.wait` and `task.poll` built-in functions. +pub type VMTaskWaitOrPollCallback = extern "C" fn( + vmctx: *mut VMOpaqueContext, + async_: bool, + memory: *mut VMMemoryDefinition, + payload: u32, +) -> u32; + +/// Type signature for the host-defined `task.yield` built-in function. +pub type VMTaskYieldCallback = extern "C" fn(vmctx: *mut VMOpaqueContext, async_: bool); + +/// Type signature for the host-defined `subtask.drop` built-in function. +pub type VMSubtaskDropCallback = + extern "C" fn(vmctx: *mut VMOpaqueContext, instance: RuntimeComponentInstanceIndex, arg: u32); + +/// Type signature for the host-defined built-in function to represent starting +/// a call to an async-lowered import in a FACT-generated module. +pub type VMAsyncEnterCallback = extern "C" fn( + vmctx: *mut VMOpaqueContext, + start: *mut VMFuncRef, + return_: *mut VMFuncRef, + caller_instance: RuntimeComponentInstanceIndex, + params: u32, + results: u32, +); + +/// Type signature for the host-defined built-in function to represent +/// completing a call to an async-lowered import in a FACT-generated module. +pub type VMAsyncExitCallback = extern "C" fn( + vmctx: *mut VMOpaqueContext, + callback: *mut VMFuncRef, + callee: *mut VMFuncRef, + callee_instance: RuntimeComponentInstanceIndex, + param_count: u32, + result_count: u32, + flags: u32, +) -> u32; + +/// Type signature for the host-defined `future.new` built-in function. +pub type VMFutureNewCallback = + extern "C" fn(vmctx: *mut VMOpaqueContext, ty: TypeFutureTableIndex) -> u32; + +/// Type signature for the host-defined `future.read` and `future.write` +/// built-in functions. +pub type VMFutureTransmitCallback = extern "C" fn( + vmctx: *mut VMOpaqueContext, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: StringEncoding, + ty: TypeFutureTableIndex, + future: u32, + address: u32, +) -> u32; + +/// Type signature for the host-defined `future.cancel-read` and +/// `future.cancel-write` built-in functions. +pub type VMFutureCancelCallback = extern "C" fn( + vmctx: *mut VMOpaqueContext, + ty: TypeFutureTableIndex, + async_: bool, + handle: u32, +) -> u32; + +/// Type signature for the host-defined `future.close-readable` built-in function. +pub type VMFutureCloseReadableCallback = + extern "C" fn(vmctx: *mut VMOpaqueContext, ty: TypeFutureTableIndex, handle: u32); + +/// Type signature for the host-defined `future.close-writable` built-in function. +pub type VMFutureCloseWritableCallback = + extern "C" fn(vmctx: *mut VMOpaqueContext, ty: TypeFutureTableIndex, handle: u32, error: u32); + +/// Type signature for the host-defined `stream.new` built-in function. +pub type VMStreamNewCallback = + extern "C" fn(vmctx: *mut VMOpaqueContext, ty: TypeStreamTableIndex) -> u32; + +/// Type signature for the host-defined `stream.read` and `stream.write` +/// built-in functions +pub type VMStreamTransmitCallback = extern "C" fn( + vmctx: *mut VMOpaqueContext, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: StringEncoding, + ty: TypeStreamTableIndex, + stream: u32, + address: u32, + count: u32, +) -> u32; + +/// Type signature for the host-defined `stream.read` ans `stream.write` +/// built-in functions for when the payload is trivially `memcpy`-able. +pub type VMFlatStreamTransmitCallback = extern "C" fn( + vmctx: *mut VMOpaqueContext, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + ty: TypeStreamTableIndex, + payload_size: u32, + payload_align: u32, + stream: u32, + address: u32, + count: u32, +) -> u32; + +/// Type signature for the host-defined `stream.close-readable` built-in function. +pub type VMStreamCloseReadableCallback = + extern "C" fn(vmctx: *mut VMOpaqueContext, ty: TypeStreamTableIndex, handle: u32); + +/// Type signature for the host-defined `stream.close-writable` built-in function. +pub type VMStreamCloseWritableCallback = + extern "C" fn(vmctx: *mut VMOpaqueContext, ty: TypeStreamTableIndex, handle: u32, error: u32); + +/// Type signature for the host-defined `stream.cancel-read` and +/// `stream.cancel-write` built-in functions. +pub type VMStreamCancelCallback = extern "C" fn( + vmctx: *mut VMOpaqueContext, + ty: TypeStreamTableIndex, + async_: bool, + handle: u32, +) -> u32; + +/// Type signature for the host-defined `error-context.new` built-in function. +pub type VMErrorContextNewCallback = extern "C" fn( + vmctx: *mut VMOpaqueContext, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: StringEncoding, + ty: TypeErrorContextTableIndex, + address: u32, + count: u32, +) -> u32; + +/// Type signature for the host-defined `error-context.debug-message` built-in +/// function. +pub type VMErrorContextDebugMessageCallback = extern "C" fn( + vmctx: *mut VMOpaqueContext, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: StringEncoding, + ty: TypeErrorContextTableIndex, + handle: u32, + address: u32, +); + +/// Type signature for the host-defined `error-context.drop` built-in function. +pub type VMErrorContextDropCallback = + extern "C" fn(vmctx: *mut VMOpaqueContext, ty: TypeErrorContextTableIndex, handle: u32); + /// This is a marker type to represent the underlying allocation of a /// `VMComponentContext`. /// @@ -144,6 +311,28 @@ pub struct VMComponentContext { _marker: marker::PhantomPinned, } +/// Represents either a `TypeFutureTableIndex` or a `TypeStreamTableIndex`. +#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] +pub enum TableIndex { + /// A `TypeFutureTableIndex` + Future(TypeFutureTableIndex), + /// A `TypeStreamTableIndex` + Stream(TypeStreamTableIndex), +} + +/// Represents the state of a future or stream handle. +#[derive(Debug)] +pub enum TableState { + /// Both the read and write ends are owned by the same component instance. + Local, + /// Only the write end is owned by this component instance. + Write, + /// Only the read end is owned by this component instance. + Read, + /// A read or write is in progress. + Busy, +} + impl ComponentInstance { /// Converts the `vmctx` provided into a `ComponentInstance` and runs the /// provided closure with that instance. @@ -212,6 +401,8 @@ impl ComponentInstance { .unwrap(), ), component_resource_tables, + handle_table: HashMap::new(), + error_context_table: HashMap::new(), runtime_info, resource_types, vmctx: VMComponentContext { @@ -286,6 +477,18 @@ impl ComponentInstance { } } + /// Returns the async callback pointer corresponding to the index provided. + /// + /// This can only be called after `idx` has been initialized at runtime + /// during the instantiation process of a component. + pub fn runtime_callback(&self, idx: RuntimeCallbackIndex) -> NonNull { + unsafe { + let ret = *self.vmctx_plus_offset::>(self.offsets.runtime_callback(idx)); + debug_assert!(ret.as_ptr() as usize != INVALID_PTR); + ret + } + } + /// Returns the post-return pointer corresponding to the index provided. /// /// This can only be called after `idx` has been initialized at runtime @@ -359,6 +562,15 @@ impl ComponentInstance { } } + /// Same as `set_runtime_memory` but for async callback function pointers. + pub fn set_runtime_callback(&mut self, idx: RuntimeCallbackIndex, ptr: NonNull) { + unsafe { + let storage = self.vmctx_plus_offset_mut(self.offsets.runtime_callback(idx)); + debug_assert!(*storage as usize == INVALID_PTR); + *storage = ptr.as_ptr(); + } + } + /// Same as `set_runtime_memory` but for post-return function pointers. pub fn set_runtime_post_return( &mut self, @@ -435,6 +647,75 @@ impl ComponentInstance { } } + /// Set the host-provided callbacks for various async-, future-, stream-, + /// and error-context-related built-in functions. + pub fn set_async_callbacks( + &mut self, + task_backpressure: VMTaskBackpressureCallback, + task_return: VMTaskReturnCallback, + task_wait: VMTaskWaitOrPollCallback, + task_poll: VMTaskWaitOrPollCallback, + task_yield: VMTaskYieldCallback, + subtask_drop: VMSubtaskDropCallback, + async_enter: VMAsyncEnterCallback, + async_exit: VMAsyncExitCallback, + future_new: VMFutureNewCallback, + future_write: VMFutureTransmitCallback, + future_read: VMFutureTransmitCallback, + future_cancel_write: VMFutureCancelCallback, + future_cancel_read: VMFutureCancelCallback, + future_close_writable: VMFutureCloseWritableCallback, + future_close_readable: VMFutureCloseReadableCallback, + stream_new: VMStreamNewCallback, + stream_write: VMStreamTransmitCallback, + stream_read: VMStreamTransmitCallback, + stream_cancel_write: VMStreamCancelCallback, + stream_cancel_read: VMStreamCancelCallback, + stream_close_writable: VMStreamCloseWritableCallback, + stream_close_readable: VMStreamCloseReadableCallback, + flat_stream_write: VMFlatStreamTransmitCallback, + flat_stream_read: VMFlatStreamTransmitCallback, + error_context_new: VMErrorContextNewCallback, + error_context_debug_message: VMErrorContextDebugMessageCallback, + error_context_drop: VMErrorContextDropCallback, + ) { + unsafe { + *self.vmctx_plus_offset_mut(self.offsets.task_backpressure()) = task_backpressure; + *self.vmctx_plus_offset_mut(self.offsets.task_return()) = task_return; + *self.vmctx_plus_offset_mut(self.offsets.task_wait()) = task_wait; + *self.vmctx_plus_offset_mut(self.offsets.task_poll()) = task_poll; + *self.vmctx_plus_offset_mut(self.offsets.task_yield()) = task_yield; + *self.vmctx_plus_offset_mut(self.offsets.subtask_drop()) = subtask_drop; + *self.vmctx_plus_offset_mut(self.offsets.async_enter()) = async_enter; + *self.vmctx_plus_offset_mut(self.offsets.async_exit()) = async_exit; + *self.vmctx_plus_offset_mut(self.offsets.future_new()) = future_new; + *self.vmctx_plus_offset_mut(self.offsets.future_write()) = future_write; + *self.vmctx_plus_offset_mut(self.offsets.future_read()) = future_read; + *self.vmctx_plus_offset_mut(self.offsets.future_cancel_write()) = future_cancel_write; + *self.vmctx_plus_offset_mut(self.offsets.future_cancel_read()) = future_cancel_read; + *self.vmctx_plus_offset_mut(self.offsets.future_close_writable()) = + future_close_writable; + *self.vmctx_plus_offset_mut(self.offsets.future_close_readable()) = + future_close_readable; + *self.vmctx_plus_offset_mut(self.offsets.stream_new()) = stream_new; + *self.vmctx_plus_offset_mut(self.offsets.stream_write()) = stream_write; + *self.vmctx_plus_offset_mut(self.offsets.stream_read()) = stream_read; + *self.vmctx_plus_offset_mut(self.offsets.stream_cancel_write()) = stream_cancel_write; + *self.vmctx_plus_offset_mut(self.offsets.stream_cancel_read()) = stream_cancel_read; + *self.vmctx_plus_offset_mut(self.offsets.stream_close_writable()) = + stream_close_writable; + *self.vmctx_plus_offset_mut(self.offsets.stream_close_readable()) = + stream_close_readable; + *self.vmctx_plus_offset_mut(self.offsets.flat_stream_write()) = flat_stream_write; + *self.vmctx_plus_offset_mut(self.offsets.flat_stream_read()) = flat_stream_read; + *self.vmctx_plus_offset_mut(self.offsets.error_context_debug_message()) = + error_context_new; + *self.vmctx_plus_offset_mut(self.offsets.error_context_debug_message()) = + error_context_debug_message; + *self.vmctx_plus_offset_mut(self.offsets.error_context_drop()) = error_context_drop; + } + } + unsafe fn initialize_vmctx(&mut self, store: *mut dyn VMStore) { *self.vmctx_plus_offset_mut(self.offsets.magic()) = VMCOMPONENT_MAGIC; *self.vmctx_plus_offset_mut(self.offsets.libcalls()) = &libcalls::VMComponentLibcalls::INIT; @@ -449,7 +730,7 @@ impl ComponentInstance { } // In debug mode set non-null bad values to all "pointer looking" bits - // and pices related to lowering and such. This'll help detect any + // and pieces related to lowering and such. This'll help detect any // erroneous usage and enable debug assertions above as well to prevent // loading these before they're configured or setting them twice. if cfg!(debug_assertions) { @@ -475,6 +756,11 @@ impl ComponentInstance { let offset = self.offsets.runtime_realloc(i); *self.vmctx_plus_offset_mut(offset) = INVALID_PTR; } + for i in 0..self.offsets.num_runtime_callbacks { + let i = RuntimeCallbackIndex::from_u32(i); + let offset = self.offsets.runtime_callback(i); + *self.vmctx_plus_offset_mut(offset) = INVALID_PTR; + } for i in 0..self.offsets.num_runtime_post_returns { let i = RuntimePostReturnIndex::from_u32(i); let offset = self.offsets.runtime_post_return(i); @@ -561,6 +847,20 @@ impl ComponentInstance { } } + /// Retrieves the table for tracking future and stream handles and their + /// states with respect to the components which own them. + pub fn handle_table(&mut self) -> &mut HashMap<(TableIndex, u32), TableState> { + &mut self.handle_table + } + + /// Retrieves the table for tracking error context handles and their + /// reference counts with respect to the components which own them. + pub fn error_context_table( + &mut self, + ) -> &mut HashMap<(TypeErrorContextTableIndex, u32), usize> { + &mut self.error_context_table + } + /// Returns the runtime state of resources associated with this component. #[inline] pub fn component_resource_tables( @@ -629,6 +929,71 @@ impl ComponentInstance { pub(crate) fn resource_exit_call(&mut self) -> Result<()> { self.resource_tables().exit_call() } + + pub(crate) fn future_transfer( + &mut self, + idx: u32, + src: TypeFutureTableIndex, + dst: TypeFutureTableIndex, + ) -> Result { + self.handle_transfer(idx, TableIndex::Future(src), TableIndex::Future(dst)) + } + + pub(crate) fn stream_transfer( + &mut self, + idx: u32, + src: TypeStreamTableIndex, + dst: TypeStreamTableIndex, + ) -> Result { + self.handle_transfer(idx, TableIndex::Stream(src), TableIndex::Stream(dst)) + } + + fn handle_transfer(&mut self, idx: u32, src: TableIndex, dst: TableIndex) -> Result { + // TODO: avoid the information leak and nondeterminism of exposing global table indexes here + let Some(state) = self.handle_table.get(&(src, idx)) else { + bail!("invalid handle"); + }; + + match state { + TableState::Local => { + self.handle_table.insert((src, idx), TableState::Write); + assert!(self + .handle_table + .insert((dst, idx), TableState::Read) + .is_none()); + } + TableState::Read => { + self.handle_table.remove(&(src, idx)); + if let Some(TableState::Write) = self.handle_table.get(&(dst, idx)) { + self.handle_table.insert((dst, idx), TableState::Local); + } else { + assert!(self + .handle_table + .insert((dst, idx), TableState::Read) + .is_none()); + } + } + TableState::Write => bail!("cannot transfer write end of stream or future"), + TableState::Busy => bail!("cannot transfer busy stream or future"), + } + + Ok(idx) + } + + pub(crate) fn error_context_transfer( + &mut self, + idx: u32, + src: TypeErrorContextTableIndex, + dst: TypeErrorContextTableIndex, + ) -> Result { + if !self.error_context_table.contains_key(&(src, idx)) { + bail!("invalid handle"); + } + + *self.error_context_table.entry((dst, idx)).or_default() += 1; + + Ok(idx) + } } impl VMComponentContext { @@ -712,6 +1077,11 @@ impl OwnedComponentInstance { unsafe { self.instance_mut().set_runtime_realloc(idx, ptr) } } + /// See `ComponentInstance::set_runtime_callback` + pub fn set_runtime_callback(&mut self, idx: RuntimeCallbackIndex, ptr: NonNull) { + unsafe { self.instance_mut().set_runtime_callback(idx, ptr) } + } + /// See `ComponentInstance::set_runtime_post_return` pub fn set_runtime_post_return( &mut self, @@ -753,6 +1123,70 @@ impl OwnedComponentInstance { pub fn resource_types_mut(&mut self) -> &mut Arc { unsafe { &mut (*self.ptr.as_ptr()).resource_types } } + + /// See `ComponentInstance::set_async_callbacks` + pub fn set_async_callbacks( + &mut self, + task_backpressure: VMTaskBackpressureCallback, + task_return: VMTaskReturnCallback, + task_wait: VMTaskWaitOrPollCallback, + task_poll: VMTaskWaitOrPollCallback, + task_yield: VMTaskYieldCallback, + subtask_drop: VMSubtaskDropCallback, + async_enter: VMAsyncEnterCallback, + async_exit: VMAsyncExitCallback, + future_new: VMFutureNewCallback, + future_write: VMFutureTransmitCallback, + future_read: VMFutureTransmitCallback, + future_cancel_write: VMFutureCancelCallback, + future_cancel_read: VMFutureCancelCallback, + future_close_writable: VMFutureCloseWritableCallback, + future_close_readable: VMFutureCloseReadableCallback, + stream_new: VMStreamNewCallback, + stream_write: VMStreamTransmitCallback, + stream_read: VMStreamTransmitCallback, + stream_cancel_write: VMStreamCancelCallback, + stream_cancel_read: VMStreamCancelCallback, + stream_close_writable: VMStreamCloseWritableCallback, + stream_close_readable: VMStreamCloseReadableCallback, + flat_stream_write: VMFlatStreamTransmitCallback, + flat_stream_read: VMFlatStreamTransmitCallback, + error_context_new: VMErrorContextNewCallback, + error_context_debug_message: VMErrorContextDebugMessageCallback, + error_context_drop: VMErrorContextDropCallback, + ) { + unsafe { + self.instance_mut().set_async_callbacks( + task_backpressure, + task_return, + task_wait, + task_poll, + task_yield, + subtask_drop, + async_enter, + async_exit, + future_new, + future_write, + future_read, + future_cancel_write, + future_cancel_read, + future_close_writable, + future_close_readable, + stream_new, + stream_write, + stream_read, + stream_cancel_write, + stream_cancel_read, + stream_close_writable, + stream_close_readable, + flat_stream_write, + flat_stream_read, + error_context_new, + error_context_debug_message, + error_context_drop, + ) + } + } } impl Deref for OwnedComponentInstance { diff --git a/crates/wasmtime/src/runtime/vm/component/libcalls.rs b/crates/wasmtime/src/runtime/vm/component/libcalls.rs index 75f0cdc2421f..c8ca37711bb2 100644 --- a/crates/wasmtime/src/runtime/vm/component/libcalls.rs +++ b/crates/wasmtime/src/runtime/vm/component/libcalls.rs @@ -4,7 +4,9 @@ use crate::prelude::*; use crate::runtime::vm::component::{ComponentInstance, VMComponentContext}; use core::cell::Cell; use core::slice; -use wasmtime_environ::component::TypeResourceTableIndex; +use wasmtime_environ::component::{ + TypeErrorContextTableIndex, TypeFutureTableIndex, TypeResourceTableIndex, TypeStreamTableIndex, +}; const UTF16_TAG: usize = 1 << 31; @@ -570,3 +572,42 @@ unsafe fn resource_exit_call(vmctx: *mut VMComponentContext) -> Result<()> { unsafe fn trap(_vmctx: *mut VMComponentContext, code: u8) -> Result<()> { Err(wasmtime_environ::Trap::from_u8(code).unwrap()).err2anyhow() } + +unsafe fn future_transfer( + vmctx: *mut VMComponentContext, + src_idx: u32, + src_table: u32, + dst_table: u32, +) -> Result { + let src_table = TypeFutureTableIndex::from_u32(src_table); + let dst_table = TypeFutureTableIndex::from_u32(dst_table); + ComponentInstance::from_vmctx(vmctx, |instance| { + instance.future_transfer(src_idx, src_table, dst_table) + }) +} + +unsafe fn stream_transfer( + vmctx: *mut VMComponentContext, + src_idx: u32, + src_table: u32, + dst_table: u32, +) -> Result { + let src_table = TypeStreamTableIndex::from_u32(src_table); + let dst_table = TypeStreamTableIndex::from_u32(dst_table); + ComponentInstance::from_vmctx(vmctx, |instance| { + instance.stream_transfer(src_idx, src_table, dst_table) + }) +} + +unsafe fn error_context_transfer( + vmctx: *mut VMComponentContext, + src_idx: u32, + src_table: u32, + dst_table: u32, +) -> Result { + let src_table = TypeErrorContextTableIndex::from_u32(src_table); + let dst_table = TypeErrorContextTableIndex::from_u32(dst_table); + ComponentInstance::from_vmctx(vmctx, |instance| { + instance.error_context_transfer(src_idx, src_table, dst_table) + }) +} diff --git a/crates/wasmtime/src/runtime/vm/instance/allocator/pooling.rs b/crates/wasmtime/src/runtime/vm/instance/allocator/pooling.rs index 75e41b562399..72e0443b8bd3 100644 --- a/crates/wasmtime/src/runtime/vm/instance/allocator/pooling.rs +++ b/crates/wasmtime/src/runtime/vm/instance/allocator/pooling.rs @@ -519,6 +519,7 @@ unsafe impl InstanceAllocatorImpl for PoolingInstanceAllocator { LowerImport { .. } | ExtractMemory(_) | ExtractRealloc(_) + | ExtractCallback(_) | ExtractPostReturn(_) | Resource(_) => {} } diff --git a/crates/wast/src/component.rs b/crates/wast/src/component.rs index 8a7f19dc08d0..1346a7f11361 100644 --- a/crates/wast/src/component.rs +++ b/crates/wast/src/component.rs @@ -284,6 +284,9 @@ fn mismatch(expected: &WastVal<'_>, actual: &Val) -> Result<()> { Val::Result(..) => "result", Val::Flags(..) => "flags", Val::Resource(..) => "resource", + Val::Future(..) => "future", + Val::Stream(..) => "stream", + Val::ErrorContext(..) => "error-context", }; bail!("expected `{expected}` got `{actual}`") } diff --git a/crates/wast/src/wast.rs b/crates/wast/src/wast.rs index 1286f264abf6..13194f5d56ec 100644 --- a/crates/wast/src/wast.rs +++ b/crates/wast/src/wast.rs @@ -187,6 +187,7 @@ where .map(|v| match v { WastArg::Core(v) => core::val(&mut self.store, v), WastArg::Component(_) => bail!("expected component function, found core"), + _ => todo!(), }) .collect::>>()?; @@ -204,6 +205,7 @@ where .map(|v| match v { WastArg::Component(v) => component::val(v), WastArg::Core(_) => bail!("expected core function, found component"), + _ => todo!(), }) .collect::>>()?; @@ -350,6 +352,7 @@ where WastRet::Component(_) => { bail!("expected component value found core value") } + _ => todo!(), }; core::match_val(&mut self.store, v, e) .with_context(|| format!("result {i} didn't match"))?; @@ -366,6 +369,7 @@ where bail!("expected component value found core value") } WastRet::Component(val) => val, + _ => todo!(), }; component::match_val(e, v) .with_context(|| format!("result {i} didn't match"))?; diff --git a/crates/wit-bindgen/Cargo.toml b/crates/wit-bindgen/Cargo.toml index 90e8ea3e9959..9e958081ce96 100644 --- a/crates/wit-bindgen/Cargo.toml +++ b/crates/wit-bindgen/Cargo.toml @@ -20,3 +20,4 @@ indexmap = { workspace = true } [features] std = [] +component-model-async = ['std'] diff --git a/crates/wit-bindgen/src/lib.rs b/crates/wit-bindgen/src/lib.rs index 74687f9d673c..ff0d3e7209fb 100644 --- a/crates/wit-bindgen/src/lib.rs +++ b/crates/wit-bindgen/src/lib.rs @@ -59,10 +59,10 @@ struct Wasmtime { opts: Opts, /// A list of all interfaces which were imported by this world. /// - /// The second value here is the contents of the module that this interface - /// generated. The third value is the name of the interface as also present - /// in `self.interface_names`. - import_interfaces: Vec<(InterfaceId, String, InterfaceName)>, + /// The first two values identify the interface; the third is the contents of the + /// module that this interface generated. The fourth value is the name of the + /// interface as also present in `self.interface_names`. + import_interfaces: Vec<(WorldKey, InterfaceId, String, InterfaceName)>, import_functions: Vec, exports: Exports, types: Types, @@ -134,6 +134,14 @@ pub struct Opts { /// Whether or not to use async rust functions and traits. pub async_: AsyncConfig, + /// Whether or not to use `func_wrap_concurrent` when generating code for + /// async imports. + /// + /// Unlike `func_wrap_async`, `func_wrap_concurrent` allows host functions + /// to suspend without monopolizing the `Store`, meaning other guest tasks + /// can make progress concurrently. + pub concurrent_imports: bool, + /// A list of "trappable errors" which are used to replace the `E` in /// `result` found in WIT. pub trappable_error_type: Vec, @@ -175,6 +183,15 @@ pub struct Opts { /// Path to the `wasmtime` crate if it's not the default path. pub wasmtime_crate: Option, + + /// If true, write the generated bindings to a file for better error + /// messages from `rustc`. + /// + /// This can also be toggled via the `WASMTIME_DEBUG_BINDGEN` environment + /// variable, but that will affect _all_ `bindgen!` macro invocations (and + /// can sometimes lead to one invocation ovewriting another in unpredictable + /// ways), whereas this option lets you specify it on a case-by-case basis. + pub debug: bool, } #[derive(Debug, Clone)] @@ -213,28 +230,10 @@ pub enum AsyncConfig { OnlyImports(HashSet), } -impl AsyncConfig { - pub fn is_import_async(&self, f: &str) -> bool { - match self { - AsyncConfig::None => false, - AsyncConfig::All => true, - AsyncConfig::AllExceptImports(set) => !set.contains(f), - AsyncConfig::OnlyImports(set) => set.contains(f), - } - } - - pub fn is_drop_async(&self, r: &str) -> bool { - self.is_import_async(&format!("[drop]{r}")) - } - - pub fn maybe_async(&self) -> bool { - match self { - AsyncConfig::None => false, - AsyncConfig::All | AsyncConfig::AllExceptImports(_) | AsyncConfig::OnlyImports(_) => { - true - } - } - } +pub enum CallStyle { + Sync, + Async, + Concurrent, } #[derive(Default, Debug, Clone)] @@ -260,6 +259,22 @@ impl TrappableImports { impl Opts { pub fn generate(&self, resolve: &Resolve, world: WorldId) -> anyhow::Result { + // TODO: Should we refine this test to inspect only types reachable from + // the specified world? + if !cfg!(feature = "component-model-async") + && resolve.types.iter().any(|(_, ty)| { + matches!( + ty.kind, + TypeDefKind::Future(_) | TypeDefKind::Stream(_) | TypeDefKind::ErrorContext + ) + }) + { + anyhow::bail!( + "must enable `component-model-async` feature when using WIT files \ + containing future, stream, or error types" + ); + } + let mut r = Wasmtime::default(); r.sizes.fill(resolve); r.opts = self.clone(); @@ -268,7 +283,41 @@ impl Opts { } fn is_store_data_send(&self) -> bool { - self.async_.maybe_async() || self.require_store_data_send + matches!(self.call_style(), CallStyle::Async | CallStyle::Concurrent) + || self.require_store_data_send + } + + pub fn import_call_style(&self, qualifier: Option<&str>, f: &str) -> CallStyle { + let matched = |names: &HashSet| { + names.contains(f) + || qualifier + .map(|v| names.contains(&format!("{v}#{f}"))) + .unwrap_or(false) + }; + + match &self.async_ { + AsyncConfig::AllExceptImports(names) if matched(names) => CallStyle::Sync, + AsyncConfig::OnlyImports(names) if !matched(names) => CallStyle::Sync, + _ => self.call_style(), + } + } + + pub fn drop_call_style(&self, qualifier: Option<&str>, r: &str) -> CallStyle { + self.import_call_style(qualifier, &format!("[drop]{r}")) + } + + pub fn call_style(&self) -> CallStyle { + match &self.async_ { + AsyncConfig::None => CallStyle::Sync, + + AsyncConfig::All | AsyncConfig::AllExceptImports(_) | AsyncConfig::OnlyImports(_) => { + if self.concurrent_imports { + CallStyle::Concurrent + } else { + CallStyle::Async + } + } + } } } @@ -455,7 +504,7 @@ impl Wasmtime { // resource-related functions get their trait signatures // during `type_resource`. let sig = if let FunctionKind::Freestanding = func.kind { - gen.generate_function_trait_sig(func); + gen.generate_function_trait_sig(func, "Data"); Some(mem::take(&mut gen.src).into()) } else { None @@ -471,14 +520,14 @@ impl Wasmtime { WorldItem::Interface { id, .. } => { gen.gen.interface_last_seen_as_import.insert(*id, true); gen.current_interface = Some((*id, name, false)); - let snake = match name { + let snake = to_rust_ident(&match name { WorldKey::Name(s) => s.to_snake_case(), WorldKey::Interface(id) => resolve.interfaces[*id] .name .as_ref() .unwrap() .to_snake_case(), - }; + }); let module = if gen.gen.name_interface(resolve, *id, name, false) { // If this interface is remapped then that means that it was // provided via the `with` key in the bindgen configuration. @@ -523,8 +572,12 @@ impl Wasmtime { " ) }; - self.import_interfaces - .push((*id, module, self.interface_names[id].clone())); + self.import_interfaces.push(( + name.clone(), + *id, + module, + self.interface_names[id].clone(), + )); let interface_path = self.import_interface_path(id); self.interface_link_options[id] @@ -806,10 +859,11 @@ fn _new( let wt = self.wasmtime_path(); let world_name = &resolve.worlds[world].name; let camel = to_rust_upper_camel_case(&world_name); - let (async_, async__, where_clause, await_) = if self.opts.async_.maybe_async() { - ("async", "_async", "where _T: Send", ".await") - } else { - ("", "", "", "") + let (async_, async__, bounds, await_) = match self.opts.call_style() { + CallStyle::Async | CallStyle::Concurrent => { + ("async", "_async", ": Send + 'static", ".await") + } + CallStyle::Sync => ("", "", ": 'static", ""), }; uwriteln!( self.src, @@ -836,7 +890,7 @@ impl Clone for {camel}Pre {{ }} }} -impl<_T> {camel}Pre<_T> {{ +impl<_T{bounds}> {camel}Pre<_T> {{ /// Creates a new copy of `{camel}Pre` bindings which can then /// be used to instantiate into a particular store. /// @@ -866,7 +920,6 @@ impl<_T> {camel}Pre<_T> {{ &self, mut store: impl {wt}::AsContextMut, ) -> {wt}::Result<{camel}> - {where_clause} {{ let mut store = store.as_context_mut(); let instance = self.instance_pre.instantiate{async__}(&mut store){await_}?; @@ -1031,7 +1084,7 @@ impl<_T> {camel}Pre<_T> {{ component: &{wt}::component::Component, linker: &{wt}::component::Linker<_T>, ) -> {wt}::Result<{camel}> - {where_clause} + where _T{bounds} {{ let pre = linker.instantiate_pre(component)?; {camel}Pre::new(pre)?.instantiate{async__}(store){await_} @@ -1090,7 +1143,12 @@ impl<_T> {camel}Pre<_T> {{ } let imports = mem::take(&mut self.import_interfaces); - self.emit_modules(imports); + self.emit_modules( + imports + .into_iter() + .map(|(_, id, module, path)| (id, module, path)) + .collect(), + ); let exports = mem::take(&mut self.exports.modules); self.emit_modules(exports); @@ -1357,12 +1415,12 @@ impl Wasmtime { let wt = self.wasmtime_path(); let world_camel = to_rust_upper_camel_case(&resolve.worlds[world].name); - if self.opts.async_.maybe_async() { + if let CallStyle::Async = self.opts.call_style() { uwriteln!(self.src, "#[{wt}::component::__internal::async_trait]") } uwrite!(self.src, "pub trait {world_camel}Imports"); let mut supertraits = vec![]; - if self.opts.async_.maybe_async() { + if let CallStyle::Async = self.opts.call_style() { supertraits.push("Send".to_string()); } for (_, name) in get_world_resources(resolve, world) { @@ -1404,7 +1462,7 @@ impl Wasmtime { ); // Generate impl WorldImports for &mut WorldImports - let (async_trait, maybe_send) = if self.opts.async_.maybe_async() { + let (async_trait, maybe_send) = if let CallStyle::Async = self.opts.call_style() { ( format!("#[{wt}::component::__internal::async_trait]\n"), "+ Send", @@ -1413,10 +1471,25 @@ impl Wasmtime { (String::new(), "") }; if !self.opts.skip_mut_forwarding_impls { + let maybe_maybe_sized = if let CallStyle::Concurrent = self.opts.call_style() { + "" + } else { + "+ ?Sized" + }; uwriteln!( self.src, - "{async_trait}impl<_T: {world_camel}Imports + ?Sized {maybe_send}> {world_camel}Imports for &mut _T {{" + "{async_trait}impl<_T: {world_camel}Imports {maybe_maybe_sized} {maybe_send}> {world_camel}Imports for &mut _T {{" ); + let has_concurrent_function = self.import_functions.iter().any(|f| { + matches!( + self.opts.import_call_style(None, &f.func.name), + CallStyle::Concurrent + ) + }); + + if has_concurrent_function { + self.src.push_str("type Data = _T::Data;\n"); + } // Forward each method call to &mut T for f in self.import_functions.iter() { if let Some(sig) = &f.sig { @@ -1430,7 +1503,7 @@ impl Wasmtime { uwrite!(self.src, "{},", to_rust_ident(name)); } uwrite!(self.src, ")"); - if self.opts.async_.is_import_async(&f.func.name) { + if let CallStyle::Async = self.opts.import_call_style(None, &f.func.name) { uwrite!(self.src, ".await"); } uwriteln!(self.src, "}}"); @@ -1443,7 +1516,7 @@ impl Wasmtime { fn import_interface_paths(&self) -> Vec<(InterfaceId, String)> { self.import_interfaces .iter() - .map(|(id, _, name)| { + .map(|(_, id, _, name)| { let path = match name { InterfaceName::Path(path) => path.join("::"), InterfaceName::Remapped { name_at_root, .. } => name_at_root.clone(), @@ -1470,7 +1543,7 @@ impl Wasmtime { let world_camel = to_rust_upper_camel_case(&resolve.worlds[world].name); traits.push(format!("{world_camel}Imports")); } - if self.opts.async_.maybe_async() { + if let CallStyle::Async = self.opts.call_style() { traits.push("Send".to_string()); } traits @@ -1512,6 +1585,7 @@ impl Wasmtime { let gate = FeatureGate::open(&mut self.src, &resolve.worlds[world].stability); for (ty, name) in get_world_resources(resolve, world) { Self::generate_add_resource_to_linker( + None, &mut self.src, &self.opts, &wt, @@ -1528,7 +1602,45 @@ impl Wasmtime { uwriteln!(self.src, "Ok(())\n}}"); } - let host_bounds = format!("U: {}", self.world_host_traits(resolve, world).join(" + ")); + let (host_bounds, data_bounds) = if let CallStyle::Concurrent = self.opts.call_style() { + // TODO: include world imports trait if applicable + let bounds = self + .import_interfaces + .iter() + .map(|(key, id, _, name)| { + ( + key, + id, + match name { + InterfaceName::Path(path) => path.join("::"), + InterfaceName::Remapped { name_at_root, .. } => name_at_root.clone(), + }, + ) + }) + .map(|(key, id, path)| { + format!( + " + {path}::Host{}", + concurrent_constraints( + resolve, + &self.opts, + Some(&resolve.name_world_key(key)), + *id + )("T") + ) + }) + .collect::>() + .concat(); + + ( + format!("U: Send{bounds}"), + format!("T: Send{bounds} + 'static,"), + ) + } else { + ( + format!("U: {}", self.world_host_traits(resolve, world).join(" + ")), + data_bounds.to_string(), + ) + }; if !self.opts.skip_mut_forwarding_impls { uwriteln!( @@ -1584,6 +1696,7 @@ impl Wasmtime { } fn generate_add_resource_to_linker( + qualifier: Option<&str>, src: &mut Source, opts: &Opts, wt: &str, @@ -1593,7 +1706,7 @@ impl Wasmtime { ) { let gate = FeatureGate::open(src, stability); let camel = name.to_upper_camel_case(); - if opts.async_.is_drop_async(name) { + if let CallStyle::Async = opts.drop_call_style(qualifier, name) { uwriteln!( src, "{inst}.resource_async( @@ -1664,8 +1777,9 @@ impl<'a> InterfaceGenerator<'a> { TypeDefKind::Result(r) => self.type_result(id, name, r, &ty.docs), TypeDefKind::List(t) => self.type_list(id, name, t, &ty.docs), TypeDefKind::Type(t) => self.type_alias(id, name, t, &ty.docs), - TypeDefKind::Future(_) => todo!("generate for future"), - TypeDefKind::Stream(_) => todo!("generate for stream"), + TypeDefKind::Future(_) => panic!("future types need not be defined"), + TypeDefKind::Stream(_) => panic!("stream types need not be defined"), + TypeDefKind::ErrorContext => panic!("the error-context type needs not be defined"), TypeDefKind::Handle(handle) => self.type_handle(id, name, handle, &ty.docs), TypeDefKind::Resource => self.type_resource(id, name, ty, &ty.docs), TypeDefKind::Unknown => unreachable!(), @@ -1709,10 +1823,11 @@ impl<'a> InterfaceGenerator<'a> { } // Generate resource trait - if self.gen.opts.async_.maybe_async() { + if let CallStyle::Async = self.gen.opts.call_style() { uwriteln!(self.src, "#[{wt}::component::__internal::async_trait]") } - uwriteln!(self.src, "pub trait Host{camel} {{"); + + uwriteln!(self.src, "pub trait Host{camel}: Sized {{"); let mut functions = match resource.owner { TypeOwner::World(id) => self.resolve.worlds[id] @@ -1739,12 +1854,29 @@ impl<'a> InterfaceGenerator<'a> { | FunctionKind::Constructor(resource) => id == resource, }); + let has_concurrent_function = functions.iter().any(|func| { + matches!( + self.gen + .opts + .import_call_style(self.qualifier().as_deref(), &func.name), + CallStyle::Concurrent + ) + }); + + if has_concurrent_function { + uwriteln!(self.src, "type {camel}Data;"); + } + for func in &functions { - self.generate_function_trait_sig(func); + self.generate_function_trait_sig(func, &format!("{camel}Data")); self.push_str(";\n"); } - if self.gen.opts.async_.is_drop_async(name) { + if let CallStyle::Async = self + .gen + .opts + .drop_call_style(self.qualifier().as_deref(), name) + { uwrite!(self.src, "async "); } uwrite!( @@ -1756,7 +1888,8 @@ impl<'a> InterfaceGenerator<'a> { // Generate impl HostResource for &mut HostResource if !self.gen.opts.skip_mut_forwarding_impls { - let (async_trait, maybe_send) = if self.gen.opts.async_.maybe_async() { + let (async_trait, maybe_send) = if let CallStyle::Async = self.gen.opts.call_style() + { ( format!("#[{wt}::component::__internal::async_trait]\n"), "+ Send", @@ -1764,27 +1897,51 @@ impl<'a> InterfaceGenerator<'a> { } else { (String::new(), "") }; + let maybe_maybe_sized = if has_concurrent_function { + "" + } else { + "+ ?Sized" + }; uwriteln!( self.src, - "{async_trait}impl <_T: Host{camel} + ?Sized {maybe_send}> Host{camel} for &mut _T {{" + "{async_trait}impl <_T: Host{camel} {maybe_maybe_sized} {maybe_send}> Host{camel} for &mut _T {{" ); + if has_concurrent_function { + uwriteln!(self.src, "type {camel}Data = _T::{camel}Data;"); + } for func in &functions { - self.generate_function_trait_sig(func); - uwrite!( - self.src, - "{{ Host{camel}::{}(*self,", - rust_function_name(func) - ); + let call_style = self + .gen + .opts + .import_call_style(self.qualifier().as_deref(), &func.name); + self.generate_function_trait_sig(func, &format!("{camel}Data")); + if let CallStyle::Concurrent = call_style { + uwrite!( + self.src, + "{{ <_T as Host{camel}>::{}(store,", + rust_function_name(func) + ); + } else { + uwrite!( + self.src, + "{{ Host{camel}::{}(*self,", + rust_function_name(func) + ); + } for (name, _) in func.params.iter() { uwrite!(self.src, "{},", to_rust_ident(name)); } uwrite!(self.src, ")"); - if self.gen.opts.async_.is_import_async(&func.name) { + if let CallStyle::Async = call_style { uwrite!(self.src, ".await"); } uwriteln!(self.src, "}}"); } - if self.gen.opts.async_.is_drop_async(name) { + if let CallStyle::Async = self + .gen + .opts + .drop_call_style(self.qualifier().as_deref(), name) + { uwriteln!(self.src, " async fn drop(&mut self, rep: {wt}::component::Resource<{camel}>) -> {wt}::Result<()> {{ Host{camel}::drop(*self, rep).await @@ -2087,10 +2244,9 @@ impl<'a> InterfaceGenerator<'a> { self.push_str( "fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n", ); - self.push_str("write!(f, \"{:?}\", self)"); + self.push_str("write!(f, \"{:?}\", self)\n"); self.push_str("}\n"); self.push_str("}\n"); - self.push_str("\n"); if cfg!(feature = "std") { self.push_str("impl"); @@ -2342,6 +2498,26 @@ impl<'a> InterfaceGenerator<'a> { } } + fn print_result_ty_tuple(&mut self, results: &Results, mode: TypeMode) { + self.push_str("("); + match results { + Results::Named(rs) if rs.is_empty() => self.push_str(")"), + Results::Named(rs) => { + for (i, (_, ty)) in rs.iter().enumerate() { + if i > 0 { + self.push_str(", ") + } + self.print_ty(ty, mode) + } + self.push_str(")"); + } + Results::Anon(ty) => { + self.print_ty(ty, mode); + self.push_str(",)"); + } + } + } + fn special_case_trappable_error( &mut self, func: &Function, @@ -2384,31 +2560,52 @@ impl<'a> InterfaceGenerator<'a> { let owner = TypeOwner::Interface(id); let wt = self.gen.wasmtime_path(); - let is_maybe_async = self.gen.opts.async_.maybe_async(); + let is_maybe_async = matches!(self.gen.opts.call_style(), CallStyle::Async); if is_maybe_async { uwriteln!(self.src, "#[{wt}::component::__internal::async_trait]") } // Generate the `pub trait` which represents the host functionality for // this import which additionally inherits from all resource traits // for this interface defined by `type_resource`. + uwrite!(self.src, "pub trait Host"); let mut host_supertraits = vec![]; if is_maybe_async { host_supertraits.push("Send".to_string()); } + let mut saw_resources = false; for (_, name) in get_resources(self.resolve, id) { + saw_resources = true; host_supertraits.push(format!("Host{}", name.to_upper_camel_case())); } + if saw_resources { + host_supertraits.push("Sized".to_string()); + } if !host_supertraits.is_empty() { uwrite!(self.src, ": {}", host_supertraits.join(" + ")); } uwriteln!(self.src, " {{"); + + let has_concurrent_function = iface.functions.iter().any(|(_, func)| { + matches!(func.kind, FunctionKind::Freestanding) + && matches!( + self.gen + .opts + .import_call_style(self.qualifier().as_deref(), &func.name), + CallStyle::Concurrent + ) + }); + + if has_concurrent_function { + self.push_str("type Data;\n"); + } + for (_, func) in iface.functions.iter() { match func.kind { FunctionKind::Freestanding => {} _ => continue, } - self.generate_function_trait_sig(func); + self.generate_function_trait_sig(func, "Data"); self.push_str(";\n"); } @@ -2456,13 +2653,32 @@ impl<'a> InterfaceGenerator<'a> { } uwriteln!(self.src, "}}"); - let (data_bounds, mut host_bounds) = if self.gen.opts.is_store_data_send() { - ("T: Send,", "Host + Send".to_string()) - } else { - ("", "Host".to_string()) + let (data_bounds, mut host_bounds, mut get_host_bounds) = match self.gen.opts.call_style() { + CallStyle::Async => ( + "T: Send,".to_string(), + "Host + Send".to_string(), + "Host + Send".to_string(), + ), + CallStyle::Concurrent => { + let constraints = concurrent_constraints( + self.resolve, + &self.gen.opts, + self.qualifier().as_deref(), + id, + ); + + ( + "T: Send + 'static,".to_string(), + format!("Host{} + Send", constraints("T")), + format!("Host{} + Send", constraints("D")), + ) + } + CallStyle::Sync => (String::new(), "Host".to_string(), "Host".to_string()), }; + for ty in required_conversion_traits { uwrite!(host_bounds, " + {ty}"); + uwrite!(get_host_bounds, " + {ty}"); } let (options_param, options_arg) = if self.gen.interface_link_options[&id].has_any() { @@ -2474,28 +2690,28 @@ impl<'a> InterfaceGenerator<'a> { uwriteln!( self.src, " - pub trait GetHost: - Fn(T) -> >::Host + pub trait GetHost: + Fn(T) -> >::Host + Send + Sync + Copy + 'static {{ - type Host: {host_bounds}; + type Host: {get_host_bounds}; }} - impl GetHost for F + impl GetHost for F where F: Fn(T) -> O + Send + Sync + Copy + 'static, - O: {host_bounds}, + O: {get_host_bounds}, {{ type Host = O; }} - pub fn add_to_linker_get_host( + pub fn add_to_linker_get_host GetHost<&'a mut T, T, Host: {host_bounds}>>( linker: &mut {wt}::component::Linker, {options_param} - host_getter: impl for<'a> GetHost<&'a mut T>, + host_getter: G, ) -> {wt}::Result<()> where {data_bounds} {{ @@ -2506,6 +2722,7 @@ impl<'a> InterfaceGenerator<'a> { for (ty, name) in get_resources(self.resolve, id) { Wasmtime::generate_add_resource_to_linker( + self.qualifier().as_deref(), &mut self.src, &self.gen.opts, &wt, @@ -2550,23 +2767,46 @@ impl<'a> InterfaceGenerator<'a> { (String::new(), "") }; + let maybe_maybe_sized = if has_concurrent_function { + "" + } else { + "+ ?Sized" + }; + uwriteln!( self.src, - "{async_trait}impl<_T: Host + ?Sized {maybe_send}> Host for &mut _T {{" + "{async_trait}impl<_T: Host {maybe_maybe_sized} {maybe_send}> Host for &mut _T {{" ); + + if has_concurrent_function { + self.push_str("type Data = _T::Data;\n"); + } + // Forward each method call to &mut T for (_, func) in iface.functions.iter() { match func.kind { FunctionKind::Freestanding => {} _ => continue, } - self.generate_function_trait_sig(func); - uwrite!(self.src, "{{ Host::{}(*self,", rust_function_name(func)); + let call_style = self + .gen + .opts + .import_call_style(self.qualifier().as_deref(), &func.name); + self.generate_function_trait_sig(func, "Data"); + if let CallStyle::Concurrent = call_style { + uwrite!( + self.src, + "{{ <_T as Host>::{}(store,", + rust_function_name(func) + ); + } else { + uwrite!(self.src, "{{ Host::{}(*self,", rust_function_name(func)); + } for (name, _) in func.params.iter() { uwrite!(self.src, "{},", to_rust_ident(name)); } uwrite!(self.src, ")"); - if self.gen.opts.async_.is_import_async(&func.name) { + if let CallStyle::Async = call_style { uwrite!(self.src, ".await"); } uwriteln!(self.src, "}}"); @@ -2586,15 +2826,24 @@ impl<'a> InterfaceGenerator<'a> { } } + fn qualifier(&self) -> Option { + self.current_interface + .map(|(_, key, _)| self.resolve.name_world_key(key)) + } + fn generate_add_function_to_linker(&mut self, owner: TypeOwner, func: &Function, linker: &str) { let gate = FeatureGate::open(&mut self.src, &func.stability); uwrite!( self.src, "{linker}.{}(\"{}\", ", - if self.gen.opts.async_.is_import_async(&func.name) { - "func_wrap_async" - } else { - "func_wrap" + match self + .gen + .opts + .import_call_style(self.qualifier().as_deref(), &func.name) + { + CallStyle::Sync => "func_wrap", + CallStyle::Async => "func_wrap_async", + CallStyle::Concurrent => "func_wrap_concurrent", }, func.name ); @@ -2618,16 +2867,20 @@ impl<'a> InterfaceGenerator<'a> { self.src.push_str(") : ("); for (_, ty) in func.params.iter() { - // Lift is required to be impled for this type, so we can't use + // Lift is required to be implied for this type, so we can't use // a borrowed type: self.print_ty(ty, TypeMode::Owned); self.src.push_str(", "); } - self.src.push_str(") |"); - self.src.push_str(" {\n"); + self.src.push_str(")| {\n"); + + let style = self + .gen + .opts + .import_call_style(self.qualifier().as_deref(), &func.name); if self.gen.opts.tracing { - if self.gen.opts.async_.is_import_async(&func.name) { + if let CallStyle::Async = style { self.src.push_str("use tracing::Instrument;\n"); } @@ -2653,7 +2906,7 @@ impl<'a> InterfaceGenerator<'a> { ); } - if self.gen.opts.async_.is_import_async(&func.name) { + if let CallStyle::Async = &style { uwriteln!( self.src, " {wt}::component::__internal::Box::new(async move {{ " @@ -2685,8 +2938,11 @@ impl<'a> InterfaceGenerator<'a> { ); } - self.src - .push_str("let host = &mut host_getter(caller.data_mut());\n"); + self.src.push_str(if let CallStyle::Concurrent = &style { + "let host = caller;\n" + } else { + "let host = &mut host_getter(caller.data_mut());\n" + }); let func_name = rust_function_name(func); let host_trait = match func.kind { FunctionKind::Freestanding => match owner { @@ -2705,15 +2961,32 @@ impl<'a> InterfaceGenerator<'a> { format!("Host{resource}") } }; - uwrite!(self.src, "let r = {host_trait}::{func_name}(host, "); + + if let CallStyle::Concurrent = &style { + uwrite!( + self.src, + "let r = ::{func_name}(host, " + ); + } else { + uwrite!(self.src, "let r = {host_trait}::{func_name}(host, "); + } for (i, _) in func.params.iter().enumerate() { uwrite!(self.src, "arg{},", i); } - if self.gen.opts.async_.is_import_async(&func.name) { - uwrite!(self.src, ").await;\n"); - } else { - uwrite!(self.src, ");\n"); + self.src.push_str(match &style { + CallStyle::Sync | CallStyle::Concurrent => ");\n", + CallStyle::Async => ").await;\n", + }); + + if let CallStyle::Concurrent = &style { + self.src.push_str( + "Box::pin(async move { + let fun = r.await; + Box::new(move |mut caller: wasmtime::StoreContextMut<'_, T>| { + let r = fun(caller); + ", + ); } if self.gen.opts.tracing { @@ -2755,29 +3028,53 @@ impl<'a> InterfaceGenerator<'a> { uwrite!(self.src, "r\n"); } - if self.gen.opts.async_.is_import_async(&func.name) { - // Need to close Box::new and async block - - if self.gen.opts.tracing { - self.src.push_str("}.instrument(span))\n"); - } else { - self.src.push_str("})\n"); + match &style { + CallStyle::Sync => (), + CallStyle::Async => { + if self.gen.opts.tracing { + self.src.push_str("}.instrument(span))\n"); + } else { + self.src.push_str("})\n"); + } + } + CallStyle::Concurrent => { + let old_source = mem::take(&mut self.src); + self.print_result_ty_tuple(&func.results, TypeMode::Owned); + let result_type = String::from(mem::replace(&mut self.src, old_source)); + let box_fn = format!( + "Box) -> \ + wasmtime::Result<{result_type}>>" + ); + uwriteln!( + self.src, + " }}) as {box_fn} + }}) as ::std::pin::Pin \ + + Send + Sync + 'static>> + " + ); } } - self.src.push_str("}\n"); } - fn generate_function_trait_sig(&mut self, func: &Function) { + fn generate_function_trait_sig(&mut self, func: &Function, data: &str) { let wt = self.gen.wasmtime_path(); self.rustdoc(&func.docs); - if self.gen.opts.async_.is_import_async(&func.name) { + let style = self + .gen + .opts + .import_call_style(self.qualifier().as_deref(), &func.name); + if let CallStyle::Async = &style { self.push_str("async "); } self.push_str("fn "); self.push_str(&rust_function_name(func)); - self.push_str("(&mut self, "); + self.push_str(&if let CallStyle::Concurrent = &style { + format!("(store: wasmtime::StoreContextMut<'_, Self::{data}>, ") + } else { + "(&mut self, ".to_string() + }); for (name, param) in func.params.iter() { let name = to_rust_ident(name); self.push_str(&name); @@ -2788,6 +3085,10 @@ impl<'a> InterfaceGenerator<'a> { self.push_str(")"); self.push_str(" -> "); + if let CallStyle::Concurrent = &style { + uwrite!(self.src, "impl ::std::future::Future) -> "); + } + if !self.gen.opts.trappable_imports.can_trap(func) { self.print_result_ty(&func.results, TypeMode::Owned); } else if let Some((r, _id, error_typename)) = self.special_case_trappable_error(func) { @@ -2810,6 +3111,10 @@ impl<'a> InterfaceGenerator<'a> { self.print_result_ty(&func.results, TypeMode::Owned); self.push_str(">"); } + + if let CallStyle::Concurrent = &style { + self.push_str(" + 'static> + Send + Sync + 'static where Self: Sized"); + } } fn extract_typed_function(&mut self, func: &Function) -> (String, String) { @@ -2840,12 +3145,10 @@ impl<'a> InterfaceGenerator<'a> { ) { // Exports must be async if anything could be async, it's just imports // that get to be optionally async/sync. - let is_async = self.gen.opts.async_.maybe_async(); - - let (async_, async__, await_) = if is_async { - ("async", "_async", ".await") - } else { - ("", "", "") + let style = self.gen.opts.call_style(); + let (async_, async__, await_) = match &style { + CallStyle::Async | CallStyle::Concurrent => ("async", "_async", ".await"), + CallStyle::Sync => ("", "", ""), }; self.rustdoc(&func.docs); @@ -2857,23 +3160,28 @@ impl<'a> InterfaceGenerator<'a> { func.item_name().to_snake_case(), ); + let param_mode = if let CallStyle::Concurrent = &style { + TypeMode::Owned + } else { + TypeMode::AllBorrowed("'_") + }; + for (i, param) in func.params.iter().enumerate() { uwrite!(self.src, "arg{}: ", i); - self.print_ty(¶m.1, TypeMode::AllBorrowed("'_")); + self.print_ty(¶m.1, param_mode); self.push_str(","); } uwrite!(self.src, ") -> {wt}::Result<"); self.print_result_ty(&func.results, TypeMode::Owned); - if is_async { - uwriteln!(self.src, "> where ::Data: Send {{"); - } else { - self.src.push_str("> {\n"); - } + uwrite!( + self.src, + "> where ::Data: Send + 'static {{\n" + ); if self.gen.opts.tracing { - if is_async { + if let CallStyle::Async = &style { self.src.push_str("use tracing::Instrument;\n"); } @@ -2893,7 +3201,7 @@ impl<'a> InterfaceGenerator<'a> { func.name, )); - if !is_async { + if !matches!(&style, CallStyle::Async) { self.src.push_str( " let _enter = span.enter(); @@ -2905,7 +3213,7 @@ impl<'a> InterfaceGenerator<'a> { self.src.push_str("let callee = unsafe {\n"); uwrite!(self.src, "{wt}::component::TypedFunc::<("); for (_, ty) in func.params.iter() { - self.print_ty(ty, TypeMode::AllBorrowed("'_")); + self.print_ty(ty, param_mode); self.push_str(", "); } self.src.push_str("), ("); @@ -2935,14 +3243,14 @@ impl<'a> InterfaceGenerator<'a> { uwrite!(self.src, "arg{}, ", i); } - let instrument = if is_async && self.gen.opts.tracing { + let instrument = if matches!(&style, CallStyle::Async) && self.gen.opts.tracing { ".instrument(span.clone())" } else { "" }; uwriteln!(self.src, ")){instrument}{await_}?;"); - let instrument = if is_async && self.gen.opts.tracing { + let instrument = if matches!(&style, CallStyle::Async) && self.gen.opts.tracing { ".instrument(span)" } else { "" @@ -3243,8 +3551,9 @@ fn type_contains_lists(ty: Type, resolve: &Resolve) -> bool { | TypeDefKind::Unknown | TypeDefKind::Flags(_) | TypeDefKind::Handle(_) - | TypeDefKind::Enum(_) => false, - TypeDefKind::Option(ty) => type_contains_lists(*ty, resolve), + | TypeDefKind::Enum(_) + | TypeDefKind::ErrorContext => false, + TypeDefKind::Option(ty) | TypeDefKind::Stream(ty) => type_contains_lists(*ty, resolve), TypeDefKind::Result(Result_ { ok, err }) => { option_type_contains_lists(*ok, resolve) || option_type_contains_lists(*err, resolve) @@ -3263,10 +3572,6 @@ fn type_contains_lists(ty: Type, resolve: &Resolve) -> bool { .any(|case| option_type_contains_lists(case.ty, resolve)), TypeDefKind::Type(ty) => type_contains_lists(*ty, resolve), TypeDefKind::Future(ty) => option_type_contains_lists(*ty, resolve), - TypeDefKind::Stream(Stream { element, end }) => { - option_type_contains_lists(*element, resolve) - || option_type_contains_lists(*end, resolve) - } TypeDefKind::List(_) => true, }, @@ -3358,3 +3663,61 @@ fn get_world_resources<'a>( _ => None, }) } + +fn concurrent_constraints<'a>( + resolve: &'a Resolve, + opts: &Opts, + qualifier: Option<&str>, + id: InterfaceId, +) -> impl Fn(&str) -> String + 'a { + let has_concurrent_function = resolve.interfaces[id].functions.iter().any(|(_, func)| { + matches!(func.kind, FunctionKind::Freestanding) + && matches!( + opts.import_call_style(qualifier, &func.name), + CallStyle::Concurrent + ) + }); + + let types = resolve.interfaces[id] + .types + .iter() + .filter_map(|(name, ty)| match resolve.types[*ty].kind { + TypeDefKind::Resource + if resolve.interfaces[id] + .functions + .values() + .any(|func| match func.kind { + FunctionKind::Freestanding => false, + FunctionKind::Method(resource) + | FunctionKind::Static(resource) + | FunctionKind::Constructor(resource) => { + *ty == resource + && matches!( + opts.import_call_style(qualifier, &func.name), + CallStyle::Concurrent + ) + } + }) => + { + Some(format!("{}Data", name.to_upper_camel_case())) + } + _ => None, + }) + .chain(has_concurrent_function.then_some("Data".to_string())) + .collect::>(); + + move |v| { + if types.is_empty() { + String::new() + } else { + format!( + "<{}>", + types + .iter() + .map(|s| format!("{s} = {v}")) + .collect::>() + .join(", ") + ) + } + } +} diff --git a/crates/wit-bindgen/src/rust.rs b/crates/wit-bindgen/src/rust.rs index d676d095f60d..77a0cb2246f1 100644 --- a/crates/wit-bindgen/src/rust.rs +++ b/crates/wit-bindgen/src/rust.rs @@ -120,7 +120,7 @@ pub trait RustGenerator<'a> { needs_generics(resolve, &resolve.types[*t].kind) } TypeDefKind::Type(Type::String) => true, - TypeDefKind::Type(_) => false, + TypeDefKind::Type(_) | TypeDefKind::ErrorContext => false, TypeDefKind::Unknown => unreachable!(), } } @@ -166,17 +166,18 @@ pub trait RustGenerator<'a> { panic!("unsupported anonymous type reference: enum") } TypeDefKind::Future(ty) => { - self.push_str("Future<"); + self.push_str("wasmtime::component::FutureReader<"); self.print_optional_ty(ty.as_ref(), mode); self.push_str(">"); } - TypeDefKind::Stream(stream) => { - self.push_str("Stream<"); - self.print_optional_ty(stream.element.as_ref(), mode); - self.push_str(","); - self.print_optional_ty(stream.end.as_ref(), mode); + TypeDefKind::Stream(ty) => { + self.push_str("wasmtime::component::StreamReader<"); + self.print_ty(ty, mode); self.push_str(">"); } + TypeDefKind::ErrorContext => { + self.push_str("wasmtime::component::ErrorContext"); + } TypeDefKind::Handle(handle) => { self.print_handle(handle); diff --git a/crates/wit-bindgen/src/types.rs b/crates/wit-bindgen/src/types.rs index c45d3d80f159..b29c0da728c1 100644 --- a/crates/wit-bindgen/src/types.rs +++ b/crates/wit-bindgen/src/types.rs @@ -148,10 +148,7 @@ impl Types { info = self.type_info(resolve, ty); info.has_list = true; } - TypeDefKind::Type(ty) => { - info = self.type_info(resolve, ty); - } - TypeDefKind::Option(ty) => { + TypeDefKind::Type(ty) | TypeDefKind::Option(ty) | TypeDefKind::Stream(ty) => { info = self.type_info(resolve, ty); } TypeDefKind::Result(r) => { @@ -161,12 +158,8 @@ impl Types { TypeDefKind::Future(ty) => { info = self.optional_type_info(resolve, ty.as_ref()); } - TypeDefKind::Stream(stream) => { - info = self.optional_type_info(resolve, stream.element.as_ref()); - info |= self.optional_type_info(resolve, stream.end.as_ref()); - } TypeDefKind::Handle(_) => info.has_handle = true, - TypeDefKind::Resource => {} + TypeDefKind::Resource | TypeDefKind::ErrorContext => {} TypeDefKind::Unknown => unreachable!(), } self.type_info.insert(ty, info); diff --git a/tests/all/pooling_allocator.rs b/tests/all/pooling_allocator.rs index 14474401a735..126712670198 100644 --- a/tests/all/pooling_allocator.rs +++ b/tests/all/pooling_allocator.rs @@ -880,7 +880,7 @@ fn component_instance_size_limit() -> Result<()> { Ok(_) => panic!("should have hit limit"), Err(e) => assert_eq!( e.to_string(), - "instance allocation for this component requires 64 bytes of `VMComponentContext` space \ + "instance allocation for this component requires 272 bytes of `VMComponentContext` space \ which exceeds the configured maximum of 1 bytes" ), }