diff --git a/src/arm/64/sse.S b/src/arm/64/sse.S index f4c3818cd5..34490c5adf 100644 --- a/src/arm/64/sse.S +++ b/src/arm/64/sse.S @@ -41,6 +41,9 @@ add x12, x3, x3 add x8, x1, x1, lsl 1 add x9, x3, x3, lsl 1 +.elseif \width >= 64 + mov w8, #(\width) + sxtw x9, w8 .endif movi v17.4s, #0 mov w10, #(\height) @@ -149,7 +152,7 @@ L(wsse_w16): RET_SUM endfunc -.macro LOAD_32X4 +.macro LOAD_32X4 vert=1 ldp q0, q22, [x0] ldp q4, q26, [x2] add x0, x0, x1 @@ -166,9 +169,18 @@ endfunc ldp q7, q29, [x2] add x0, x0, x1 add x2, x2, x3 +.if \vert == 1 ldp q16, q19, [x4] add x4, x4, x5 subs w10, w10, #4 +.else + sub x0, x0, x1, lsl 2 + sub x2, x2, x3, lsl 2 + add x0, x0, #32 + add x2, x2, #32 + ldp q16, q19, [x4] + add x4, x4, #32 +.endif mov v18.d[0], v16.d[1] mov v20.d[0], v19.d[1] .endm @@ -245,10 +257,33 @@ L(wsse_w32): RET_SUM endfunc +function weighted_sse_64x64_neon, export=1 + INIT 64, 64 +L(wsse_w32up): + LOAD_32X4 vert=0 + WEIGHTED_SSE_32X4 + subs w8, w8, #32 + bne L(wsse_w32up) + mov w8, w9 + sub x0, x0, x9 + sub x2, x2, x9 + add x0, x0, x1, lsl 2 + add x2, x2, x3, lsl 2 + sub x4, x4, x9 + add x4, x4, x5 + subs w10, w10, #4 + bne L(wsse_w32up) + RET_SUM +endfunc + .macro weighted_sse width, height function weighted_sse_\width\()x\height\()_neon, export=1 INIT \width, \height +.if \width <= 32 b L(wsse_w\width) +.else + b L(wsse_w32up) +.endif endfunc .endm @@ -264,3 +299,8 @@ weighted_sse 16, 64 weighted_sse 32, 8 weighted_sse 32, 16 weighted_sse 32, 64 +weighted_sse 64, 16 +weighted_sse 64, 32 +weighted_sse 64, 128 +weighted_sse 128, 64 +weighted_sse 128, 128 diff --git a/src/asm/aarch64/dist/sse.rs b/src/asm/aarch64/dist/sse.rs index e2509294b7..4c8ad4afe1 100644 --- a/src/asm/aarch64/dist/sse.rs +++ b/src/asm/aarch64/dist/sse.rs @@ -52,7 +52,13 @@ declare_asm_sse_fn![ rav1e_weighted_sse_32x8_neon, rav1e_weighted_sse_32x16_neon, rav1e_weighted_sse_32x32_neon, - rav1e_weighted_sse_32x64_neon + rav1e_weighted_sse_32x64_neon, + rav1e_weighted_sse_64x16_neon, + rav1e_weighted_sse_64x32_neon, + rav1e_weighted_sse_64x64_neon, + rav1e_weighted_sse_64x128_neon, + rav1e_weighted_sse_128x64_neon, + rav1e_weighted_sse_128x128_neon ]; /// # Panics @@ -132,6 +138,12 @@ static SSE_FNS_NEON: [Option; DIST_FNS_LENGTH] = { out[BLOCK_32X16 as usize] = Some(rav1e_weighted_sse_32x16_neon); out[BLOCK_32X32 as usize] = Some(rav1e_weighted_sse_32x32_neon); out[BLOCK_32X64 as usize] = Some(rav1e_weighted_sse_32x64_neon); + out[BLOCK_64X16 as usize] = Some(rav1e_weighted_sse_64x16_neon); + out[BLOCK_64X32 as usize] = Some(rav1e_weighted_sse_64x32_neon); + out[BLOCK_64X64 as usize] = Some(rav1e_weighted_sse_64x64_neon); + out[BLOCK_64X128 as usize] = Some(rav1e_weighted_sse_64x128_neon); + out[BLOCK_128X64 as usize] = Some(rav1e_weighted_sse_128x64_neon); + out[BLOCK_128X128 as usize] = Some(rav1e_weighted_sse_128x128_neon); out };