Skip to content

Commit

Permalink
add float64 to more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lijinf2 committed May 20, 2024
1 parent 8ea37c5 commit 428ce45
Showing 1 changed file with 34 additions and 14 deletions.
48 changes: 34 additions & 14 deletions python/cuml/tests/dask/test_dask_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,25 +417,30 @@ def array_to_numpy(ary):
@pytest.mark.parametrize("fit_intercept", [False, True])
@pytest.mark.parametrize("delayed", [True, False])
def test_lbfgs(n_parts, fit_intercept, delayed, client):
_test_lbfgs(
datatype = np.float32 if fit_intercept else np.float64

lr = _test_lbfgs(
nrows=1e5,
ncols=20,
n_parts=n_parts,
fit_intercept=fit_intercept,
datatype=np.float32,
datatype=datatype,
delayed=delayed,
client=client,
)

assert lr.dtype == datatype


@pytest.mark.parametrize("fit_intercept", [False, True])
def test_noreg(fit_intercept, client):
datatype = np.float64 if fit_intercept else np.float32
lr = _test_lbfgs(
nrows=1e5,
ncols=20,
n_parts=23,
fit_intercept=fit_intercept,
datatype=np.float32,
datatype=datatype,
delayed=True,
client=client,
penalty="none",
Expand All @@ -449,6 +454,8 @@ def test_noreg(fit_intercept, client):
assert l1_strength == 0.0
assert l2_strength == 0.0

assert lr.dtype == datatype


def test_n_classes_small(client):
def assert_small(X, y, n_classes):
Expand Down Expand Up @@ -493,29 +500,31 @@ def assert_small(X, y, n_classes):
@pytest.mark.parametrize("fit_intercept", [False, True])
@pytest.mark.parametrize("n_classes", [8])
def test_n_classes(n_parts, fit_intercept, n_classes, client):
datatype = np.float32 if fit_intercept else np.float64
nrows = int(1e5) if n_classes < 5 else int(2e5)
lr = _test_lbfgs(
nrows=nrows,
ncols=20,
n_parts=n_parts,
fit_intercept=fit_intercept,
datatype=np.float32,
datatype=datatype,
delayed=True,
client=client,
penalty="l2",
n_classes=n_classes,
)

assert lr._num_classes == n_classes
assert lr.dtype == datatype


@pytest.mark.mg
@pytest.mark.parametrize("fit_intercept", [False, True])
@pytest.mark.parametrize("datatype", [np.float32])
@pytest.mark.parametrize("delayed", [True])
@pytest.mark.parametrize("n_classes", [2, 8])
@pytest.mark.parametrize("C", [1.0, 10.0])
def test_l1(fit_intercept, datatype, delayed, n_classes, C, client):
def test_l1(fit_intercept, delayed, n_classes, C, client):
datatype = np.float64 if fit_intercept else np.float32
nrows = int(1e5) if n_classes < 5 else int(2e5)
lr = _test_lbfgs(
nrows=nrows,
Expand All @@ -534,16 +543,20 @@ def test_l1(fit_intercept, datatype, delayed, n_classes, C, client):
assert l1_strength == 1.0 / lr.C
assert l2_strength == 0.0

assert lr.dtype == datatype


@pytest.mark.mg
@pytest.mark.parametrize("fit_intercept", [False, True])
@pytest.mark.parametrize("datatype", [np.float32])
@pytest.mark.parametrize("datatype", [np.float32, np.float64])
@pytest.mark.parametrize("delayed", [True])
@pytest.mark.parametrize("n_classes", [2, 8])
@pytest.mark.parametrize("l1_ratio", [0.2, 0.8])
def test_elasticnet(
fit_intercept, datatype, delayed, n_classes, l1_ratio, client
):
datatype = np.float32 if fit_intercept else np.float64

nrows = int(1e5) if n_classes < 5 else int(2e5)
lr = _test_lbfgs(
nrows=nrows,
Expand All @@ -564,6 +577,8 @@ def test_elasticnet(
assert l1_strength == lr.l1_ratio * strength
assert l2_strength == (1.0 - lr.l1_ratio) * strength

assert lr.dtype == datatype


@pytest.mark.mg
@pytest.mark.parametrize("fit_intercept", [False, True])
Expand Down Expand Up @@ -890,21 +905,23 @@ def to_dask_data(X_train, X_test, y_train, y_test):
@pytest.mark.mg
@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize(
"regularization",
"reg_dtype",
[
("none", 1.0, None),
("l2", 2.0, None),
("l1", 2.0, None),
("elasticnet", 2.0, 0.2),
(("none", 1.0, None), np.float64),
(("l2", 2.0, None), np.float64),
(("l1", 2.0, None), np.float32),
(("elasticnet", 2.0, 0.2), np.float32),
],
)
def test_standardization_example(fit_intercept, regularization, client):
def test_standardization_example(fit_intercept, reg_dtype, client):
regularization = reg_dtype[0]
datatype = reg_dtype[1]

n_rows = int(1e5)
n_cols = 20
n_info = 10
n_classes = 4

datatype = np.float32
n_parts = 2
max_iter = 5 # cannot set this too large. Observed GPU-specific coefficients when objective converges at 0.

Expand Down Expand Up @@ -979,6 +996,9 @@ def test_standardization_example(fit_intercept, regularization, client):
total_tol=tolerance,
)

assert lr_on.dtype == datatype
assert lr_off.dtype == datatype


@pytest.mark.mg
@pytest.mark.parametrize("fit_intercept", [True, False])
Expand Down

0 comments on commit 428ce45

Please sign in to comment.