Skip to content

Commit

Permalink
Simplify resource binding APIs.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Feb 8, 2025
1 parent d0fd807 commit 202c5dc
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 367 deletions.
10 changes: 4 additions & 6 deletions examples/othello/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,10 @@ impl TensorOpExt for TensorOp {
);

let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
.touch(2, input.resource_key())
.touch(3, output.resource_key())
.bind(0, input.meta_binding())
.bind(1, output.meta_binding())
.bind(2, input.binding())
.bind(3, output.binding())
.bind_meta(0, &input)
.bind_meta(1, &output)
.bind(2, &input)
.bind(3, &output)
.build()];

Ok(Self::Atom {
Expand Down
10 changes: 4 additions & 6 deletions examples/puzzle15/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,10 @@ impl TensorOpExt for TensorOp {
);

let bindings = vec![BindGroupBuilder::new(&key, context, &pipeline.layout)
.touch(2, input.resource_key())
.touch(3, output.resource_key())
.bind(0, input.meta_binding())
.bind(1, output.meta_binding())
.bind(2, input.binding())
.bind(3, output.binding())
.bind_meta(0, &input)
.bind_meta(1, &output)
.bind(2, &input)
.bind(3, &output)
.build()];

Ok(Self::Atom {
Expand Down
26 changes: 18 additions & 8 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ use web_rwkv_derive::{Deref, DerefMut};
use wgpu::{
util::{BufferInitDescriptor, DeviceExt},
Adapter, BindGroup, BindGroupDescriptor, BindGroupEntry, BindGroupLayout,
BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingResource, Buffer, BufferDescriptor,
BufferUsages, ComputePipeline, ComputePipelineDescriptor, Device, DeviceDescriptor, Features,
Instance, Limits, MemoryHints, PipelineLayoutDescriptor, PowerPreference, Queue,
RequestAdapterOptions, ShaderModuleDescriptor,
BindGroupLayoutDescriptor, BindGroupLayoutEntry, Buffer, BufferDescriptor, BufferUsages,
ComputePipeline, ComputePipelineDescriptor, Device, DeviceDescriptor, Features, Instance,
Limits, MemoryHints, PipelineLayoutDescriptor, PowerPreference, Queue, RequestAdapterOptions,
ShaderModuleDescriptor,
};

use crate::tensor::{
cache::{ResourceCache, SharedResourceCache},
shape::{IntoBytes, Shape},
ResourceKey, View,
ResourceKey, TensorResource, View,
};

pub trait InstanceExt {
Expand Down Expand Up @@ -269,14 +269,24 @@ impl<'a, 'b> BindGroupBuilder<'a, 'b> {

/// Mark a resource as being touched.
/// How resources are touched determines whether the bind group can be found in cache.
pub fn touch(mut self, binding: u32, resource: ResourceKey) -> Self {
self.key.bindings.push((binding, resource));
pub fn touch(mut self, binding: u32, tensor: &'a impl TensorResource) -> Self {
let key = tensor.resource_key();
self.key.bindings.push((binding, key));
self
}

/// Insert an entry into the bind group.
pub fn bind(mut self, binding: u32, resource: BindingResource<'a>) -> Self {
pub fn bind(mut self, binding: u32, tensor: &'a impl TensorResource) -> Self {
let resource = tensor.binding();
self.entries.push(BindGroupEntry { binding, resource });
self.touch(binding, tensor)
}

/// Insert an entry into the bind group.
pub fn bind_meta(mut self, binding: u32, tensor: &'a impl TensorResource) -> Self {
let resource = tensor.meta_binding();
self.entries.push(BindGroupEntry { binding, resource });
// self.touch(binding, tensor)
self
}

Expand Down
75 changes: 43 additions & 32 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ pub trait TensorReshape: Sized {
) -> Result<Self, TensorError>;
}

pub trait TensorResource {
/// Retrieve the key identifying a resource.
fn resource_key(&self) -> ResourceKey;
/// Binding for metadata of the tensor (shape, stride, etc.).
fn meta_binding(&self) -> BindingResource;
/// Binding for actual data of the tensor.
fn binding(&self) -> BindingResource;
}

/// A tensor on either CPU or GPU.
#[derive(Debug)]
pub struct Tensor<D: Device, T: Scalar> {
Expand Down Expand Up @@ -245,26 +254,6 @@ pub mod kind {
pub struct ReadWrite;
}

impl TensorGpuData {
#[inline]
pub fn meta_binding(&self) -> BindingResource {
BindingResource::Buffer(BufferBinding {
buffer: &self.meta,
offset: 0,
size: None,
})
}

#[inline]
pub fn binding(&self) -> BindingResource {
BindingResource::Buffer(BufferBinding {
buffer: &self.buffer,
offset: 0,
size: None,
})
}
}

pub type TensorCpu<T> = Tensor<Cpu<T>, T>;
pub type TensorGpu<T, K> = Tensor<Gpu<K>, T>;

Expand Down Expand Up @@ -509,7 +498,7 @@ impl<T: Scalar, K: Kind> TensorReshape for TensorGpu<T, K> {
}
}

impl<T: Scalar, K: Kind> TensorGpu<T, K> {
impl<T: Scalar, K: Kind> TensorResource for TensorGpu<T, K> {
#[inline]
fn resource_key(&self) -> ResourceKey {
let id = self.id;
Expand All @@ -521,6 +510,26 @@ impl<T: Scalar, K: Kind> TensorGpu<T, K> {
ResourceKey { id, view }
}

#[inline]
fn meta_binding(&self) -> BindingResource {
BindingResource::Buffer(BufferBinding {
buffer: &self.meta,
offset: 0,
size: None,
})
}

#[inline]
fn binding(&self) -> BindingResource {
BindingResource::Buffer(BufferBinding {
buffer: &self.buffer,
offset: 0,
size: None,
})
}
}

impl<T: Scalar, K: Kind> TensorGpu<T, K> {
#[cfg(not(target_arch = "wasm32"))]
pub fn back_in_place(&self) -> TensorCpu<T> {
use crate::context::ContextEvent;
Expand Down Expand Up @@ -959,9 +968,19 @@ impl<T: Scalar> TensorGpuView<'_, T> {
pub fn layout(&self, binding: u32, read_only: bool) -> BindGroupLayoutEntry {
self.tensor.layout(binding, read_only)
}
}

impl<T: Scalar> TensorResource for TensorGpuView<'_, T> {
#[inline]
pub fn meta_binding(&self) -> BindingResource {
fn resource_key(&self) -> ResourceKey {
ResourceKey {
id: self.tensor.id,
view: self.view,
}
}

#[inline]
fn meta_binding(&self) -> BindingResource {
BindingResource::Buffer(BufferBinding {
buffer: &self.meta,
offset: 0,
Expand All @@ -970,16 +989,8 @@ impl<T: Scalar> TensorGpuView<'_, T> {
}

#[inline]
pub fn binding(&self) -> BindingResource {
self.data().binding()
}

#[inline]
pub fn resource_key(&self) -> ResourceKey {
ResourceKey {
id: self.tensor.id,
view: self.view,
}
fn binding(&self) -> BindingResource {
self.tensor.binding()
}
}

Expand Down
Loading

0 comments on commit 202c5dc

Please sign in to comment.