From 47416d7f417382e17b3b4c45d098fbeaff640bc9 Mon Sep 17 00:00:00 2001 From: Jinsol Park Date: Thu, 23 May 2024 08:25:27 -0700 Subject: [PATCH] Fix RandomForestClassifier return type (#5896) Closes #5637 ``` import cuml from cuml.datasets import make_classification X, y = make_classification() clf = cuml.ensemble.RandomForestClassifier().fit(X,y) print(clf.predict(X[:5]).dtype) ``` Result is ``` int64 ``` Authors: - Jinsol Park (https://github.com/jinsolp) Approvers: - Dante Gama Dessavre (https://github.com/dantegd) URL: https://github.com/rapidsai/cuml/pull/5896 --- python/cuml/ensemble/randomforestclassifier.pyx | 3 ++- python/cuml/tests/test_random_forest.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx index ba16335dad..23a1bae940 100644 --- a/python/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/ensemble/randomforestclassifier.pyx @@ -1,6 +1,6 @@ # -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -550,6 +550,7 @@ class RandomForestClassifier(BaseRandomForestModel, domain="cuml_python") @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')], return_values=[('dense', '(n_samples, 1)')]) + @cuml.internals.api_base_return_array(get_output_dtype=True) def predict(self, X, predict_model="GPU", threshold=0.5, algo='auto', convert_dtype=True, fil_sparse_format='auto') -> CumlArray: diff --git a/python/cuml/tests/test_random_forest.py b/python/cuml/tests/test_random_forest.py index d7f6ff6705..b18d6ec8ab 100644 --- a/python/cuml/tests/test_random_forest.py +++ b/python/cuml/tests/test_random_forest.py @@ -1382,3 +1382,11 @@ def test_rf_min_samples_split_with_small_float(estimator, make_data): # Does not error clf.fit(X, y) + + +def test_rf_predict_returns_int(): + + X, y = make_classification() + clf = cuml.ensemble.RandomForestClassifier().fit(X, y) + pred = clf.predict(X) + assert pred.dtype == np.int64