Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] l2_exp random fail in half-float32 mixed precision on self-neighboring #596

Merged
merged 2 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions cpp/src/distance/detail/distance_ops/l2_exp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ namespace cuvs::distance::detail::ops {
* for round-off error tolerance.
* @tparam DataT
*/
template <typename DataT>
__device__ constexpr DataT get_clamp_precision()
template <typename DataT, typename AccT>
__device__ constexpr AccT get_clamp_precision()
{
switch (sizeof(DataT)) {
case 2: return 1e-3;
case 4: return 1e-6;
case 8: return 1e-15;
default: return 0;
case 2: return AccT{1e-3};
case 4: return AccT{1e-6};
case 8: return AccT{1e-15};
default: return AccT{0};
}
}

Expand All @@ -46,19 +46,27 @@ struct l2_exp_cutlass_op {

__device__ l2_exp_cutlass_op() noexcept : sqrt(false) {}
__device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {}
inline __device__ AccT operator()(DataT aNorm, DataT bNorm, DataT accVal) const noexcept
inline __device__ AccT operator()(AccT aNorm, AccT bNorm, AccT accVal) const noexcept
{
AccT outVal = aNorm + bNorm - DataT(2.0) * accVal;
AccT outVal = aNorm + bNorm - AccT(2.0) * accVal;

/**
* Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product (accVal)
* can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal instead.
*/
outVal = outVal * AccT(!((outVal * outVal < get_clamp_precision<AccT>()) * (aNorm == bNorm)));
outVal =
outVal * AccT(!((outVal * outVal < get_clamp_precision<DataT, AccT>()) * (aNorm == bNorm)));
return sqrt ? raft::sqrt(outVal * static_cast<AccT>(outVal > AccT(0))) : outVal;
}

__device__ AccT operator()(DataT aData) const noexcept { return aData; }
__device__ AccT operator()(DataT aData) const noexcept
{
if constexpr (std::is_same_v<DataT, half> && std::is_same_v<AccT, float>) {
return __half2float(aData);
} else {
return aData;
}
}
};

/**
Expand Down Expand Up @@ -121,9 +129,9 @@ struct l2_exp_distance_op {
* (accVal) can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal
* instead.
*/
acc[i][j] =
val * static_cast<AccT>((val > AccT(0))) *
static_cast<AccT>(!((val * val < get_clamp_precision<AccT>()) * (regxn[i] == regyn[j])));
acc[i][j] = val * static_cast<AccT>((val > AccT(0))) *
static_cast<AccT>(
!((val * val < get_clamp_precision<DataT, AccT>()) * (regxn[i] == regyn[j])));
}
}
if (sqrt) {
Expand Down
5 changes: 2 additions & 3 deletions python/cuvs/cuvs/test/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from cuvs.distance import pairwise_distance


@pytest.mark.parametrize("times", range(20))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whats this times parameter used for? I don't see it used in the test it self -

Are you just trying to run this test multiple times here to stress test it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's just for testing multiple times and to guarantee the reproducing on one going because the possibility is close to ~10% empirically.

@pytest.mark.parametrize("n_rows", [50, 100])
@pytest.mark.parametrize("n_cols", [10, 50])
@pytest.mark.parametrize(
Expand All @@ -43,7 +44,7 @@
@pytest.mark.parametrize("inplace", [True, False])
@pytest.mark.parametrize("order", ["F", "C"])
@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.float16])
def test_distance(n_rows, n_cols, inplace, order, metric, dtype):
def test_distance(n_rows, n_cols, inplace, order, metric, dtype, times):
input1 = np.random.random_sample((n_rows, n_cols))
input1 = np.asarray(input1, order=order).astype(dtype)

Expand Down Expand Up @@ -79,7 +80,5 @@ def test_distance(n_rows, n_cols, inplace, order, metric, dtype):
actual = output_device.copy_to_host()

tol = 1e-3
if np.issubdtype(dtype, np.float16):
tol = 1e-1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I added this reduced tolerance because I was seeing failures - is this no longer needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I tried it successfully at local machines. I think this change can help us block potential actual failures in the future, so I made it.


assert np.allclose(expected, actual, atol=tol, rtol=tol)
Loading