Skip to content

Commit

Permalink
Function values in extra buffer, fix reloading
Browse files Browse the repository at this point in the history
  • Loading branch information
mhochsteger committed Oct 17, 2024
1 parent 811909e commit 5dae764
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 28 deletions.
60 changes: 37 additions & 23 deletions webgpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,32 +68,46 @@ def render(time):
def cleanup():
print("cleanup")
global gpu, input_handler, mesh_object
if input_handler is not None:
input_handler.unregister_callbacks()
del mesh_object
del input_handler
del gpu
if "input_handler" in globals():
if input_handler is not None:
input_handler.unregister_callbacks()
del mesh_object
del input_handler
del gpu


def reload_package(package_name):
"""Reload python package and all submodules (searches in modules for references to other submodules)"""
import importlib
import os
import types

package = importlib.import_module(package_name)
assert hasattr(package, "__package__")
file_name = package.__file__
package_dir = os.path.dirname(file_name) + os.sep
reloaded_modules = {file_name: package}

def reload_recursive(module):
module = importlib.reload(module)

for var in vars(module).values():
if isinstance(var, types.ModuleType):
file_name = getattr(var, "__file__", None)
if file_name is not None and file_name.startswith(package_dir):
if file_name not in reloaded_modules:
reloaded_modules[file_name] = reload_recursive(var)

return module

reload_recursive(package)
return reloaded_modules


async def reload():
print("reload")
cleanup()
import glob
import importlib
import os
reload_package("webgpu")
from webgpu.main import main

import webgpu
import webgpu.colormap
import webgpu.gpu
import webgpu.main
import webgpu.mesh
import webgpu.utils

dirname = os.path.dirname(__file__)
for filename in glob.glob(os.path.join(dirname, "*.py")):
if filename.endswith("__init__.py"):
continue
module_name = os.path.basename(filename)[:-3]
webgpu.__dict__[module_name] = importlib.reload(webgpu.__dict__[module_name])
webgpu = importlib.reload(webgpu)
await webgpu.main.main()
await main()
7 changes: 5 additions & 2 deletions webgpu/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,20 @@ def _create_buffers(self):

trigs = []
edges = []
trig_function_values = []
for t in m.Elements2D():
for i in range(3):
trigs.append(t.vertices[i].nr - 1)
edges.append(t.vertices[i].nr - 1)
edges.append(t.vertices[(i + 1) % 3].nr - 1)
trig_function_values.append(vertices[4 * (t.vertices[i].nr - 1)])
trigs.append(t.index)

data = {
"vertices": js.Float32Array.new(vertices),
"edges": js.Int32Array.new(edges),
"trigs": js.Int32Array.new(trigs),
"trig_function_values": js.Float32Array.new(trig_function_values),
}

buffers = {}
Expand All @@ -59,7 +62,7 @@ def _create_buffers(self):

def get_binding_layout(self):
layouts = []
for name in ["vertices", "edges", "trigs"]:
for name in self._buffers.keys():
binding = getattr(Binding, name.upper())
layouts.append(
{
Expand All @@ -72,7 +75,7 @@ def get_binding_layout(self):

def get_binding(self):
resources = []
for name in ["vertices", "edges", "trigs"]:
for name in self._buffers.keys():
binding = getattr(Binding, name.upper())
resources.append(
{"binding": binding, "resource": {"buffer": self._buffers[name]}}
Expand Down
7 changes: 4 additions & 3 deletions webgpu/shader.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct Uniforms {
@group(0) @binding(3) var<storage> vertices : array<vec3<f32>>;
@group(0) @binding(4) var<storage> edges : array<Edge>;
@group(0) @binding(5) var<storage> trigs : array<Trig>;
@group(0) @binding(6) var<storage> trig_function_values : array<f32>;

struct VertexOutput1d {
@builtin(position) fragPosition: vec4<f32>,
Expand Down Expand Up @@ -88,9 +89,9 @@ fn mainVertexEdge(@builtin(vertex_index) vertexId: u32, @builtin(instance_index)
@fragment
fn mainFragmentTrig(@location(0) p: vec3<f32>, @location(1) lam: vec2<f32>, @location(2) id: u32) -> @location(0) vec4<f32> {
let verts = trigs[id].v;
let v0 = vertices[ verts[0] ].x;
let v1 = vertices[ verts[1] ].x;
let v2 = vertices[ verts[2] ].x;
let v0 = trig_function_values[ 3 * id ];
let v1 = trig_function_values[ 3 * id + 1];
let v2 = trig_function_values[ 3 * id + 2];

checkClipping(p);

Expand Down
1 change: 1 addition & 0 deletions webgpu/uniforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class Binding:
VERTICES = 3
EDGES = 4
TRIGS = 5
TRIG_FUNCTION_VALUES = 6


class ClippingPlaneUniform(ct.Structure):
Expand Down

0 comments on commit 5dae764

Please sign in to comment.