diff --git a/packages/typegpu/src/core/function/tgpuFn.ts b/packages/typegpu/src/core/function/tgpuFn.ts index 247f7eb5..3c2222e7 100644 --- a/packages/typegpu/src/core/function/tgpuFn.ts +++ b/packages/typegpu/src/core/function/tgpuFn.ts @@ -1,7 +1,7 @@ import type { Infer } from '../../data'; import type { Exotic, ExoticArray } from '../../data/exotic'; import type { AnyWgslData } from '../../data/wgslTypes'; -import { inGPUMode } from '../../gpuMode'; +import { getResolutionCtx, inGPUMode } from '../../gpuMode'; import type { TgpuNamable } from '../../namable'; import type { Labelled, @@ -95,7 +95,15 @@ export function fn< does( implementation: Implementation, InferReturn>, ): TgpuFn { - return createFn(this, implementation as Implementation); + const fn = createFn(this, implementation as Implementation); + + const ctx = getResolutionCtx(); + if (ctx) { + // Creating a function during resolution. + const snapshot = Array.from(ctx.getSlotSnapshot()); + return createBoundFunction(fn, snapshot); + } + return fn; }, }; } diff --git a/packages/typegpu/src/resolutionCtx.ts b/packages/typegpu/src/resolutionCtx.ts index 5b5e5c7a..22f13a43 100644 --- a/packages/typegpu/src/resolutionCtx.ts +++ b/packages/typegpu/src/resolutionCtx.ts @@ -58,7 +58,7 @@ type ItemLayer = { type SlotBindingLayer = { type: 'slotBinding'; - bindingMap: WeakMap, unknown>; + bindingMap: Map, unknown>; }; type FunctionScopeLayer = { @@ -105,7 +105,7 @@ class ItemStateStack { pushSlotBindings(pairs: SlotValuePair[]) { this._stack.push({ type: 'slotBinding', - bindingMap: new WeakMap(pairs), + bindingMap: new Map(pairs), }); } @@ -154,6 +154,30 @@ class ItemStateStack { return slot.defaultValue; } + /** + * @returns A flattened list of slot->value pairs that are bound at the + * time of calling the function. + */ + getSlotSnapshot(): Iterable> { + const flattened = new Map, unknown>(); + + for (let i = this._stack.length - 1; i >= 0; --i) { + const layer = this._stack[i]; + + if (layer?.type === 'slotBinding') { + for (const [slot, value] of layer.bindingMap.entries()) { + // Since we're going bottom-up (or fine->coarse), we only + // acknowledge a slot if it hasn't been seen yet. + if (!flattened.has(slot)) { + flattened.set(slot, value); + } + } + } + } + + return flattened.entries(); + } + getResourceById(id: string): Resource | undefined { for (let i = this._stack.length - 1; i >= 0; --i) { const layer = this._stack[i]; @@ -372,6 +396,10 @@ class ResolutionCtxImpl implements ResolutionCtx { return value; } + getSlotSnapshot(): Iterable> { + return this._itemStateStack.getSlotSnapshot(); + } + withSlots(pairs: SlotValuePair[], callback: () => T): T { this._itemStateStack.pushSlotBindings(pairs); @@ -547,6 +575,22 @@ class ResolutionCtxImpl implements ResolutionCtx { `Value ${value} (as json: ${JSON.stringify(value)}) of schema ${schema} is not resolvable to WGSL`, ); } + + resolveCall(fn: unknown, args: unknown[]): Resource { + if (isProviding(fn)) { + return this.withSlots(fn['~providing'].pairs, () => + this.resolveCall(fn['~providing'].inner, args), + ); + } + + if (typeof fn === 'function') { + const result = fn(...args); + // TODO: Make function calls return resources instead of just values. + return { value: result, dataType: UnknownData }; + } + + throw new Error(`Cannot call ${fn} (as json: ${JSON.stringify(fn)}).`); + } } export interface ResolutionResult { diff --git a/packages/typegpu/src/smol/wgslGenerator.ts b/packages/typegpu/src/smol/wgslGenerator.ts index c3d5394b..7a3933dd 100644 --- a/packages/typegpu/src/smol/wgslGenerator.ts +++ b/packages/typegpu/src/smol/wgslGenerator.ts @@ -5,7 +5,6 @@ import { type ResolutionCtx, type Resource, UnknownData, - type Wgsl, isWgsl, } from '../types'; @@ -180,13 +179,7 @@ function generateExpression( }; } - // Assuming that `id` is callable - // TODO: Pass in resources, not just values. - const result = (idValue as unknown as (...args: unknown[]) => unknown)( - ...argValues, - ) as Wgsl; - // TODO: Make function calls return resources instead of just values. - return { value: result, dataType: UnknownData }; + return ctx.resolveCall(idValue, argValues); } assertExhaustive(expression); diff --git a/packages/typegpu/src/types.ts b/packages/typegpu/src/types.ts index d791be61..21742792 100644 --- a/packages/typegpu/src/types.ts +++ b/packages/typegpu/src/types.ts @@ -109,6 +109,12 @@ export interface ResolutionCtx { binding: number; }; + /** + * @returns A flattened list of slot->value pairs that are bound at the + * time of calling the function. + */ + getSlotSnapshot(): Iterable>; + withSlots(pairs: SlotValuePair[], callback: () => T): T; /** @@ -119,6 +125,7 @@ export interface ResolutionCtx { resolve(item: unknown): string; resolveValue(value: Infer, schema: T): string; + resolveCall(fn: unknown, args: unknown[]): Resource; transpileFn(fn: string): { argNames: string[]; diff --git a/packages/typegpu/tests/derived.test.ts b/packages/typegpu/tests/derived.test.ts index a1793b5b..ac858269 100644 --- a/packages/typegpu/tests/derived.test.ts +++ b/packages/typegpu/tests/derived.test.ts @@ -183,11 +183,10 @@ describe('TgpuDerived', () => { }); it('allows slot bindings to pass downstream from derived (#697)', () => { - const utgpu = tgpu['~unstable']; - const valueSlot = utgpu.slot(1).$name('valueSlot'); + const valueSlot = tgpu['~unstable'].slot(1).$name('valueSlot'); - const derivedFn = utgpu.derived(() => { - return utgpu + const derivedFn = tgpu['~unstable'].derived(() => { + return tgpu['~unstable'] .fn([], d.f32) .does(() => valueSlot.value) .$name('innerFn'); @@ -195,7 +194,7 @@ describe('TgpuDerived', () => { const derivedFnWith2 = derivedFn.with(valueSlot, 2); - const mainFn = utgpu + const mainFn = tgpu['~unstable'] .fn([]) .does(() => { derivedFn.value();