Skip to content

Commit

Permalink
finish arch dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
Martmists-GH committed Dec 14, 2024
1 parent bc112cf commit 1a4cf8c
Show file tree
Hide file tree
Showing 34 changed files with 1,611 additions and 510 deletions.
4 changes: 1 addition & 3 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ kotlin {
"fma3_avx2",
"fma3_avx",
"fma3_sse4_2",
"fma4",
"sse2",
"sse3",
"sse4_1",
Expand All @@ -130,10 +129,9 @@ kotlin {
"fma3_avx2" -> arrayOf("-mfma") + flagsFor("avx2")
"fma3_avx" -> arrayOf("-mfma") + flagsFor("avx")
"fma3_sse4_2" -> arrayOf("-mfma") + flagsFor("sse4_2")
"fma4" -> arrayOf("-mfma4") + flagsFor("sse4_2")
"sse2" -> arrayOf("-msse2")
"sse3" -> arrayOf("-msse3") + flagsFor("sse2")
"sse4_1" -> arrayOf("-msse4.1") + flagsFor("sse3")
"sse4_1" -> arrayOf("-msse4.1") + flagsFor("ssse3")
"sse4_2" -> arrayOf("-msse4.2") + flagsFor("sse4_1")
"ssse3" -> arrayOf("-mssse3") + flagsFor("sse3")

Expand Down
4 changes: 4 additions & 0 deletions src/commonMain/kotlin/com/martmists/ndarray/simd/F64Array.kt
Original file line number Diff line number Diff line change
Expand Up @@ -379,13 +379,17 @@ interface F64Array {
* @see kotlin.math.ln
*/
fun logInPlace() = transformInPlace(::ln)
// Alias
fun lnInPlace() = logInPlace()

/**
* Computes ln(x) for each element in the array.
*
* @return the computed array
*/
fun log(): F64Array = copy().apply { logInPlace() }
// Alias
fun ln() = log()

/**
* Computes ln(1 + x) for each element in the array in place.
Expand Down
79 changes: 79 additions & 0 deletions src/commonTest/kotlin/MathTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import com.martmists.ndarray.simd.F64Array
import com.martmists.ndarray.simd.impl.F64LargeDenseFlatArrayImpl
import com.martmists.ndarray.simd.pow
import kotlin.math.*
import kotlin.test.Test
import kotlin.test.assertContentEquals

class MathTest {
// FIXME: disabled until we can test within certain precision
// @Test
// fun `Test Array sqrt`() {
// val arr1 = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { it.toDouble() }, 0, F64Array.simdSize)
// val arr2 = arr1.sqrt()
// assertContentEquals(arr2.toDoubleArray(), DoubleArray(F64Array.simdSize) { sqrt(it.toDouble()) })
// }
//
// @Test
// fun `Test Array pow`() {
// val arr1 = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { it.toDouble() }, 0, F64Array.simdSize)
// val arr2 = arr1.pow(2.0)
// assertContentEquals(arr2.toDoubleArray(), DoubleArray(F64Array.simdSize) { it.toDouble().pow(2.0) })
// }
//
// @Test
// fun `Test Array ipow`() {
// val arr1 = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { it.toDouble() }, 0, F64Array.simdSize)
// val arr2 = (2.0).pow(arr1)
// assertContentEquals(arr2.toDoubleArray(), DoubleArray(F64Array.simdSize) { (2.0).pow(it) })
// }
//
// @Test
// fun `Test Array log`() {
// val arr1 = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { it.toDouble() }, 0, F64Array.simdSize)
// val arr2 = arr1.ln()
// assertContentEquals(arr2.toDoubleArray(), DoubleArray(F64Array.simdSize) { ln(it.toDouble()) })
// }
//
// @Test
// fun `Test Array logbase`() {
// val arr1 = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { it.toDouble() }, 0, F64Array.simdSize)
// val arr2 = arr1.logBase(5.0)
// assertContentEquals(arr2.toDoubleArray(), DoubleArray(F64Array.simdSize) { log(it.toDouble(), 5.0) })
// }
//
// @Test
// fun `Test Array exp`() {
// val arr1 = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { it.toDouble() }, 0, F64Array.simdSize)
// val arr2 = arr1.exp()
// assertContentEquals(arr2.toDoubleArray(), DoubleArray(F64Array.simdSize) { exp(it.toDouble()) })
// }
//
// @Test
// fun `Test Array expm1`() {
// val arr1 = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { it.toDouble() }, 0, F64Array.simdSize)
// val arr2 = arr1.expm1()
// assertContentEquals(arr2.toDoubleArray(), DoubleArray(F64Array.simdSize) { expm1(it.toDouble()) })
// }
//
// @Test
// fun `Test Array log1p`() {
// val arr1 = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { it.toDouble() }, 0, F64Array.simdSize)
// val arr2 = arr1.log1p()
// assertContentEquals(arr2.toDoubleArray(), DoubleArray(F64Array.simdSize) { ln(1.0 + it) })
// }
//
// @Test
// fun `Test Array log2`() {
// val arr1 = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { it.toDouble() }, 0, F64Array.simdSize)
// val arr2 = arr1.log2()
// assertContentEquals(arr2.toDoubleArray(), DoubleArray(F64Array.simdSize) { log2(it.toDouble()) })
// }
//
// @Test
// fun `Test Array log10`() {
// val arr1 = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { it.toDouble() }, 0, F64Array.simdSize)
// val arr2 = arr1.log10()
// assertContentEquals(arr2.toDoubleArray(), DoubleArray(F64Array.simdSize) { log10(it.toDouble()) })
// }
}
58 changes: 58 additions & 0 deletions src/commonTest/kotlin/ProcedureTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import com.martmists.ndarray.simd.F64Array
import com.martmists.ndarray.simd.impl.F64LargeDenseFlatArrayImpl
import com.martmists.ndarray.simd.pow
import kotlin.math.*
import kotlin.test.Test
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals

class ProcedureTest {
@Test
fun `Test Array sum`() {
val arr = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { 1.0 }, 0, F64Array.simdSize)
val res = arr.sum()
assertEquals(F64Array.simdSize.toDouble(), res)
}

@Test
fun `Test Array min`() {
val arr = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { it * 2.0 }, 0, F64Array.simdSize)
val res = arr.min()
assertEquals(0.0, res)
}

@Test
fun `Test Array max`() {
val arr = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { it * 5.0 }, 0, F64Array.simdSize)
val res = arr.max()
assertEquals((F64Array.simdSize * 5 - 5).toDouble(), res)
}

@Test
fun `Test Array prod`() {
val arr = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { it + 1.0 }, 0, F64Array.simdSize)
val res = arr.product()

fun factorial(n: Double): Double {
if (n == 1.0) return n
return n * factorial(n - 1)
}

assertEquals(factorial(F64Array.simdSize.toDouble()), res)
}

@Test
fun `Test Array mean`() {
val arr = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { it.toDouble() }, 0, F64Array.simdSize)
val res = arr.mean()
assertEquals((F64Array.simdSize.toDouble() - 1) / 2, res)
}

@Test
fun `Test Array coerce`() {
val arr1 = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { it * 5.0 }, 0, F64Array.simdSize)
val arr2 = arr1.coerce(2.5, 17.0)
assertEquals(2.5, arr2.min())
assertEquals(17.0, arr2.max())
}
}
37 changes: 37 additions & 0 deletions src/commonTest/kotlin/RoundingTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import com.martmists.ndarray.simd.F64Array
import com.martmists.ndarray.simd.impl.F64LargeDenseFlatArrayImpl
import com.martmists.ndarray.simd.pow
import kotlin.math.*
import kotlin.test.Test
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals

class RoundingTest {
@Test
fun `Test Array floor`() {
val arr1 = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { 0.1 * it }, 0, F64Array.simdSize)
val arr2 = arr1.floor()
assertContentEquals(arr2.toDoubleArray(), DoubleArray(F64Array.simdSize) { floor(0.1 * it) })
}

@Test
fun `Test Array ceil`() {
val arr1 = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { 0.1 * it }, 0, F64Array.simdSize)
val arr2 = arr1.ceil()
assertContentEquals(arr2.toDoubleArray(), DoubleArray(F64Array.simdSize) { ceil(0.1 * it) })
}

@Test
fun `Test Array trunc`() {
val arr1 = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { 0.1 * it }, 0, F64Array.simdSize)
val arr2 = arr1.trunc()
assertContentEquals(arr2.toDoubleArray(), DoubleArray(F64Array.simdSize) { truncate(0.1 * it) })
}

@Test
fun `Test Array round`() {
val arr1 = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { 0.2 * it }, 0, F64Array.simdSize)
val arr2 = arr1.round()
assertContentEquals(arr2.toDoubleArray(), DoubleArray(F64Array.simdSize) { round(0.2 * it) })
}
}
11 changes: 11 additions & 0 deletions src/commonTest/kotlin/TrigonometryTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import com.martmists.ndarray.simd.F64Array
import com.martmists.ndarray.simd.impl.F64LargeDenseFlatArrayImpl
import com.martmists.ndarray.simd.pow
import kotlin.math.*
import kotlin.test.Test
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals

class TrigonometryTest {
// TODO
}
17 changes: 17 additions & 0 deletions src/commonTest/kotlin/VectorTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import com.martmists.ndarray.simd.F64Array
import com.martmists.ndarray.simd.impl.F64LargeDenseFlatArrayImpl
import com.martmists.ndarray.simd.pow
import kotlin.math.*
import kotlin.test.Test
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals

class VectorTest {
@Test
fun `Test Array dot Array`() {
val arr1 = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { it.toDouble() }, 0, F64Array.simdSize)
val arr2 = F64LargeDenseFlatArrayImpl(DoubleArray(F64Array.simdSize) { it - 1.0 }, 0, F64Array.simdSize)
val res = arr1 dot arr2
assertEquals(res, (0 until F64Array.simdSize).sumOf { it * (it - 1.0) })
}
}
56 changes: 56 additions & 0 deletions src/lib/arch/avx.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
#include "../cpp/arithmetic_priv.h"
#include "../cpp/bitwise_priv.h"
#include "../cpp/compare_priv.h"
#include "../cpp/math_priv.h"
#include "../cpp/misc_priv.h"
#include "../cpp/procedure_priv.h"
#include "../cpp/rounding_priv.h"
#include "../cpp/trigonometry_priv.h"
#include "../cpp/vector_priv.h"

// Arithmetic
template void _vec_add_scalar::operator()<xsimd::avx>(xsimd::avx, double *, double, int);
Expand Down Expand Up @@ -43,3 +49,53 @@ template void _vec_gte_vec::operator()<xsimd::avx>(xsimd::avx, double *, double
template void _vec_isnan::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_isinf::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_isfinite::operator()<xsimd::avx>(xsimd::avx, double *, int);

// Math
template void _vec_sqrt::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_pow::operator()<xsimd::avx>(xsimd::avx, double *, double, int);
template void _vec_ipow::operator()<xsimd::avx>(xsimd::avx, double *, double, int);
template void _vec_log::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_logbase::operator()<xsimd::avx>(xsimd::avx, double *, double, int);
template void _vec_exp::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_expm1::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_log1p::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_log2::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_log10::operator()<xsimd::avx>(xsimd::avx, double *, int);

// Misc
template void _vec_copy::operator()<xsimd::avx>(xsimd::avx, double *, double *, int);
template int _get_simd_size::operator()<xsimd::avx>(xsimd::avx);

// Procedure
template double _vec_sum::operator()<xsimd::avx>(xsimd::avx, double *, int);
template double _vec_min::operator()<xsimd::avx>(xsimd::avx, double *, int);
template double _vec_max::operator()<xsimd::avx>(xsimd::avx, double *, int);
template double _vec_prod::operator()<xsimd::avx>(xsimd::avx, double *, int);
template double _vec_var::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_coerce::operator()<xsimd::avx>(xsimd::avx, double *, int, double, double);

// Rounding
template void _vec_floor::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_ceil::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_trunc::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_round::operator()<xsimd::avx>(xsimd::avx, double *, int);

// Trigonometry
template void _vec_sin::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_cos::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_tan::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_asin::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_acos::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_atan::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_atan2::operator()<xsimd::avx>(xsimd::avx, double *, double *, int);
template void _vec_sinh::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_cosh::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_tanh::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_asinh::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_acosh::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_atanh::operator()<xsimd::avx>(xsimd::avx, double *, int);
template void _vec_hypot::operator()<xsimd::avx>(xsimd::avx, double *, double *, int);

// Vector
template double _vec_dot::operator()<xsimd::avx>(xsimd::avx, double *, double *, int);
template void _vec_matmul::operator()<xsimd::avx>(xsimd::avx, double *, double *, double *, int, int, int);
56 changes: 56 additions & 0 deletions src/lib/arch/avx2.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
#include "../cpp/arithmetic_priv.h"
#include "../cpp/bitwise_priv.h"
#include "../cpp/compare_priv.h"
#include "../cpp/math_priv.h"
#include "../cpp/misc_priv.h"
#include "../cpp/procedure_priv.h"
#include "../cpp/rounding_priv.h"
#include "../cpp/trigonometry_priv.h"
#include "../cpp/vector_priv.h"

// Arithmetic
template void _vec_add_scalar::operator()<xsimd::avx2>(xsimd::avx2, double *, double, int);
Expand Down Expand Up @@ -43,3 +49,53 @@ template void _vec_gte_vec::operator()<xsimd::avx2>(xsimd::avx2, double *, doubl
template void _vec_isnan::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_isinf::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_isfinite::operator()<xsimd::avx2>(xsimd::avx2, double *, int);

// Math
template void _vec_sqrt::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_pow::operator()<xsimd::avx2>(xsimd::avx2, double *, double, int);
template void _vec_ipow::operator()<xsimd::avx2>(xsimd::avx2, double *, double, int);
template void _vec_log::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_logbase::operator()<xsimd::avx2>(xsimd::avx2, double *, double, int);
template void _vec_exp::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_expm1::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_log1p::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_log2::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_log10::operator()<xsimd::avx2>(xsimd::avx2, double *, int);

// Misc
template void _vec_copy::operator()<xsimd::avx2>(xsimd::avx2, double *, double *, int);
template int _get_simd_size::operator()<xsimd::avx2>(xsimd::avx2);

// Procedure
template double _vec_sum::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template double _vec_min::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template double _vec_max::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template double _vec_prod::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template double _vec_var::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_coerce::operator()<xsimd::avx2>(xsimd::avx2, double *, int, double, double);

// Rounding
template void _vec_floor::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_ceil::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_trunc::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_round::operator()<xsimd::avx2>(xsimd::avx2, double *, int);

// Trigonometry
template void _vec_sin::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_cos::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_tan::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_asin::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_acos::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_atan::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_atan2::operator()<xsimd::avx2>(xsimd::avx2, double *, double *, int);
template void _vec_sinh::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_cosh::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_tanh::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_asinh::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_acosh::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_atanh::operator()<xsimd::avx2>(xsimd::avx2, double *, int);
template void _vec_hypot::operator()<xsimd::avx2>(xsimd::avx2, double *, double *, int);

// Vector
template double _vec_dot::operator()<xsimd::avx2>(xsimd::avx2, double *, double *, int);
template void _vec_matmul::operator()<xsimd::avx2>(xsimd::avx2, double *, double *, double *, int, int, int);
Loading

0 comments on commit 1a4cf8c

Please sign in to comment.