Skip to content

Commit

Permalink
Fix normalization.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Dec 8, 2024
1 parent f092ad7 commit eab225f
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 169 deletions.
14 changes: 7 additions & 7 deletions src/shaders/rms_norm.wgsl → src/shaders/normalize.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

var<workgroup> sketch: array<vec4<f32>, BLOCK_SIZE>;
var<workgroup> mean: f32;
var<workgroup> rms: f32;
var<workgroup> norm: f32;

fn pack4x16float(x: vec4<f32>) -> vec2<u32> {
return vec2<u32>(pack2x16float(x.xy), pack2x16float(x.zw));
Expand Down Expand Up @@ -102,16 +102,16 @@ fn rms_norm(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
reduce_sum(index, 1u);

if index == 0u {
rms = inverseSqrt(dot(sketch[0], vec4<f32>(1.0)) / f32(shape[0]) + EPS);
norm = inverseSqrt(dot(sketch[0], vec4<f32>(1.0)) / f32(shape[0]) + EPS);
}
workgroupBarrier();

for (var i = index; i < stride; i += BLOCK_SIZE) {
#ifdef FP16
let value = unpack4x16float(x[bb + i]) * rms;
let value = unpack4x16float(x[bb + i]) * norm;
x[bb + i] = pack4x16float(fma(value, unpack4x16float(w[i]), unpack4x16float(b[i])));
#else
let value = x[bb + i] * rms;
let value = x[bb + i] * norm;
x[bb + i] = fma(value, unpack4x16float(w[i]), unpack4x16float(b[i]));
#endif
}
Expand Down Expand Up @@ -147,15 +147,15 @@ fn l2_norm(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
reduce_sum(index, 1u);

if index == 0u {
l2 = inverseSqrt(dot(sketch[0], vec4<f32>(1.0)) + EPS);
norm = inverseSqrt(dot(sketch[0], vec4<f32>(1.0)) + EPS);
}
workgroupBarrier();

for (var i = index; i < stride; i += BLOCK_SIZE) {
#ifdef FP16
x[bb + i] = pack4x16float(unpack4x16float(x[bb + i]) * l2);
x[bb + i] = pack4x16float(unpack4x16float(x[bb + i]) * norm);
#else
x[bb + i] = x[bb + i] * l2;
x[bb + i] = x[bb + i] * norm;
#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ const NUM_SUBGROUPS: u32 = BLOCK_SIZE / MIN_SUBGROUP_SIZE;

var<workgroup> sketch: array<vec4<f32>, NUM_SUBGROUPS>;
var<workgroup> mean: f32;
var<workgroup> rms: f32;
var<workgroup> norm: f32;

fn pack4x16float(x: vec4<f32>) -> vec2<u32> {
return vec2<u32>(pack2x16float(x.xy), pack2x16float(x.zw));
Expand Down Expand Up @@ -146,16 +146,16 @@ fn rms_norm(
#endif

if index == 0u {
rms = inverseSqrt(dot(sketch[0], vec4<f32>(1.0)) / f32(shape[0]) + EPS);
norm = inverseSqrt(dot(sketch[0], vec4<f32>(1.0)) / f32(shape[0]) + EPS);
}
workgroupBarrier();

for (var i = index; i < stride; i += BLOCK_SIZE) {
#ifdef FP16
let value = unpack4x16float(x[bb + i]) * rms;
let value = unpack4x16float(x[bb + i]) * norm;
x[bb + i] = pack4x16float(fma(value, unpack4x16float(w[i]), unpack4x16float(b[i])));
#else
let value = x[bb + i] * rms;
let value = x[bb + i] * norm;
x[bb + i] = fma(value, unpack4x16float(w[i]), unpack4x16float(b[i]));
#endif
}
Expand Down Expand Up @@ -212,15 +212,15 @@ fn l2_norm(
#endif

if index == 0u {
l2 = inverseSqrt(dot(sketch[0], vec4<f32>(1.0)) + EPS);
norm = inverseSqrt(dot(sketch[0], vec4<f32>(1.0)) + EPS);
}
workgroupBarrier();

for (var i = index; i < stride; i += BLOCK_SIZE) {
#ifdef FP16
x[bb + i] = pack4x16float(unpack4x16float(x[bb + i]) * l2);
x[bb + i] = pack4x16float(unpack4x16float(x[bb + i]) * norm);
#else
x[bb + i] = x[bb + i] * l2;
x[bb + i] = x[bb + i] * norm;
#endif
}
}
Loading

0 comments on commit eab225f

Please sign in to comment.