Skip to content

Commit

Permalink
Support GS spherical harmonics (playcanvas#6946)
Browse files Browse the repository at this point in the history
  • Loading branch information
slimbuck committed Sep 11, 2024
1 parent 9dc3fd9 commit df8501a
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 74 deletions.
28 changes: 2 additions & 26 deletions src/framework/parsers/ply.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { GSplatData } from '../../scene/gsplat/gsplat-data.js';
import { GSplatCompressedData } from '../../scene/gsplat/gsplat-compressed-data.js';
import { GSplatResource } from './gsplat-resource.js';
import { Mat4 } from '../../core/math/mat4.js';

/**
* @import { AssetRegistry } from '../asset/asset-registry.js'
Expand Down Expand Up @@ -461,24 +460,8 @@ const readPly = async (reader, propertyFilter = null) => {
return new GSplatData(elements);
};

// filter out element data we're not going to use
const defaultElements = [
'x', 'y', 'z',
'f_dc_0', 'f_dc_1', 'f_dc_2', 'opacity',
'rot_0', 'rot_1', 'rot_2', 'rot_3',
'scale_0', 'scale_1', 'scale_2',
// compressed format elements
'min_x', 'min_y', 'min_z',
'max_x', 'max_y', 'max_z',
'min_scale_x', 'min_scale_y', 'min_scale_z',
'max_scale_x', 'max_scale_y', 'max_scale_z',
'packed_position', 'packed_rotation', 'packed_scale', 'packed_color'
];

const defaultElementsSet = new Set(defaultElements);
const defaultElementFilter = val => defaultElementsSet.has(val);

const mat4 = new Mat4();
// by default load everything
const defaultElementFilter = val => true;

class PlyParser {
/** @type {GraphicsDevice} */
Expand Down Expand Up @@ -517,13 +500,6 @@ class PlyParser {
readPly(response.body.getReader(), asset.data.elementFilter ?? defaultElementFilter)
.then((gsplatData) => {
if (!gsplatData.isCompressed) {

// perform Z scale
if (asset.data.performZScale ?? true) {
mat4.setScale(-1, -1, 1);
gsplatData.transform(mat4);
}

// reorder data
if (asset.data.reorder ?? true) {
gsplatData.reorderData();
Expand Down
44 changes: 0 additions & 44 deletions src/scene/gsplat/gsplat-data.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ import { BoundingBox } from '../../core/shape/bounding-box.js';
* @import { Vec4 } from '../../core/math/vec4.js'
*/

const vec3 = new Vec3();
const mat4 = new Mat4();
const quat = new Quat();
const quat2 = new Quat();
const aabb = new BoundingBox();
const aabb2 = new BoundingBox();

Expand Down Expand Up @@ -121,48 +119,6 @@ class GSplatData {
result.setFromTransformedAabb(aabb, mat4);
}

/**
* Transform splat data by the given matrix.
*
* @param {Mat4} mat - The matrix.
*/
transform(mat) {
const x = this.getProp('x');
const y = this.getProp('y');
const z = this.getProp('z');

if (x && y && z) {
for (let i = 0; i < this.numSplats; ++i) {
// transform center
vec3.set(x[i], y[i], z[i]);
mat.transformPoint(vec3, vec3);
x[i] = vec3.x;
y[i] = vec3.y;
z[i] = vec3.z;
}
}

const rx = this.getProp('rot_1');
const ry = this.getProp('rot_2');
const rz = this.getProp('rot_3');
const rw = this.getProp('rot_0');

if (rx && ry && rz && rw) {
quat2.setFromMat4(mat);

for (let i = 0; i < this.numSplats; ++i) {
// transform orientation
quat.set(rx[i], ry[i], rz[i], rw[i]).mul2(quat2, quat);
rx[i] = quat.x;
ry[i] = quat.y;
rz[i] = quat.z;
rw[i] = quat.w;
}
}

// TODO: transform SH
}

// access a named property
getProp(name, elementName = 'vertex') {
return this.getElement(elementName)?.properties.find(p => p.name === name)?.storage;
Expand Down
8 changes: 8 additions & 0 deletions src/scene/gsplat/gsplat-material.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import { getProgramLibrary } from '../shader-lib/get-program-library.js';
import { gsplat } from './shader-generator-gsplat.js';

const splatMainVS = /* glsl */ `
uniform vec3 view_position;
uniform sampler2D splatColor;
varying mediump vec2 texCoord;
Expand Down Expand Up @@ -59,6 +61,12 @@ const splatMainVS = /* glsl */ `
texCoord = vertex_position.xy * scale / 2.0;
#ifdef USE_SH1
vec4 worldCenter = matrix_model * vec4(center, 1.0);
vec3 viewDir = normalize((worldCenter.xyz / worldCenter.w - view_position) * mat3(matrix_model));
color.xyz = max(color.xyz + evalSH(viewDir), 0.0);
#endif
#ifndef DITHER_NONE
id = float(splatId);
#endif
Expand Down
127 changes: 127 additions & 0 deletions src/scene/gsplat/gsplat.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ const _m0 = new Vec3();
const _m1 = new Vec3();
const _m2 = new Vec3();

const getSHData = (gsplatData) => {
const result = [];
for (let i = 0; i < 45; ++i) {
result.push(gsplatData.getProp(`f_rest_${i}`));
}
return result;
};

/** @ignore */
class GSplat {
device;
Expand All @@ -47,6 +55,21 @@ class GSplat {
/** @type {Texture} */
transformBTexture;

/** @type {Boolean} */
hasSH;

/** @type {Texture | undefined} */
sh1to3Texture;

/** @type {Texture | undefined} */
sh4to7Texture;

/** @type {Texture | undefined} */
sh8to11Texture;

/** @type {Texture | undefined} */
sh12to15Texture;

/**
* @param {GraphicsDevice} device - The graphics device.
* @param {GSplatData} gsplatData - The splat data.
Expand All @@ -71,12 +94,27 @@ class GSplat {
// write texture data
this.updateColorData(gsplatData);
this.updateTransformData(gsplatData);

// initialize SH data
this.hasSH = getSHData(gsplatData).every(x => x);
if (this.hasSH) {
this.sh1to3Texture = this.createTexture('splatSH_1to3', PIXELFORMAT_RGBA32U, size);
this.sh4to7Texture = this.createTexture('splatSH_4to7', PIXELFORMAT_RGBA32U, size);
this.sh8to11Texture = this.createTexture('splatSH_8to11', PIXELFORMAT_RGBA32U, size);
this.sh12to15Texture = this.createTexture('splatSH_12to15', PIXELFORMAT_RGBA32U, size);

this.updateSHData(gsplatData);
}
}

destroy() {
this.colorTexture?.destroy();
this.transformATexture?.destroy();
this.transformBTexture?.destroy();
this.sh1to3Texture?.destroy();
this.sh4to7Texture?.destroy();
this.sh8to11Texture?.destroy();
this.sh12to15Texture?.destroy();
}

/**
Expand All @@ -88,6 +126,16 @@ class GSplat {
result.setParameter('transformA', this.transformATexture);
result.setParameter('transformB', this.transformBTexture);
result.setParameter('tex_params', new Float32Array([this.numSplats, this.colorTexture.width, 0, 0]));

if (this.hasSH) {
result.setDefine('USE_SH1', true);
result.setDefine('USE_SH2', true);
result.setDefine('USE_SH3', true);
result.setParameter('splatSH_1to3', this.sh1to3Texture);
result.setParameter('splatSH_4to7', this.sh4to7Texture);
result.setParameter('splatSH_8to11', this.sh8to11Texture);
result.setParameter('splatSH_12to15', this.sh12to15Texture);
}
return result;
}

Expand Down Expand Up @@ -236,6 +284,85 @@ class GSplat {
_m2.dot(_m2)
);
}

/**
* @param {import('./gsplat-data.js').GSplatData} gsplatData - The source data
*/
updateSHData(gsplatData) {
const sh1to3Data = this.sh1to3Texture.lock();
const sh4to7Data = this.sh4to7Texture.lock();
const sh8to11Data = this.sh8to11Texture.lock();
const sh12to15Data = this.sh12to15Texture.lock();

const src = getSHData(gsplatData);

/**
* @param {number} value - The value to pack.
* @param {number} bits - The number of bits to use.
* @returns {number} The packed value.
*/
const packUnorm = (value, bits) => {
const t = (1 << bits) - 1;
return Math.max(0, Math.min(t, Math.floor(value * t + 0.5)));
};

/**
* @param {number} coef - Index of the coefficient to pack
* @param {number} idx - Index of the splat to pack
* @param {number} m - Scaling factor to normalize the set of coefficients
* @returns {number} The packed value.
*/
const pack = (coef, idx, m) => {
const r = src[coef][idx] / m;
const g = src[coef + 15][idx] / m;
const b = src[coef + 30][idx] / m;

return packUnorm(r * 0.5 + 0.5, 11) << 21 |
packUnorm(g * 0.5 + 0.5, 10) << 11 |
packUnorm(b * 0.5 + 0.5, 11);
};

const float32 = new Float32Array(1);
const uint32 = new Uint32Array(float32.buffer);

for (let i = 0; i < gsplatData.numSplats; ++i) {
let m = Math.abs(src[0][i]);
for (let j = 1; j < 45; ++j) {
m = Math.max(m, Math.abs(src[j][i]));
}

if (m === 0) {
continue;
}

float32[0] = m;

sh1to3Data[i * 4 + 0] = uint32[0];
sh1to3Data[i * 4 + 1] = pack(0, i, m);
sh1to3Data[i * 4 + 2] = pack(1, i, m);
sh1to3Data[i * 4 + 3] = pack(2, i, m);

sh4to7Data[i * 4 + 0] = pack(3, i, m);
sh4to7Data[i * 4 + 1] = pack(4, i, m);
sh4to7Data[i * 4 + 2] = pack(5, i, m);
sh4to7Data[i * 4 + 3] = pack(6, i, m);

sh8to11Data[i * 4 + 0] = pack(7, i, m);
sh8to11Data[i * 4 + 1] = pack(8, i, m);
sh8to11Data[i * 4 + 2] = pack(9, i, m);
sh8to11Data[i * 4 + 3] = pack(10, i, m);

sh12to15Data[i * 4 + 0] = pack(11, i, m);
sh12to15Data[i * 4 + 1] = pack(12, i, m);
sh12to15Data[i * 4 + 2] = pack(13, i, m);
sh12to15Data[i * 4 + 3] = pack(14, i, m);
}

this.sh1to3Texture.unlock();
this.sh4to7Texture.unlock();
this.sh8to11Texture.unlock();
this.sh12to15Texture.unlock();
}
}

export { GSplat };
Loading

0 comments on commit df8501a

Please sign in to comment.