From df8501a20ed9a2e9eda1e7a7200da9279c438cdf Mon Sep 17 00:00:00 2001 From: Donovan Hutchence Date: Wed, 11 Sep 2024 16:13:24 +0100 Subject: [PATCH] Support GS spherical harmonics (#6946) --- src/framework/parsers/ply.js | 28 +---- src/scene/gsplat/gsplat-data.js | 44 ------- src/scene/gsplat/gsplat-material.js | 8 ++ src/scene/gsplat/gsplat.js | 127 ++++++++++++++++++++ src/scene/gsplat/shader-generator-gsplat.js | 118 +++++++++++++++++- 5 files changed, 251 insertions(+), 74 deletions(-) diff --git a/src/framework/parsers/ply.js b/src/framework/parsers/ply.js index 526dd159fec..b3c4051294a 100644 --- a/src/framework/parsers/ply.js +++ b/src/framework/parsers/ply.js @@ -1,7 +1,6 @@ import { GSplatData } from '../../scene/gsplat/gsplat-data.js'; import { GSplatCompressedData } from '../../scene/gsplat/gsplat-compressed-data.js'; import { GSplatResource } from './gsplat-resource.js'; -import { Mat4 } from '../../core/math/mat4.js'; /** * @import { AssetRegistry } from '../asset/asset-registry.js' @@ -461,24 +460,8 @@ const readPly = async (reader, propertyFilter = null) => { return new GSplatData(elements); }; -// filter out element data we're not going to use -const defaultElements = [ - 'x', 'y', 'z', - 'f_dc_0', 'f_dc_1', 'f_dc_2', 'opacity', - 'rot_0', 'rot_1', 'rot_2', 'rot_3', - 'scale_0', 'scale_1', 'scale_2', - // compressed format elements - 'min_x', 'min_y', 'min_z', - 'max_x', 'max_y', 'max_z', - 'min_scale_x', 'min_scale_y', 'min_scale_z', - 'max_scale_x', 'max_scale_y', 'max_scale_z', - 'packed_position', 'packed_rotation', 'packed_scale', 'packed_color' -]; - -const defaultElementsSet = new Set(defaultElements); -const defaultElementFilter = val => defaultElementsSet.has(val); - -const mat4 = new Mat4(); +// by default load everything +const defaultElementFilter = val => true; class PlyParser { /** @type {GraphicsDevice} */ @@ -517,13 +500,6 @@ class PlyParser { readPly(response.body.getReader(), asset.data.elementFilter ?? defaultElementFilter) .then((gsplatData) => { if (!gsplatData.isCompressed) { - - // perform Z scale - if (asset.data.performZScale ?? true) { - mat4.setScale(-1, -1, 1); - gsplatData.transform(mat4); - } - // reorder data if (asset.data.reorder ?? true) { gsplatData.reorderData(); diff --git a/src/scene/gsplat/gsplat-data.js b/src/scene/gsplat/gsplat-data.js index 2850ca7a3e0..0658e5507c7 100644 --- a/src/scene/gsplat/gsplat-data.js +++ b/src/scene/gsplat/gsplat-data.js @@ -10,10 +10,8 @@ import { BoundingBox } from '../../core/shape/bounding-box.js'; * @import { Vec4 } from '../../core/math/vec4.js' */ -const vec3 = new Vec3(); const mat4 = new Mat4(); const quat = new Quat(); -const quat2 = new Quat(); const aabb = new BoundingBox(); const aabb2 = new BoundingBox(); @@ -121,48 +119,6 @@ class GSplatData { result.setFromTransformedAabb(aabb, mat4); } - /** - * Transform splat data by the given matrix. - * - * @param {Mat4} mat - The matrix. - */ - transform(mat) { - const x = this.getProp('x'); - const y = this.getProp('y'); - const z = this.getProp('z'); - - if (x && y && z) { - for (let i = 0; i < this.numSplats; ++i) { - // transform center - vec3.set(x[i], y[i], z[i]); - mat.transformPoint(vec3, vec3); - x[i] = vec3.x; - y[i] = vec3.y; - z[i] = vec3.z; - } - } - - const rx = this.getProp('rot_1'); - const ry = this.getProp('rot_2'); - const rz = this.getProp('rot_3'); - const rw = this.getProp('rot_0'); - - if (rx && ry && rz && rw) { - quat2.setFromMat4(mat); - - for (let i = 0; i < this.numSplats; ++i) { - // transform orientation - quat.set(rx[i], ry[i], rz[i], rw[i]).mul2(quat2, quat); - rx[i] = quat.x; - ry[i] = quat.y; - rz[i] = quat.z; - rw[i] = quat.w; - } - } - - // TODO: transform SH - } - // access a named property getProp(name, elementName = 'vertex') { return this.getElement(elementName)?.properties.find(p => p.name === name)?.storage; diff --git a/src/scene/gsplat/gsplat-material.js b/src/scene/gsplat/gsplat-material.js index 2ef9798c893..f9e0589c3eb 100644 --- a/src/scene/gsplat/gsplat-material.js +++ b/src/scene/gsplat/gsplat-material.js @@ -6,6 +6,8 @@ import { getProgramLibrary } from '../shader-lib/get-program-library.js'; import { gsplat } from './shader-generator-gsplat.js'; const splatMainVS = /* glsl */ ` + uniform vec3 view_position; + uniform sampler2D splatColor; varying mediump vec2 texCoord; @@ -59,6 +61,12 @@ const splatMainVS = /* glsl */ ` texCoord = vertex_position.xy * scale / 2.0; + #ifdef USE_SH1 + vec4 worldCenter = matrix_model * vec4(center, 1.0); + vec3 viewDir = normalize((worldCenter.xyz / worldCenter.w - view_position) * mat3(matrix_model)); + color.xyz = max(color.xyz + evalSH(viewDir), 0.0); + #endif + #ifndef DITHER_NONE id = float(splatId); #endif diff --git a/src/scene/gsplat/gsplat.js b/src/scene/gsplat/gsplat.js index cb2b0008eaf..5f9954b6686 100644 --- a/src/scene/gsplat/gsplat.js +++ b/src/scene/gsplat/gsplat.js @@ -26,6 +26,14 @@ const _m0 = new Vec3(); const _m1 = new Vec3(); const _m2 = new Vec3(); +const getSHData = (gsplatData) => { + const result = []; + for (let i = 0; i < 45; ++i) { + result.push(gsplatData.getProp(`f_rest_${i}`)); + } + return result; +}; + /** @ignore */ class GSplat { device; @@ -47,6 +55,21 @@ class GSplat { /** @type {Texture} */ transformBTexture; + /** @type {Boolean} */ + hasSH; + + /** @type {Texture | undefined} */ + sh1to3Texture; + + /** @type {Texture | undefined} */ + sh4to7Texture; + + /** @type {Texture | undefined} */ + sh8to11Texture; + + /** @type {Texture | undefined} */ + sh12to15Texture; + /** * @param {GraphicsDevice} device - The graphics device. * @param {GSplatData} gsplatData - The splat data. @@ -71,12 +94,27 @@ class GSplat { // write texture data this.updateColorData(gsplatData); this.updateTransformData(gsplatData); + + // initialize SH data + this.hasSH = getSHData(gsplatData).every(x => x); + if (this.hasSH) { + this.sh1to3Texture = this.createTexture('splatSH_1to3', PIXELFORMAT_RGBA32U, size); + this.sh4to7Texture = this.createTexture('splatSH_4to7', PIXELFORMAT_RGBA32U, size); + this.sh8to11Texture = this.createTexture('splatSH_8to11', PIXELFORMAT_RGBA32U, size); + this.sh12to15Texture = this.createTexture('splatSH_12to15', PIXELFORMAT_RGBA32U, size); + + this.updateSHData(gsplatData); + } } destroy() { this.colorTexture?.destroy(); this.transformATexture?.destroy(); this.transformBTexture?.destroy(); + this.sh1to3Texture?.destroy(); + this.sh4to7Texture?.destroy(); + this.sh8to11Texture?.destroy(); + this.sh12to15Texture?.destroy(); } /** @@ -88,6 +126,16 @@ class GSplat { result.setParameter('transformA', this.transformATexture); result.setParameter('transformB', this.transformBTexture); result.setParameter('tex_params', new Float32Array([this.numSplats, this.colorTexture.width, 0, 0])); + + if (this.hasSH) { + result.setDefine('USE_SH1', true); + result.setDefine('USE_SH2', true); + result.setDefine('USE_SH3', true); + result.setParameter('splatSH_1to3', this.sh1to3Texture); + result.setParameter('splatSH_4to7', this.sh4to7Texture); + result.setParameter('splatSH_8to11', this.sh8to11Texture); + result.setParameter('splatSH_12to15', this.sh12to15Texture); + } return result; } @@ -236,6 +284,85 @@ class GSplat { _m2.dot(_m2) ); } + + /** + * @param {import('./gsplat-data.js').GSplatData} gsplatData - The source data + */ + updateSHData(gsplatData) { + const sh1to3Data = this.sh1to3Texture.lock(); + const sh4to7Data = this.sh4to7Texture.lock(); + const sh8to11Data = this.sh8to11Texture.lock(); + const sh12to15Data = this.sh12to15Texture.lock(); + + const src = getSHData(gsplatData); + + /** + * @param {number} value - The value to pack. + * @param {number} bits - The number of bits to use. + * @returns {number} The packed value. + */ + const packUnorm = (value, bits) => { + const t = (1 << bits) - 1; + return Math.max(0, Math.min(t, Math.floor(value * t + 0.5))); + }; + + /** + * @param {number} coef - Index of the coefficient to pack + * @param {number} idx - Index of the splat to pack + * @param {number} m - Scaling factor to normalize the set of coefficients + * @returns {number} The packed value. + */ + const pack = (coef, idx, m) => { + const r = src[coef][idx] / m; + const g = src[coef + 15][idx] / m; + const b = src[coef + 30][idx] / m; + + return packUnorm(r * 0.5 + 0.5, 11) << 21 | + packUnorm(g * 0.5 + 0.5, 10) << 11 | + packUnorm(b * 0.5 + 0.5, 11); + }; + + const float32 = new Float32Array(1); + const uint32 = new Uint32Array(float32.buffer); + + for (let i = 0; i < gsplatData.numSplats; ++i) { + let m = Math.abs(src[0][i]); + for (let j = 1; j < 45; ++j) { + m = Math.max(m, Math.abs(src[j][i])); + } + + if (m === 0) { + continue; + } + + float32[0] = m; + + sh1to3Data[i * 4 + 0] = uint32[0]; + sh1to3Data[i * 4 + 1] = pack(0, i, m); + sh1to3Data[i * 4 + 2] = pack(1, i, m); + sh1to3Data[i * 4 + 3] = pack(2, i, m); + + sh4to7Data[i * 4 + 0] = pack(3, i, m); + sh4to7Data[i * 4 + 1] = pack(4, i, m); + sh4to7Data[i * 4 + 2] = pack(5, i, m); + sh4to7Data[i * 4 + 3] = pack(6, i, m); + + sh8to11Data[i * 4 + 0] = pack(7, i, m); + sh8to11Data[i * 4 + 1] = pack(8, i, m); + sh8to11Data[i * 4 + 2] = pack(9, i, m); + sh8to11Data[i * 4 + 3] = pack(10, i, m); + + sh12to15Data[i * 4 + 0] = pack(11, i, m); + sh12to15Data[i * 4 + 1] = pack(12, i, m); + sh12to15Data[i * 4 + 2] = pack(13, i, m); + sh12to15Data[i * 4 + 3] = pack(14, i, m); + } + + this.sh1to3Texture.unlock(); + this.sh4to7Texture.unlock(); + this.sh8to11Texture.unlock(); + this.sh12to15Texture.unlock(); + } } export { GSplat }; diff --git a/src/scene/gsplat/shader-generator-gsplat.js b/src/scene/gsplat/shader-generator-gsplat.js index 4a889844218..50c7dfb6642 100644 --- a/src/scene/gsplat/shader-generator-gsplat.js +++ b/src/scene/gsplat/shader-generator-gsplat.js @@ -107,6 +107,117 @@ const splatCoreVS = /* glsl */ ` return vec4(v1, v2); } + + + // Spherical Harmonics + + vec3 unpack111011(uint bits) { + return vec3( + float(bits >> 21u) / 2047.0, + float((bits >> 11u) & 0x3ffu) / 1023.0, + float(bits & 0x7ffu) / 2047.0 + ); + } + + // fetch quantized spherical harmonic coefficients + void fetchScale(in highp usampler2D sampler, out float scale, out vec3 a, out vec3 b, out vec3 c) { + uvec4 t = texelFetch(sampler, splatUV, 0); + scale = uintBitsToFloat(t.x); + a = unpack111011(t.y) * 2.0 - 1.0; + b = unpack111011(t.z) * 2.0 - 1.0; + c = unpack111011(t.w) * 2.0 - 1.0; + } + + // fetch quantized spherical harmonic coefficients + void fetch(in highp usampler2D sampler, out vec3 a, out vec3 b, out vec3 c, out vec3 d) { + uvec4 t = texelFetch(sampler, splatUV, 0); + a = unpack111011(t.x) * 2.0 - 1.0; + b = unpack111011(t.y) * 2.0 - 1.0; + c = unpack111011(t.z) * 2.0 - 1.0; + d = unpack111011(t.w) * 2.0 - 1.0; + } + + #if defined(USE_SH1) + #define SH_C1 0.4886025119029199f + + uniform highp usampler2D splatSH_1to3; + #if defined(USE_SH2) + #define SH_C2_0 1.0925484305920792f + #define SH_C2_1 -1.0925484305920792f + #define SH_C2_2 0.31539156525252005f + #define SH_C2_3 -1.0925484305920792f + #define SH_C2_4 0.5462742152960396f + + uniform highp usampler2D splatSH_4to7; + uniform highp usampler2D splatSH_8to11; + #if defined(USE_SH3) + #define SH_C3_0 -0.5900435899266435f + #define SH_C3_1 2.890611442640554f + #define SH_C3_2 -0.4570457994644658f + #define SH_C3_3 0.3731763325901154f + #define SH_C3_4 -0.4570457994644658f + #define SH_C3_5 1.445305721320277f + #define SH_C3_6 -0.5900435899266435f + + uniform highp usampler2D splatSH_12to15; + #endif + #endif + #endif + + vec3 evalSH(in vec3 dir) { + vec3 result = vec3(0.0); + + // see https://github.com/graphdeco-inria/gaussian-splatting/blob/main/utils/sh_utils.py + #if defined(USE_SH1) + // 1st degree + float x = dir.x; + float y = dir.y; + float z = dir.z; + + float scale; + vec3 sh1, sh2, sh3; + fetchScale(splatSH_1to3, scale, sh1, sh2, sh3); + result += SH_C1 * (-sh1 * y + sh2 * z - sh3 * x); + + #if defined(USE_SH2) + // 2nd degree + float xx = x * x; + float yy = y * y; + float zz = z * z; + float xy = x * y; + float yz = y * z; + float xz = x * z; + + vec3 sh4, sh5, sh6, sh7; + vec3 sh8, sh9, sh10, sh11; + fetch(splatSH_4to7, sh4, sh5, sh6, sh7); + fetch(splatSH_8to11, sh8, sh9, sh10, sh11); + result += + sh4 * (SH_C2_0 * xy) * + + sh5 * (SH_C2_1 * yz) + + sh6 * (SH_C2_2 * (2.0 * zz - xx - yy)) + + sh7 * (SH_C2_3 * xz) + + sh8 * (SH_C2_4 * (xx - yy)); + + #if defined(USE_SH3) + // 3rd degree + vec3 sh12, sh13, sh14, sh15; + fetch(splatSH_12to15, sh12, sh13, sh14, sh15); + result += + sh9 * (SH_C3_0 * y * (3.0 * xx - yy)) + + sh10 * (SH_C3_1 * xy * z) + + sh11 * (SH_C3_2 * y * (4.0 * zz - xx - yy)) + + sh12 * (SH_C3_3 * z * (2.0 * zz - 3.0 * xx - 3.0 * yy)) + + sh13 * (SH_C3_4 * x * (4.0 * zz - xx - yy)) + + sh14 * (SH_C3_5 * z * (xx - yy)) + + sh15 * (SH_C3_6 * x * (xx - 3.0 * yy)); + #endif + #endif + result *= scale; + #endif + + return result; + } `; const splatCoreFS = /* glsl */ ` @@ -163,8 +274,8 @@ class GSplatShaderGenerator { const shaderPassDefines = shaderPassInfo.shaderDefines; const defines = - `${shaderPassDefines - }#define DITHER_${options.dither.toUpperCase()}\n` + + `${shaderPassDefines}\n` + + `#define DITHER_${options.dither.toUpperCase()}\n` + `#define TONEMAP_${options.toneMapping === TONEMAP_LINEAR ? 'DISABLED' : 'ENABLED'}\n`; const vs = defines + splatCoreVS + options.vertex; @@ -174,8 +285,7 @@ class GSplatShaderGenerator { ShaderGenerator.gammaCode(options.gamma) + splatCoreFS + options.fragment; - const defineMap = new Map(); - options.defines.forEach(value => defineMap.set(value, true)); + const defineMap = new Map(options.defines); return ShaderUtils.createDefinition(device, { name: 'SplatShader',