From 19ffdee614e9d0db46dd1fbb3ae96364e3adda4d Mon Sep 17 00:00:00 2001 From: rhdong Date: Tue, 21 Jan 2025 12:48:38 -0800 Subject: [PATCH] [Fix] l2_exp random fail in half-float32 mixed precision on self-neighboring --- .../distance/detail/distance_ops/l2_exp.cuh | 34 ++++++++++++------- python/cuvs/cuvs/test/test_distance.py | 5 ++- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/cpp/src/distance/detail/distance_ops/l2_exp.cuh b/cpp/src/distance/detail/distance_ops/l2_exp.cuh index 04817aa8b..f49771605 100644 --- a/cpp/src/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/src/distance/detail/distance_ops/l2_exp.cuh @@ -28,14 +28,14 @@ namespace cuvs::distance::detail::ops { * for round-off error tolerance. * @tparam DataT */ -template -__device__ constexpr DataT get_clamp_precision() +template +__device__ constexpr AccT get_clamp_precision() { switch (sizeof(DataT)) { - case 2: return 1e-3; - case 4: return 1e-6; - case 8: return 1e-15; - default: return 0; + case 2: return AccT{1e-3}; + case 4: return AccT{1e-6}; + case 8: return AccT{1e-15}; + default: return AccT{0}; } } @@ -46,19 +46,27 @@ struct l2_exp_cutlass_op { __device__ l2_exp_cutlass_op() noexcept : sqrt(false) {} __device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {} - inline __device__ AccT operator()(DataT aNorm, DataT bNorm, DataT accVal) const noexcept + inline __device__ AccT operator()(AccT aNorm, AccT bNorm, AccT accVal) const noexcept { - AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; + AccT outVal = aNorm + bNorm - AccT(2.0) * accVal; /** * Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product (accVal) * can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal instead. */ - outVal = outVal * AccT(!((outVal * outVal < get_clamp_precision()) * (aNorm == bNorm))); + outVal = + outVal * AccT(!((outVal * outVal < get_clamp_precision()) * (aNorm == bNorm))); return sqrt ? raft::sqrt(outVal * static_cast(outVal > AccT(0))) : outVal; } - __device__ AccT operator()(DataT aData) const noexcept { return aData; } + __device__ AccT operator()(DataT aData) const noexcept + { + if constexpr (std::is_same_v && std::is_same_v) { + return __half2float(aData); + } else { + return aData; + } + } }; /** @@ -121,9 +129,9 @@ struct l2_exp_distance_op { * (accVal) can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal * instead. */ - acc[i][j] = - val * static_cast((val > AccT(0))) * - static_cast(!((val * val < get_clamp_precision()) * (regxn[i] == regyn[j]))); + acc[i][j] = val * static_cast((val > AccT(0))) * + static_cast( + !((val * val < get_clamp_precision()) * (regxn[i] == regyn[j]))); } } if (sqrt) { diff --git a/python/cuvs/cuvs/test/test_distance.py b/python/cuvs/cuvs/test/test_distance.py index 483d5d201..370dd773a 100644 --- a/python/cuvs/cuvs/test/test_distance.py +++ b/python/cuvs/cuvs/test/test_distance.py @@ -21,6 +21,7 @@ from cuvs.distance import pairwise_distance +@pytest.mark.parametrize("times", range(20)) @pytest.mark.parametrize("n_rows", [50, 100]) @pytest.mark.parametrize("n_cols", [10, 50]) @pytest.mark.parametrize( @@ -43,7 +44,7 @@ @pytest.mark.parametrize("inplace", [True, False]) @pytest.mark.parametrize("order", ["F", "C"]) @pytest.mark.parametrize("dtype", [np.float32, np.float64, np.float16]) -def test_distance(n_rows, n_cols, inplace, order, metric, dtype): +def test_distance(n_rows, n_cols, inplace, order, metric, dtype, times): input1 = np.random.random_sample((n_rows, n_cols)) input1 = np.asarray(input1, order=order).astype(dtype) @@ -79,7 +80,5 @@ def test_distance(n_rows, n_cols, inplace, order, metric, dtype): actual = output_device.copy_to_host() tol = 1e-3 - if np.issubdtype(dtype, np.float16): - tol = 1e-1 assert np.allclose(expected, actual, atol=tol, rtol=tol)