diff --git a/libs/lczero-common b/libs/lczero-common index 55e1b382ef..e05fb7a505 160000 --- a/libs/lczero-common +++ b/libs/lczero-common @@ -1 +1 @@ -Subproject commit 55e1b382efadd57903e37f2a2e29caef3ea85799 +Subproject commit e05fb7a505554682acc8a197eb797c26b6db161d diff --git a/src/neural/backends/metal/mps/NetworkGraph.h b/src/neural/backends/metal/mps/NetworkGraph.h index 2664b68c7d..60e2f660d7 100644 --- a/src/neural/backends/metal/mps/NetworkGraph.h +++ b/src/neural/backends/metal/mps/NetworkGraph.h @@ -132,12 +132,27 @@ static MPSImageFeatureChannelFormat fcFormat = MPSImageFeatureChannelFormatFloat alpha:(float)alpha label:(NSString * __nonnull)label; +-(nonnull MPSGraphTensor *) relativePositionEncodingWithTensor:(MPSGraphTensor * __nonnull)tensor + mapTensor:(MPSGraphTensor * __nonnull)rpeMapTensor + weights:(float * __nonnull)rpeWeights + depth:(NSUInteger)depth + heads:(NSUInteger)heads + queries:(NSUInteger)queries + keys:(NSUInteger)keys + type:(NSUInteger)type + label:(NSString * __nonnull)label; + +-(nonnull MPSGraphTensor *) getRpeMapTensor; + -(nonnull MPSGraphTensor *) scaledMHAMatmulWithQueries:(MPSGraphTensor * __nonnull)queries withKeys:(MPSGraphTensor * __nonnull)keys withValues:(MPSGraphTensor * __nonnull)values heads:(NSUInteger)heads parent:(MPSGraphTensor * __nonnull)parent smolgen:(lczero::MultiHeadWeights::Smolgen * __nullable)smolgen + rpeQ:(float * __nullable)rpeQ + rpeK:(float * __nullable)rpeK + rpeV:(float * __nullable)rpeV smolgenActivation:(NSString * __nullable)smolgenActivation label:(NSString * __nonnull)label; diff --git a/src/neural/backends/metal/mps/NetworkGraph.mm b/src/neural/backends/metal/mps/NetworkGraph.mm index 0befa256e6..7de08d1ce4 100644 --- a/src/neural/backends/metal/mps/NetworkGraph.mm +++ b/src/neural/backends/metal/mps/NetworkGraph.mm @@ -534,6 +534,9 @@ -(nonnull MPSGraphTensor *) addEncoderLayerWithParent:(MPSGraphTensor * __nonnul heads:heads parent:parent smolgen:encoder.mha.has_smolgen ? &encoder.mha.smolgen : nil + rpeQ:encoder.mha.rpe_q.size() > 0 ? encoder.mha.rpe_q.data() : nil + rpeK:encoder.mha.rpe_k.size() > 0 ? encoder.mha.rpe_k.data() : nil + rpeV:encoder.mha.rpe_v.size() > 0 ? encoder.mha.rpe_v.data() : nil smolgenActivation:smolgenActivation label:[NSString stringWithFormat:@"%@/mha", label]]; @@ -746,12 +749,135 @@ -(nonnull MPSGraphTensor *) transposeChannelsWithTensor:(MPSGraphTensor * __nonn name:[NSString stringWithFormat:@"%@/reshape", label]]; } +-(nonnull MPSGraphTensor *) relativePositionEncodingWithTensor:(MPSGraphTensor * __nonnull)tensor + mapTensor:(MPSGraphTensor * __nonnull)rpeMapTensor + weights:(float * __nonnull)rpeWeights + depth:(NSUInteger)depth + heads:(NSUInteger)heads + queries:(NSUInteger)queries + keys:(NSUInteger)keys + type:(NSUInteger)type + label:(NSString * __nonnull)label +{ + // RPE weights factorization. + NSData * rpeWeightsData = [NSData dataWithBytesNoCopy:(void *)rpeWeights + length:depth * heads * 15 * 15 * sizeof(float) + freeWhenDone:NO]; + + // Leela weights are transposed prior to storage. So needs to be re-transposed. + MPSGraphTensor * rpeTensor = [self variableWithData:rpeWeightsData + shape:@[@(15 * 15), @(depth * heads)] + dataType:MPSDataTypeFloat32 + name:[NSString stringWithFormat:@"%@/weights", label]]; + + rpeTensor = [self transposeTensor:rpeTensor dimension:0 withDimension:1 name:[NSString stringWithFormat:@"%@/transpose", label]]; + + rpeTensor = [self matrixMultiplicationWithPrimaryTensor:rpeTensor + secondaryTensor:rpeMapTensor + name:[NSString stringWithFormat:@"%@/factorize_matmul", label]]; + + rpeTensor = [self reshapeTensor:rpeTensor + withShape:@[@(depth), @(heads), @(queries), @(keys)] + name:[NSString stringWithFormat:@"%@/reshape", label]]; + + // Permutations to implement einsum. + // First permute rpeTensor to get D to dimension 3. + if (type == 0) { + // RPE-Q + // rpe: [D, H, Q, K] -> [H, Q, D, K] + rpeTensor = [self transposeTensor:rpeTensor dimension:0 withDimension:1 name:[NSString stringWithFormat:@"%@/transpose_1", label]]; + rpeTensor = [self transposeTensor:rpeTensor dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_2", label]]; + } else if (type == 1) { + // RPE-K + // rpe: [D, H, Q, K] -> [H, K, D, Q] + rpeTensor = [self transposeTensor:rpeTensor dimension:2 withDimension:3 name:[NSString stringWithFormat:@"%@/transpose_1", label]]; + rpeTensor = [self transposeTensor:rpeTensor dimension:0 withDimension:1 name:[NSString stringWithFormat:@"%@/transpose_2", label]]; + rpeTensor = [self transposeTensor:rpeTensor dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_3", label]]; + } else if (type == 2) { + // RPE-V + // rpe: [D, H, Q, K] -> [H, Q, K, D] + rpeTensor = [self transposeTensor:rpeTensor dimension:0 withDimension:1 name:[NSString stringWithFormat:@"%@/transpose_1", label]]; + rpeTensor = [self transposeTensor:rpeTensor dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_2", label]]; + rpeTensor = [self transposeTensor:rpeTensor dimension:2 withDimension:3 name:[NSString stringWithFormat:@"%@/transpose_3", label]]; + } + + // Second transpose Nabc -> abNc to allow abNc × abcd -> abNd, where N is the batch dimension. + // x: [B, H, Q, D] -> [H, Q, B, D] # RPE-Q + // x: [B, H, K, D] -> [H, K, B, D] # RPE-K + // x: [B, H, Q, K] -> [H, Q, B, K] # RPE-V + tensor = [self transposeTensor:tensor dimension:0 withDimension:1 name:[NSString stringWithFormat:@"%@/a_transpose_1", label]]; + tensor = [self transposeTensor:tensor dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/a_transpose_2", label]]; + + // Finally matrix multiplication and squeeze. + // x: [H, Q, B, D] x [H, Q, D, K] -> [H, Q, B, K] # RPE-Q + // x: [H, K, B, D] x [H, K, D, Q] -> [H, K, B, Q] # RPE-K + // x: [H, Q, B, K] x [H, Q, K, D] -> [H, Q, B, D] # RPE-V + tensor = [self matrixMultiplicationWithPrimaryTensor:tensor + secondaryTensor:rpeTensor + name:[NSString stringWithFormat:@"%@/rpe/matmul", label]]; + + // Reverse the last transposition back to Nabd. + tensor = [self transposeTensor:tensor dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/a_transpose_4", label]]; + tensor = [self transposeTensor:tensor dimension:0 withDimension:1 name:[NSString stringWithFormat:@"%@/a_transpose_5", label]]; + + + if (type == 1) { + // RPE-K needs another transposition back to BHQK. + // x: [B, H, K, Q] -> [B, H, Q, K] # RPE-K + return [self transposeTensor:tensor dimension:2 withDimension:3 name:[NSString stringWithFormat:@"%@/rpe/transpose_6", label]]; + } + + // x: [B, H, Q, K] # RPE-Q or RPE-K + // x: [B, H, Q, D] # RPE-V + return tensor; +} + +-(nonnull MPSGraphTensor *) getRpeMapTensor +{ + // RPE weights factorizer tensor + static MPSGraphTensor * rpeMapTensor = nil; + + @synchronized (self) { + if (rpeMapTensor == nil) { + int rows = 15 * 15; + int cols = 64 * 64; + int row, col; + std::vector rpeMap(rows * cols); + // 15 * 15 in units for distance pairs to 64 * 64 pairs of squares. + // Distance pairs mapped on rows, while square pairs mapped on columns. + for (NSUInteger i = 0; i < 8; i++) { + for (NSUInteger j = 0; j < 8; j++) { + for (NSUInteger k = 0; k < 8; k++) { + for (NSUInteger l = 0; l < 8; l++) { + row = 15 * (i - k + 7) + (j - l + 7); + col = 64 * (i * 8 + j) + k * 8 + l; + rpeMap[row * cols + col] = 1.0f; + } + } + } + } + NSData * rpeMapData = [NSData dataWithBytesNoCopy:(void *)rpeMap.data() + length:rows * cols * sizeof(float) + freeWhenDone:NO]; + + rpeMapTensor = [self variableWithData:rpeMapData + shape:@[@(rows), @(cols)] + dataType:MPSDataTypeFloat32 + name:@"rpe_factor"]; + } + } + return rpeMapTensor; +} + -(nonnull MPSGraphTensor *) scaledMHAMatmulWithQueries:(MPSGraphTensor * __nonnull)queries withKeys:(MPSGraphTensor * __nonnull)keys withValues:(MPSGraphTensor * __nonnull)values heads:(NSUInteger)heads parent:(MPSGraphTensor * __nonnull)parent smolgen:(lczero::MultiHeadWeights::Smolgen * __nullable)smolgen + rpeQ:(float * __nullable)rpeQ + rpeK:(float * __nullable)rpeK + rpeV:(float * __nullable)rpeV smolgenActivation:(NSString * __nullable)smolgenActivation label:(NSString * __nonnull)label { @@ -769,10 +895,45 @@ -(nonnull MPSGraphTensor *) scaledMHAMatmulWithQueries:(MPSGraphTensor * __nonnu values = [self transposeTensor:values dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_v", label]]; // Scaled attention matmul. - keys = [self transposeTensor:keys dimension:2 withDimension:3 name:[NSString stringWithFormat:@"%@/transpose_k_2", label]]; + MPSGraphTensor * transposedKeys = [self transposeTensor:keys dimension:2 withDimension:3 name:[NSString stringWithFormat:@"%@/transpose_k_2", label]]; MPSGraphTensor * attn = [self matrixMultiplicationWithPrimaryTensor:queries - secondaryTensor:keys + secondaryTensor:transposedKeys name:[NSString stringWithFormat:@"%@/matmul_qk", label]]; + + if (rpeQ != nil || rpeK != nil) { + MPSGraphTensor * rpeMapTensor = [self getRpeMapTensor]; + + // Apply the RPELogits to each of Q and K. + if (rpeQ != nil) { + MPSGraphTensor * rpeQTensor = [self relativePositionEncodingWithTensor:queries + mapTensor:rpeMapTensor + weights:rpeQ + depth:depth + heads:heads + queries:64 + keys:64 + type:0 // Q-type + label:[NSString stringWithFormat:@"%@/rpeQ", label]]; + attn = [self additionWithPrimaryTensor:attn + secondaryTensor:rpeQTensor + name:[NSString stringWithFormat:@"%@/rpeQ_add", label]]; + } + if (rpeK != nil) { + MPSGraphTensor * rpeKTensor = [self relativePositionEncodingWithTensor:keys + mapTensor:rpeMapTensor + weights:rpeK + depth:depth + heads:heads + queries:64 + keys:64 + type:1 // K-type + label:[NSString stringWithFormat:@"%@/rpeK", label]]; + attn = [self additionWithPrimaryTensor:attn + secondaryTensor:rpeKTensor + name:[NSString stringWithFormat:@"%@/rpeK_add", label]]; + } + } + attn = [self divisionWithPrimaryTensor:attn secondaryTensor:[self constantWithScalar:sqrt(depth) shape:@[@1] @@ -849,13 +1010,30 @@ -(nonnull MPSGraphTensor *) scaledMHAMatmulWithQueries:(MPSGraphTensor * __nonnu attn = [self applyActivationWithTensor:attn activation:@"softmax" label:label]; // matmul(scaled_attention_weights, v). - attn = [self matrixMultiplicationWithPrimaryTensor:attn - secondaryTensor:values - name:[NSString stringWithFormat:@"%@/matmul_v", label]]; + MPSGraphTensor * output = [self matrixMultiplicationWithPrimaryTensor:attn + secondaryTensor:values + name:[NSString stringWithFormat:@"%@/matmul_v", label]]; + + if (rpeV != nil) { + MPSGraphTensor * rpeMapTensor = [self getRpeMapTensor]; + // output = output + RPEValue(head_depth, name=name+'/rpe_v')(attention_weights) + MPSGraphTensor * rpeVTensor = [self relativePositionEncodingWithTensor:attn + mapTensor:rpeMapTensor + weights:rpeV + depth:depth + heads:heads + queries:64 + keys:64 + type:2 // V-type + label:[NSString stringWithFormat:@"%@/rpeV", label]]; + output = [self additionWithPrimaryTensor:output + secondaryTensor:rpeVTensor + name:[NSString stringWithFormat:@"%@/rpeV_add", label]]; + } - attn = [self transposeTensor:attn dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_a", label]]; + output = [self transposeTensor:output dimension:1 withDimension:2 name:[NSString stringWithFormat:@"%@/transpose_a", label]]; - return [self reshapeTensor:attn withShape:@[@(-1), @64, @(dmodel)] name:[NSString stringWithFormat:@"%@/reshape_a", label]]; + return [self reshapeTensor:output withShape:@[@(-1), @64, @(dmodel)] name:[NSString stringWithFormat:@"%@/reshape_a", label]]; } -(nonnull MPSGraphTensor *) scaledQKMatmulWithQueries:(MPSGraphTensor * __nonnull)queries diff --git a/src/neural/backends/metal/network_metal.cc b/src/neural/backends/metal/network_metal.cc index 0a45eb74da..a7013e8b81 100644 --- a/src/neural/backends/metal/network_metal.cc +++ b/src/neural/backends/metal/network_metal.cc @@ -160,9 +160,11 @@ MetalNetwork::MetalNetwork(const WeightsFile& file, const OptionsDict& options) "' does not exist in this net."); } - auto embedding = static_cast(file.format().network_format().input_embedding()); - builder_->build(kInputPlanes, weights, embedding, attn_body, attn_policy_, conv_policy_, - wdl_, moves_left_, activations, policy_head, value_head); + auto embedding = static_cast( + file.format().network_format().input_embedding()); + builder_->build(kInputPlanes, weights, embedding, attn_body, attn_policy_, + conv_policy_, wdl_, moves_left_, activations, policy_head, + value_head); } void MetalNetwork::forwardEval(InputsOutputs* io, int batchSize) { diff --git a/src/neural/network_legacy.cc b/src/neural/network_legacy.cc index 53846353c6..44f79c4517 100644 --- a/src/neural/network_legacy.cc +++ b/src/neural/network_legacy.cc @@ -142,7 +142,10 @@ BaseWeights::MHA::MHA(const pblczero::Weights::MHA& mha) dense_w(LayerAdapter(mha.dense_w()).as_vector()), dense_b(LayerAdapter(mha.dense_b()).as_vector()), smolgen(Smolgen(mha.smolgen())), - has_smolgen(mha.has_smolgen()) {} + has_smolgen(mha.has_smolgen()), + rpe_q(LayerAdapter(mha.rpe_q()).as_vector()), + rpe_k(LayerAdapter(mha.rpe_k()).as_vector()), + rpe_v(LayerAdapter(mha.rpe_v()).as_vector()) {} BaseWeights::FFN::FFN(const pblczero::Weights::FFN& ffn) : dense1_w(LayerAdapter(ffn.dense1_w()).as_vector()), diff --git a/src/neural/network_legacy.h b/src/neural/network_legacy.h index 72ce67544f..7ad5db82ba 100644 --- a/src/neural/network_legacy.h +++ b/src/neural/network_legacy.h @@ -81,6 +81,9 @@ struct BaseWeights { Vec dense_b; Smolgen smolgen; bool has_smolgen; + Vec rpe_q; + Vec rpe_k; + Vec rpe_v; }; struct FFN {