From f16fcccf79fd21d5f2230ff31ab23f610d4a658e Mon Sep 17 00:00:00 2001 From: Jinfeng Li Date: Wed, 22 May 2024 10:05:37 -0700 Subject: [PATCH] Support double precision in MNMG Logistic Regression (#5898) Github issue: https://github.com/rapidsai/cuml/issues/5589 Authors: - Jinfeng Li (https://github.com/lijinf2) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - Dante Gama Dessavre (https://github.com/dantegd) URL: https://github.com/rapidsai/cuml/pull/5898 --- cpp/include/cuml/linear_model/qn_mg.hpp | 25 +-- cpp/src/glm/qn_mg.cu | 147 +++++++++++++----- .../linear_model/logistic_regression_mg.pyx | 109 +++++++++++-- .../dask/test_dask_logistic_regression.py | 146 +++++++++-------- 4 files changed, 296 insertions(+), 131 deletions(-) diff --git a/cpp/include/cuml/linear_model/qn_mg.hpp b/cpp/include/cuml/linear_model/qn_mg.hpp index 048d65c322..aa7c3226c5 100644 --- a/cpp/include/cuml/linear_model/qn_mg.hpp +++ b/cpp/include/cuml/linear_model/qn_mg.hpp @@ -37,9 +37,10 @@ namespace opg { * @param[in] labels: labels data * @returns host vector that stores the distinct labels */ -std::vector getUniquelabelsMG(const raft::handle_t& handle, - Matrix::PartDescriptor& input_desc, - std::vector*>& labels); +template +std::vector getUniquelabelsMG(const raft::handle_t& handle, + Matrix::PartDescriptor& input_desc, + std::vector*>& labels); /** * @brief performs MNMG fit operation for the logistic regression using quasi newton methods @@ -55,16 +56,17 @@ std::vector getUniquelabelsMG(const raft::handle_t& handle, * @param[out] f: host pointer holding the final objective value * @param[out] num_iters: host pointer holding the actual number of iterations taken */ +template void qnFit(raft::handle_t& handle, - std::vector*>& input_data, + std::vector*>& input_data, Matrix::PartDescriptor& input_desc, - std::vector*>& labels, - float* coef, + std::vector*>& labels, + T* coef, const qn_params& pams, bool X_col_major, bool standardization, int n_classes, - float* f, + T* f, int* num_iters); /** @@ -86,18 +88,19 @@ void qnFit(raft::handle_t& handle, * @param[out] f: host pointer holding the final objective value * @param[out] num_iters: host pointer holding the actual number of iterations taken */ +template void qnFitSparse(raft::handle_t& handle, - std::vector*>& input_values, + std::vector*>& input_values, int* input_cols, int* input_row_ids, int X_nnz, Matrix::PartDescriptor& input_desc, - std::vector*>& labels, - float* coef, + std::vector*>& labels, + T* coef, const qn_params& pams, bool standardization, int n_classes, - float* f, + T* f, int* num_iters); }; // namespace opg diff --git a/cpp/src/glm/qn_mg.cu b/cpp/src/glm/qn_mg.cu index 786df4c1ea..0c679c55f4 100644 --- a/cpp/src/glm/qn_mg.cu +++ b/cpp/src/glm/qn_mg.cu @@ -183,42 +183,76 @@ void qnFit_impl(raft::handle_t& handle, input_desc.uniqueRanks().size()); } -std::vector getUniquelabelsMG(const raft::handle_t& handle, - Matrix::PartDescriptor& input_desc, - std::vector*>& labels) +template +std::vector getUniquelabelsMG(const raft::handle_t& handle, + Matrix::PartDescriptor& input_desc, + std::vector*>& labels) { RAFT_EXPECTS(labels.size() == 1, "getUniqueLabelsMG currently does not accept more than one data chunk"); - Matrix::Data* data_y = labels[0]; - int n_rows = input_desc.totalElementsOwnedBy(input_desc.rank); - return distinct_mg(handle, data_y->ptr, n_rows); + Matrix::Data* data_y = labels[0]; + size_t n_rows = input_desc.totalElementsOwnedBy(input_desc.rank); + return distinct_mg(handle, data_y->ptr, n_rows); } +template std::vector getUniquelabelsMG(const raft::handle_t& handle, + Matrix::PartDescriptor& input_desc, + std::vector*>& labels); + +template std::vector getUniquelabelsMG(const raft::handle_t& handle, + Matrix::PartDescriptor& input_desc, + std::vector*>& labels); + +template void qnFit(raft::handle_t& handle, - std::vector*>& input_data, + std::vector*>& input_data, Matrix::PartDescriptor& input_desc, - std::vector*>& labels, - float* coef, + std::vector*>& labels, + T* coef, const qn_params& pams, bool X_col_major, bool standardization, int n_classes, - float* f, + T* f, int* num_iters) { - qnFit_impl(handle, - input_data, - input_desc, - labels, - coef, - pams, - X_col_major, - standardization, - n_classes, - f, - num_iters); + qnFit_impl(handle, + input_data, + input_desc, + labels, + coef, + pams, + X_col_major, + standardization, + n_classes, + f, + num_iters); } +template void qnFit(raft::handle_t& handle, + std::vector*>& input_data, + Matrix::PartDescriptor& input_desc, + std::vector*>& labels, + float* coef, + const qn_params& pams, + bool X_col_major, + bool standardization, + int n_classes, + float* f, + int* num_iters); + +template void qnFit(raft::handle_t& handle, + std::vector*>& input_data, + Matrix::PartDescriptor& input_desc, + std::vector*>& labels, + double* coef, + const qn_params& pams, + bool X_col_major, + bool standardization, + int n_classes, + double* f, + int* num_iters); + template void qnFitSparse_impl(const raft::handle_t& handle, const qn_params& pams, @@ -269,18 +303,19 @@ void qnFitSparse_impl(const raft::handle_t& handle, return; } +template void qnFitSparse(raft::handle_t& handle, - std::vector*>& input_values, + std::vector*>& input_values, int* input_cols, int* input_row_ids, int X_nnz, Matrix::PartDescriptor& input_desc, - std::vector*>& labels, - float* coef, + std::vector*>& labels, + T* coef, const qn_params& pams, bool standardization, int n_classes, - float* f, + T* f, int* num_iters) { RAFT_EXPECTS(input_values.size() == 1, @@ -289,25 +324,53 @@ void qnFitSparse(raft::handle_t& handle, auto data_input_values = input_values[0]; auto data_y = labels[0]; - qnFitSparse_impl(handle, - pams, - data_input_values->ptr, - input_cols, - input_row_ids, - X_nnz, - standardization, - data_y->ptr, - input_desc.totalElementsOwnedBy(input_desc.rank), - input_desc.N, - n_classes, - coef, - f, - num_iters, - input_desc.M, - input_desc.rank, - input_desc.uniqueRanks().size()); + qnFitSparse_impl(handle, + pams, + data_input_values->ptr, + input_cols, + input_row_ids, + X_nnz, + standardization, + data_y->ptr, + input_desc.totalElementsOwnedBy(input_desc.rank), + input_desc.N, + n_classes, + coef, + f, + num_iters, + input_desc.M, + input_desc.rank, + input_desc.uniqueRanks().size()); } +template void qnFitSparse(raft::handle_t& handle, + std::vector*>& input_values, + int* input_cols, + int* input_row_ids, + int X_nnz, + Matrix::PartDescriptor& input_desc, + std::vector*>& labels, + float* coef, + const qn_params& pams, + bool standardization, + int n_classes, + float* f, + int* num_iters); + +template void qnFitSparse(raft::handle_t& handle, + std::vector*>& input_values, + int* input_cols, + int* input_row_ids, + int X_nnz, + Matrix::PartDescriptor& input_desc, + std::vector*>& labels, + double* coef, + const qn_params& pams, + bool standardization, + int n_classes, + double* f, + int* num_iters); + }; // namespace opg }; // namespace GLM }; // namespace ML diff --git a/python/cuml/linear_model/logistic_regression_mg.pyx b/python/cuml/linear_model/logistic_regression_mg.pyx index 0a6988a804..834ee1f41d 100644 --- a/python/cuml/linear_model/logistic_regression_mg.pyx +++ b/python/cuml/linear_model/logistic_regression_mg.pyx @@ -80,11 +80,29 @@ cdef extern from "cuml/linear_model/qn_mg.hpp" namespace "ML::GLM::opg" nogil: float *f, int *num_iters) except + + cdef void qnFit( + handle_t& handle, + vector[doubleData_t *] input_data, + PartDescriptor &input_desc, + vector[doubleData_t *] labels, + double *coef, + const qn_params& pams, + bool X_col_major, + bool standardization, + int n_classes, + double *f, + int *num_iters) except + + cdef vector[float] getUniquelabelsMG( const handle_t& handle, PartDescriptor &input_desc, vector[floatData_t*] labels) except+ + cdef vector[double] getUniquelabelsMG( + const handle_t& handle, + PartDescriptor &input_desc, + vector[doubleData_t*] labels) except+ + cdef void qnFitSparse( handle_t& handle, vector[floatData_t *] input_values, @@ -100,6 +118,21 @@ cdef extern from "cuml/linear_model/qn_mg.hpp" namespace "ML::GLM::opg" nogil: float *f, int *num_iters) except + + cdef void qnFitSparse( + handle_t& handle, + vector[doubleData_t *] input_values, + int *input_cols, + int *input_row_ids, + int X_nnz, + PartDescriptor &input_desc, + vector[doubleData_t *] labels, + double *coef, + const qn_params& pams, + bool standardization, + int n_classes, + double *f, + int *num_iters) except + + class LogisticRegressionMG(MGFitMixin, LogisticRegression): @@ -199,14 +232,25 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression): cdef handle_t* handle_ = self.handle.getHandle() cdef float objective32 + cdef float objective64 cdef int num_iters cdef vector[float] c_classes_ - c_classes_ = getUniquelabelsMG( - handle_[0], - deref(input_desc), - deref(y)) - self.classes_ = np.sort(list(c_classes_)).astype('float32') + cdef vector[double] c_classes_64 + if self.dtype == np.float32: + c_classes_ = getUniquelabelsMG( + handle_[0], + deref(input_desc), + deref(y)) + self.classes_ = np.sort(list(c_classes_)).astype(np.float32) + elif self.dtype == np.float64: + c_classes_64 = getUniquelabelsMG( + handle_[0], + deref(input_desc), + deref(y)) + self.classes_ = np.sort(list(c_classes_64)) + else: + assert False, "dtypes other than float32 and float64 are currently not supported yet." self._num_classes = len(self.classes_) self.loss = "sigmoid" if self._num_classes <= 2 else "softmax" @@ -220,6 +264,7 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression): if self.dtype == np.float32: if sparse_input is False: + qnFit( handle_[0], deref(X), @@ -227,9 +272,9 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression): deref(y), mat_coef_ptr, qnpams, - self.is_col_major, - self.standardization, - self._num_classes, + self.is_col_major, + self.standardization, + self._num_classes, &objective32, &num_iters) @@ -245,20 +290,60 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression): deref(X_values), X_cols, X_row_ids, - X_nnz, + X_nnz, deref(input_desc), deref(y), mat_coef_ptr, qnpams, - self.standardization, - self._num_classes, + self.standardization, + self._num_classes, &objective32, &num_iters) self.solver_model.objective = objective32 + elif self.dtype == np.float64: + if sparse_input is False: + + qnFit( + handle_[0], + deref(X), + deref(input_desc), + deref(y), + mat_coef_ptr, + qnpams, + self.is_col_major, + self.standardization, + self._num_classes, + &objective64, + &num_iters) + + else: + assert len(X) == 4 + X_values = X[0] + X_cols = X[1] + X_row_ids = X[2] + X_nnz = X[3] + + qnFitSparse( + handle_[0], + deref(X_values), + X_cols, + X_row_ids, + X_nnz, + deref(input_desc), + deref(y), + mat_coef_ptr, + qnpams, + self.standardization, + self._num_classes, + &objective32, + &num_iters) + + self.solver_model.objective = objective64 + else: - assert False, "dtypes other than float32 are currently not supported yet. See issue: https://github.com/rapidsai/cuml/issues/5589" + assert False, "dtypes other than float32 and float64 are currently not supported yet." self.solver_model.num_iters = num_iters diff --git a/python/cuml/tests/dask/test_dask_logistic_regression.py b/python/cuml/tests/dask/test_dask_logistic_regression.py index a512d78d4f..9d46fa0147 100644 --- a/python/cuml/tests/dask/test_dask_logistic_regression.py +++ b/python/cuml/tests/dask/test_dask_logistic_regression.py @@ -187,7 +187,7 @@ def imp(): @pytest.mark.mg @pytest.mark.parametrize("n_parts", [2]) -@pytest.mark.parametrize("datatype", [np.float32]) +@pytest.mark.parametrize("datatype", [np.float32, np.float64]) def test_lbfgs_toy(n_parts, datatype, client): def imp(): import cuml.comm.serialize # NOQA @@ -217,16 +217,7 @@ def imp(): from numpy.testing import assert_array_equal assert_array_equal(preds, y, strict=True) - - # assert error on float64 - X = X.astype(np.float64) - y = y.astype(np.float64) - X_df, y_df = _prep_training_data(client, X, y, n_parts) - with pytest.raises( - RuntimeError, - match="dtypes other than float32 are currently not supported yet. See issue: https://github.com/rapidsai/cuml/issues/5589", - ): - lr.fit(X_df, y_df) + assert lr.dtype == datatype def test_lbfgs_init(client): @@ -426,25 +417,30 @@ def array_to_numpy(ary): @pytest.mark.parametrize("fit_intercept", [False, True]) @pytest.mark.parametrize("delayed", [True, False]) def test_lbfgs(n_parts, fit_intercept, delayed, client): - _test_lbfgs( + datatype = np.float32 if fit_intercept else np.float64 + + lr = _test_lbfgs( nrows=1e5, ncols=20, n_parts=n_parts, fit_intercept=fit_intercept, - datatype=np.float32, + datatype=datatype, delayed=delayed, client=client, ) + assert lr.dtype == datatype + @pytest.mark.parametrize("fit_intercept", [False, True]) def test_noreg(fit_intercept, client): + datatype = np.float64 if fit_intercept else np.float32 lr = _test_lbfgs( nrows=1e5, ncols=20, n_parts=23, fit_intercept=fit_intercept, - datatype=np.float32, + datatype=datatype, delayed=True, client=client, penalty="none", @@ -458,6 +454,8 @@ def test_noreg(fit_intercept, client): assert l1_strength == 0.0 assert l2_strength == 0.0 + assert lr.dtype == datatype + def test_n_classes_small(client): def assert_small(X, y, n_classes): @@ -502,13 +500,14 @@ def assert_small(X, y, n_classes): @pytest.mark.parametrize("fit_intercept", [False, True]) @pytest.mark.parametrize("n_classes", [8]) def test_n_classes(n_parts, fit_intercept, n_classes, client): + datatype = np.float32 if fit_intercept else np.float64 nrows = int(1e5) if n_classes < 5 else int(2e5) lr = _test_lbfgs( nrows=nrows, ncols=20, n_parts=n_parts, fit_intercept=fit_intercept, - datatype=np.float32, + datatype=datatype, delayed=True, client=client, penalty="l2", @@ -516,15 +515,16 @@ def test_n_classes(n_parts, fit_intercept, n_classes, client): ) assert lr._num_classes == n_classes + assert lr.dtype == datatype @pytest.mark.mg @pytest.mark.parametrize("fit_intercept", [False, True]) -@pytest.mark.parametrize("datatype", [np.float32]) @pytest.mark.parametrize("delayed", [True]) @pytest.mark.parametrize("n_classes", [2, 8]) @pytest.mark.parametrize("C", [1.0, 10.0]) -def test_l1(fit_intercept, datatype, delayed, n_classes, C, client): +def test_l1(fit_intercept, delayed, n_classes, C, client): + datatype = np.float64 if fit_intercept else np.float32 nrows = int(1e5) if n_classes < 5 else int(2e5) lr = _test_lbfgs( nrows=nrows, @@ -543,16 +543,20 @@ def test_l1(fit_intercept, datatype, delayed, n_classes, C, client): assert l1_strength == 1.0 / lr.C assert l2_strength == 0.0 + assert lr.dtype == datatype + @pytest.mark.mg @pytest.mark.parametrize("fit_intercept", [False, True]) -@pytest.mark.parametrize("datatype", [np.float32]) +@pytest.mark.parametrize("datatype", [np.float32, np.float64]) @pytest.mark.parametrize("delayed", [True]) @pytest.mark.parametrize("n_classes", [2, 8]) @pytest.mark.parametrize("l1_ratio", [0.2, 0.8]) def test_elasticnet( fit_intercept, datatype, delayed, n_classes, l1_ratio, client ): + datatype = np.float32 if fit_intercept else np.float64 + nrows = int(1e5) if n_classes < 5 else int(2e5) lr = _test_lbfgs( nrows=nrows, @@ -573,24 +577,24 @@ def test_elasticnet( assert l1_strength == lr.l1_ratio * strength assert l2_strength == (1.0 - lr.l1_ratio) * strength + assert lr.dtype == datatype + @pytest.mark.mg @pytest.mark.parametrize("fit_intercept", [False, True]) @pytest.mark.parametrize( - "regularization", + "reg_dtype", [ - ("none", 1.0, None), - ("l2", 2.0, None), - ("l1", 2.0, None), - ("elasticnet", 2.0, 0.2), + (("none", 1.0, None), np.float32), + (("l2", 2.0, None), np.float64), + (("l1", 2.0, None), np.float32), + (("elasticnet", 2.0, 0.2), np.float64), ], ) -@pytest.mark.parametrize("datatype", [np.float32, np.float64]) @pytest.mark.parametrize("n_classes", [2, 8]) -def test_sparse_from_dense( - fit_intercept, regularization, datatype, n_classes, client -): - penalty, C, l1_ratio = regularization +def test_sparse_from_dense(fit_intercept, reg_dtype, n_classes, client): + penalty, C, l1_ratio = reg_dtype[0] + datatype = reg_dtype[1] nrows = int(1e5) if n_classes < 5 else int(2e5) run_test = partial( @@ -609,17 +613,11 @@ def test_sparse_from_dense( convert_to_sparse=True, ) - if datatype == np.float32: - run_test() - else: - with pytest.raises( - RuntimeError, - match="dtypes other than float32 are currently not supported", - ): - run_test() + lr = run_test() + assert lr.dtype == datatype -@pytest.mark.parametrize("dtype", [np.float32]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_sparse_nlp20news(dtype, nlp_20news, client): X, y = nlp_20news @@ -686,21 +684,22 @@ def test_exception_one_label(fit_intercept, client): @pytest.mark.mg @pytest.mark.parametrize("fit_intercept", [False, True]) @pytest.mark.parametrize( - "regularization", + "reg_dtype", [ - ("none", 1.0, None), - ("l2", 2.0, None), - ("l1", 2.0, None), - ("elasticnet", 2.0, 0.2), + (("none", 1.0, None), np.float64), + (("l2", 2.0, None), np.float32), + (("l1", 2.0, None), np.float64), + (("elasticnet", 2.0, 0.2), np.float32), ], ) -@pytest.mark.parametrize("datatype", [np.float32]) @pytest.mark.parametrize("delayed", [False]) @pytest.mark.parametrize("n_classes", [2, 8]) def test_standardization_on_normal_dataset( - fit_intercept, regularization, datatype, delayed, n_classes, client + fit_intercept, reg_dtype, delayed, n_classes, client ): + regularization = reg_dtype[0] + datatype = reg_dtype[1] penalty = regularization[0] C = regularization[1] l1_ratio = regularization[2] @@ -708,7 +707,7 @@ def test_standardization_on_normal_dataset( nrows = int(1e5) if n_classes < 5 else int(2e5) # test correctness compared with scikit-learn - _test_lbfgs( + lr = _test_lbfgs( nrows=nrows, ncols=20, n_parts=2, @@ -722,26 +721,29 @@ def test_standardization_on_normal_dataset( l1_ratio=l1_ratio, standardization=True, ) + assert lr.dtype == datatype @pytest.mark.mg @pytest.mark.parametrize("fit_intercept", [False, True]) @pytest.mark.parametrize( - "regularization", + "reg_dtype", [ - ("none", 1.0, None), - ("l2", 2.0, None), - ("l1", 2.0, None), - ("elasticnet", 2.0, 0.2), + (("none", 1.0, None), np.float32), + (("l2", 2.0, None), np.float32), + (("l1", 2.0, None), np.float64), + (("elasticnet", 2.0, 0.2), np.float64), ], ) -@pytest.mark.parametrize("datatype", [np.float32]) @pytest.mark.parametrize("delayed", [False]) @pytest.mark.parametrize("ncol_and_nclasses", [(2, 2), (6, 4), (100, 10)]) def test_standardization_on_scaled_dataset( - fit_intercept, regularization, datatype, delayed, ncol_and_nclasses, client + fit_intercept, reg_dtype, delayed, ncol_and_nclasses, client ): + regularization = reg_dtype[0] + datatype = reg_dtype[1] + penalty = regularization[0] C = regularization[1] l1_ratio = regularization[2] @@ -896,25 +898,30 @@ def to_dask_data(X_train, X_test, y_train, y_test): total_tol=tolerance, ) + assert mgon.dtype == datatype + assert mgoff.dtype == datatype + @pytest.mark.mg @pytest.mark.parametrize("fit_intercept", [True, False]) @pytest.mark.parametrize( - "regularization", + "reg_dtype", [ - ("none", 1.0, None), - ("l2", 2.0, None), - ("l1", 2.0, None), - ("elasticnet", 2.0, 0.2), + (("none", 1.0, None), np.float64), + (("l2", 2.0, None), np.float64), + (("l1", 2.0, None), np.float32), + (("elasticnet", 2.0, 0.2), np.float32), ], ) -def test_standardization_example(fit_intercept, regularization, client): +def test_standardization_example(fit_intercept, reg_dtype, client): + regularization = reg_dtype[0] + datatype = reg_dtype[1] + n_rows = int(1e5) n_cols = 20 n_info = 10 n_classes = 4 - datatype = np.float32 n_parts = 2 max_iter = 5 # cannot set this too large. Observed GPU-specific coefficients when objective converges at 0. @@ -989,19 +996,25 @@ def test_standardization_example(fit_intercept, regularization, client): total_tol=tolerance, ) + assert lr_on.dtype == datatype + assert lr_off.dtype == datatype + @pytest.mark.mg @pytest.mark.parametrize("fit_intercept", [True, False]) @pytest.mark.parametrize( - "regularization", + "reg_dtype", [ - ("none", 1.0, None), - ("l2", 2.0, None), - ("l1", 2.0, None), - ("elasticnet", 2.0, 0.2), + (("none", 1.0, None), np.float64), + (("l2", 2.0, None), np.float32), + (("l1", 2.0, None), np.float64), + (("elasticnet", 2.0, 0.2), np.float32), ], ) -def test_standardization_sparse(fit_intercept, regularization, client): +def test_standardization_sparse(fit_intercept, reg_dtype, client): + regularization = reg_dtype[0] + datatype = reg_dtype[1] + n_rows = 10000 n_cols = 25 n_info = 15 @@ -1009,7 +1022,6 @@ def test_standardization_sparse(fit_intercept, regularization, client): nnz = int(n_rows * n_cols * 0.3) # number of non-zero values tolerance = 0.005 - datatype = np.float32 n_parts = 10 max_iter = 5 # cannot set this too large. Observed GPU-specific coefficients when objective converges at 0. @@ -1089,3 +1101,5 @@ def make_classification_with_nnz( assert array_equal( lron_intercept_origin, sg.intercept_, unit_tol=tolerance ) + + assert lr_on.dtype == datatype