Skip to content

Commit

Permalink
feat: add support for graviton4 (#364)
Browse files Browse the repository at this point in the history
  • Loading branch information
gswirski authored Jan 4, 2025
1 parent 7ee6d7f commit f825c23
Show file tree
Hide file tree
Showing 3 changed files with 412 additions and 77 deletions.
49 changes: 13 additions & 36 deletions arch/arm64-sve/rpo/library.c
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#include <stddef.h>
#include <arm_sve.h>
#include "library.h"
#include "rpo_hash.h"
#include "rpo_hash_128bit.h"
#include "rpo_hash_256bit.h"

// The STATE_WIDTH of RPO hash is 12x u64 elements.
// The current generation of SVE-enabled processors - Neoverse V1
Expand Down Expand Up @@ -31,48 +32,24 @@

bool add_constants_and_apply_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector

if (vl != 4) {

if (vl == 2) {
return add_constants_and_apply_sbox_128(state, constants);
} else if (vl == 4) {
return add_constants_and_apply_sbox_256(state, constants);
} else {
return false;
}

svbool_t ptrue = svptrue_b64();

svuint64_t state1 = svld1(ptrue, state + 0*vl);
svuint64_t state2 = svld1(ptrue, state + 1*vl);

svuint64_t const1 = svld1(ptrue, constants + 0*vl);
svuint64_t const2 = svld1(ptrue, constants + 1*vl);

add_constants(ptrue, &state1, &const1, &state2, &const2, state+8, constants+8);
apply_sbox(ptrue, &state1, &state2, state+8);

svst1(ptrue, state + 0*vl, state1);
svst1(ptrue, state + 1*vl, state2);

return true;
}

bool add_constants_and_apply_inv_sbox(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
const uint64_t vl = svcntd(); // number of u64 numbers in one SVE vector

if (vl != 4) {
if (vl == 2) {
return add_constants_and_apply_inv_sbox_128(state, constants);
} else if (vl == 4) {
return add_constants_and_apply_inv_sbox_256(state, constants);
} else {
return false;
}

svbool_t ptrue = svptrue_b64();

svuint64_t state1 = svld1(ptrue, state + 0 * vl);
svuint64_t state2 = svld1(ptrue, state + 1 * vl);

svuint64_t const1 = svld1(ptrue, constants + 0 * vl);
svuint64_t const2 = svld1(ptrue, constants + 1 * vl);

add_constants(ptrue, &state1, &const1, &state2, &const2, state + 8, constants + 8);
apply_inv_sbox(ptrue, &state1, &state2, state + 8);

svst1(ptrue, state + 0 * vl, state1);
svst1(ptrue, state + 1 * vl, state2);

return true;
}
318 changes: 318 additions & 0 deletions arch/arm64-sve/rpo/rpo_hash_128bit.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
#ifndef RPO_SVE_RPO_HASH_128_H
#define RPO_SVE_RPO_HASH_128_H

#include <arm_sve.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>

#define STATE_WIDTH 12

#define COPY_128(NAME, VIN1, VIN2, VIN3, VIN4, SIN) \
svuint64_t NAME ## _1 = VIN1; \
svuint64_t NAME ## _2 = VIN2; \
svuint64_t NAME ## _3 = VIN3; \
svuint64_t NAME ## _4 = VIN4; \
uint64_t NAME ## _tail[4]; \
memcpy(NAME ## _tail, SIN, 4 * sizeof(uint64_t))

#define MULTIPLY_128(PRED, DEST, OP) \
mul_128(PRED, &DEST ## _1, &OP ## _1, &DEST ## _2, &OP ## _2, &DEST ## _3, &OP ## _3, &DEST ## _4, &OP ## _4, DEST ## _tail, OP ## _tail)

#define SQUARE_128(PRED, NAME) \
sq_128(PRED, &NAME ## _1, &NAME ## _2, &NAME ## _3, &NAME ## _4, NAME ## _tail)

#define SQUARE_DEST_128(PRED, DEST, SRC) \
COPY_128(DEST, SRC ## _1, SRC ## _2, SRC ## _3, SRC ## _4, SRC ## _tail); \
SQUARE_128(PRED, DEST);

#define POW_ACC_128(PRED, NAME, CNT, TAIL) \
for (size_t i = 0; i < CNT; i++) { \
SQUARE_128(PRED, NAME); \
} \
MULTIPLY_128(PRED, NAME, TAIL);

#define POW_ACC_DEST(PRED, DEST, CNT, HEAD, TAIL) \
COPY_128(DEST, HEAD ## _1, HEAD ## _2, HEAD ## _3, HEAD ## _4, HEAD ## _tail); \
POW_ACC_128(PRED, DEST, CNT, TAIL)

extern inline void add_constants_128(
svbool_t pg,
svuint64_t *state1,
svuint64_t *const1,
svuint64_t *state2,
svuint64_t *const2,
svuint64_t *state3,
svuint64_t *const3,
svuint64_t *state4,
svuint64_t *const4,

uint64_t *state_tail,
uint64_t *const_tail
) {
uint64_t Ms = 0xFFFFFFFF00000001ull;
svuint64_t Mv = svindex_u64(Ms, 0);

uint64_t p_1 = Ms - const_tail[0];
uint64_t p_2 = Ms - const_tail[1];
uint64_t p_3 = Ms - const_tail[2];
uint64_t p_4 = Ms - const_tail[3];

uint64_t x_1, x_2, x_3, x_4;
uint32_t adj_1 = -__builtin_sub_overflow(state_tail[0], p_1, &x_1);
uint32_t adj_2 = -__builtin_sub_overflow(state_tail[1], p_2, &x_2);
uint32_t adj_3 = -__builtin_sub_overflow(state_tail[2], p_3, &x_3);
uint32_t adj_4 = -__builtin_sub_overflow(state_tail[3], p_4, &x_4);

state_tail[0] = x_1 - (uint64_t)adj_1;
state_tail[1] = x_2 - (uint64_t)adj_2;
state_tail[2] = x_3 - (uint64_t)adj_3;
state_tail[3] = x_4 - (uint64_t)adj_4;

svuint64_t p1 = svsub_x(pg, Mv, *const1);
svuint64_t p2 = svsub_x(pg, Mv, *const2);
svuint64_t p3 = svsub_x(pg, Mv, *const3);
svuint64_t p4 = svsub_x(pg, Mv, *const4);

svuint64_t x1 = svsub_x(pg, *state1, p1);
svuint64_t x2 = svsub_x(pg, *state2, p2);
svuint64_t x3 = svsub_x(pg, *state3, p3);
svuint64_t x4 = svsub_x(pg, *state4, p4);

svbool_t pt1 = svcmplt_u64(pg, *state1, p1);
svbool_t pt2 = svcmplt_u64(pg, *state2, p2);
svbool_t pt3 = svcmplt_u64(pg, *state3, p3);
svbool_t pt4 = svcmplt_u64(pg, *state4, p4);

*state1 = svsub_m(pt1, x1, (uint32_t)-1);
*state2 = svsub_m(pt2, x2, (uint32_t)-1);
*state3 = svsub_m(pt3, x3, (uint32_t)-1);
*state4 = svsub_m(pt4, x4, (uint32_t)-1);
}

extern inline void mul_128(
svbool_t pg,
svuint64_t *r1,
const svuint64_t *op1,
svuint64_t *r2,
const svuint64_t *op2,
svuint64_t *r3,
const svuint64_t *op3,
svuint64_t *r4,
const svuint64_t *op4,
uint64_t *r_tail,
const uint64_t *op_tail
) {
__uint128_t x_1 = r_tail[0];
__uint128_t x_2 = r_tail[1];
__uint128_t x_3 = r_tail[2];
__uint128_t x_4 = r_tail[3];

x_1 *= (__uint128_t) op_tail[0];
x_2 *= (__uint128_t) op_tail[1];
x_3 *= (__uint128_t) op_tail[2];
x_4 *= (__uint128_t) op_tail[3];

uint64_t x0_1 = x_1;
uint64_t x0_2 = x_2;
uint64_t x0_3 = x_3;
uint64_t x0_4 = x_4;

svuint64_t l1 = svmul_x(pg, *r1, *op1);
svuint64_t l2 = svmul_x(pg, *r2, *op2);
svuint64_t l3 = svmul_x(pg, *r3, *op3);
svuint64_t l4 = svmul_x(pg, *r4, *op4);

uint64_t x1_1 = (x_1 >> 64);
uint64_t x1_2 = (x_2 >> 64);
uint64_t x1_3 = (x_3 >> 64);
uint64_t x1_4 = (x_4 >> 64);

uint64_t a_1, a_2, a_3, a_4;
uint64_t e_1 = __builtin_add_overflow(x0_1, (x0_1 << 32), &a_1);
uint64_t e_2 = __builtin_add_overflow(x0_2, (x0_2 << 32), &a_2);
uint64_t e_3 = __builtin_add_overflow(x0_3, (x0_3 << 32), &a_3);
uint64_t e_4 = __builtin_add_overflow(x0_4, (x0_4 << 32), &a_4);

svuint64_t ls1 = svlsl_x(pg, l1, 32);
svuint64_t ls2 = svlsl_x(pg, l2, 32);
svuint64_t ls3 = svlsl_x(pg, l3, 32);
svuint64_t ls4 = svlsl_x(pg, l4, 32);

svuint64_t a1 = svadd_x(pg, l1, ls1);
svuint64_t a2 = svadd_x(pg, l2, ls2);
svuint64_t a3 = svadd_x(pg, l3, ls3);
svuint64_t a4 = svadd_x(pg, l4, ls4);

svbool_t e1 = svcmplt(pg, a1, l1);
svbool_t e2 = svcmplt(pg, a2, l2);
svbool_t e3 = svcmplt(pg, a3, l3);
svbool_t e4 = svcmplt(pg, a4, l4);

svuint64_t as1 = svlsr_x(pg, a1, 32);
svuint64_t as2 = svlsr_x(pg, a2, 32);
svuint64_t as3 = svlsr_x(pg, a3, 32);
svuint64_t as4 = svlsr_x(pg, a4, 32);

svuint64_t b1 = svsub_x(pg, a1, as1);
svuint64_t b2 = svsub_x(pg, a2, as2);
svuint64_t b3 = svsub_x(pg, a3, as3);
svuint64_t b4 = svsub_x(pg, a4, as4);

b1 = svsub_m(e1, b1, 1);
b2 = svsub_m(e2, b2, 1);
b3 = svsub_m(e3, b3, 1);
b4 = svsub_m(e4, b4, 1);

uint64_t b_1 = a_1 - (a_1 >> 32) - e_1;
uint64_t b_2 = a_2 - (a_2 >> 32) - e_2;
uint64_t b_3 = a_3 - (a_3 >> 32) - e_3;
uint64_t b_4 = a_4 - (a_4 >> 32) - e_4;

uint64_t r_1, r_2, r_3, r_4;
uint32_t c_1 = __builtin_sub_overflow(x1_1, b_1, &r_1);
uint32_t c_2 = __builtin_sub_overflow(x1_2, b_2, &r_2);
uint32_t c_3 = __builtin_sub_overflow(x1_3, b_3, &r_3);
uint32_t c_4 = __builtin_sub_overflow(x1_4, b_4, &r_4);

svuint64_t h1 = svmulh_x(pg, *r1, *op1);
svuint64_t h2 = svmulh_x(pg, *r2, *op2);
svuint64_t h3 = svmulh_x(pg, *r3, *op3);
svuint64_t h4 = svmulh_x(pg, *r4, *op4);

svuint64_t tr1 = svsub_x(pg, h1, b1);
svuint64_t tr2 = svsub_x(pg, h2, b2);
svuint64_t tr3 = svsub_x(pg, h3, b3);
svuint64_t tr4 = svsub_x(pg, h4, b4);

svbool_t c1 = svcmplt_u64(pg, h1, b1);
svbool_t c2 = svcmplt_u64(pg, h2, b2);
svbool_t c3 = svcmplt_u64(pg, h3, b3);
svbool_t c4 = svcmplt_u64(pg, h4, b4);

*r1 = svsub_m(c1, tr1, (uint32_t) -1);
*r2 = svsub_m(c2, tr2, (uint32_t) -1);
*r3 = svsub_m(c3, tr3, (uint32_t) -1);
*r4 = svsub_m(c4, tr4, (uint32_t) -1);

uint32_t minus1_1 = 0 - c_1;
uint32_t minus1_2 = 0 - c_2;
uint32_t minus1_3 = 0 - c_3;
uint32_t minus1_4 = 0 - c_4;

r_tail[0] = r_1 - (uint64_t)minus1_1;
r_tail[1] = r_2 - (uint64_t)minus1_2;
r_tail[2] = r_3 - (uint64_t)minus1_3;
r_tail[3] = r_4 - (uint64_t)minus1_4;
}

extern inline void sq_128(svbool_t pg, svuint64_t *a, svuint64_t *b, svuint64_t *c, svuint64_t *d, uint64_t *e) {
mul_128(pg, a, a, b, b, c, c, d, d, e, e);
}

extern inline void apply_sbox_128(
svbool_t pg,
svuint64_t *state1,
svuint64_t *state2,
svuint64_t *state3,
svuint64_t *state4,
uint64_t *state_tail
) {
COPY_128(x, *state1, *state2, *state3, *state4, state_tail); // copy input to x
SQUARE_128(pg, x); // x contains input^2
mul_128(pg, state1, &x_1, state2, &x_2, state3, &x_3, state4, &x_4, state_tail, x_tail); // state contains input^3
SQUARE_128(pg, x); // x contains input^4
mul_128(pg, state1, &x_1, state2, &x_2, state3, &x_3, state4, &x_4, state_tail, x_tail); // state contains input^7
}

extern inline void apply_inv_sbox_128(
svbool_t pg,
svuint64_t *state1,
svuint64_t *state2,
svuint64_t *state3,
svuint64_t *state4,
uint64_t *state_tail
) {
// base^10
COPY_128(t1, *state1, *state2, *state3, *state4, state_tail);
SQUARE_128(pg, t1);

// base^100
SQUARE_DEST_128(pg, t2, t1);

// base^100100
POW_ACC_DEST(pg, t3, 3, t2, t2);

// base^100100100100
POW_ACC_DEST(pg, t4, 6, t3, t3);

// compute base^100100100100100100100100
POW_ACC_DEST(pg, t5, 12, t4, t4);

// compute base^100100100100100100100100100100
POW_ACC_DEST(pg, t6, 6, t5, t3);

// compute base^1001001001001001001001001001000100100100100100100100100100100
POW_ACC_DEST(pg, t7, 31, t6, t6);

// compute base^1001001001001001001001001001000110110110110110110110110110110111
SQUARE_128(pg, t7);
MULTIPLY_128(pg, t7, t6);
SQUARE_128(pg, t7);
SQUARE_128(pg, t7);
MULTIPLY_128(pg, t7, t1);
MULTIPLY_128(pg, t7, t2);
mul_128(pg, state1, &t7_1, state2, &t7_2, state3, &t7_3, state4, &t7_4, state_tail, t7_tail);
}

bool add_constants_and_apply_sbox_128(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
const uint64_t vl = 2; // number of u64 numbers in one 128 bit SVE vector
svbool_t ptrue = svptrue_b64();

svuint64_t state1 = svld1(ptrue, state + 0 * vl);
svuint64_t state2 = svld1(ptrue, state + 1 * vl);
svuint64_t state3 = svld1(ptrue, state + 2 * vl);
svuint64_t state4 = svld1(ptrue, state + 3 * vl);

svuint64_t const1 = svld1(ptrue, constants + 0 * vl);
svuint64_t const2 = svld1(ptrue, constants + 1 * vl);
svuint64_t const3 = svld1(ptrue, constants + 2 * vl);
svuint64_t const4 = svld1(ptrue, constants + 3 * vl);

add_constants_128(ptrue, &state1, &const1, &state2, &const2, &state3, &const3, &state4, &const4, state + 8, constants + 8);
apply_sbox_128(ptrue, &state1, &state2, &state3, &state4, state + 8);

svst1(ptrue, state + 0 * vl, state1);
svst1(ptrue, state + 1 * vl, state2);
svst1(ptrue, state + 2 * vl, state3);
svst1(ptrue, state + 3 * vl, state4);

return true;
}

bool add_constants_and_apply_inv_sbox_128(uint64_t state[STATE_WIDTH], uint64_t constants[STATE_WIDTH]) {
const uint64_t vl = 2; // number of u64 numbers in one 128 bit SVE vector
svbool_t ptrue = svptrue_b64();

svuint64_t state1 = svld1(ptrue, state + 0 * vl);
svuint64_t state2 = svld1(ptrue, state + 1 * vl);
svuint64_t state3 = svld1(ptrue, state + 2 * vl);
svuint64_t state4 = svld1(ptrue, state + 3 * vl);

svuint64_t const1 = svld1(ptrue, constants + 0 * vl);
svuint64_t const2 = svld1(ptrue, constants + 1 * vl);
svuint64_t const3 = svld1(ptrue, constants + 2 * vl);
svuint64_t const4 = svld1(ptrue, constants + 3 * vl);

add_constants_128(ptrue, &state1, &const1, &state2, &const2, &state3, &const3, &state4, &const4, state + 8, constants + 8);
apply_inv_sbox_128(ptrue, &state1, &state2, &state3, &state4, state + 8);

svst1(ptrue, state + 0 * vl, state1);
svst1(ptrue, state + 1 * vl, state2);
svst1(ptrue, state + 2 * vl, state3);
svst1(ptrue, state + 3 * vl, state4);

return true;
}

#endif //RPO_SVE_RPO_HASH_128_H
Loading

0 comments on commit f825c23

Please sign in to comment.