From 38b37dfb3ff1aa3af9e3d0d58128a0c0301dd4e2 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 20 Nov 2024 16:37:48 -0500 Subject: [PATCH 1/3] update partykit_tree_info() to handle classification outputs --- R/model-partykit.R | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/R/model-partykit.R b/R/model-partykit.R index beac61e..9fa5cda 100644 --- a/R/model-partykit.R +++ b/R/model-partykit.R @@ -1,9 +1,18 @@ partykit_tree_info <- function(model) { model_nodes <- map(seq_along(model), ~ model[[.x]]) is_split <- map_lgl(model_nodes, ~ class(.x$node[1]) == "partynode") - # non-cat model - mean_resp <- map_dbl(model_nodes, ~ mean(.x$fitted[, "(response)"])) - prediction <- ifelse(!is_split, mean_resp, NA) + if (is.numeric(model_nodes[[1]]$fitted[["(response)"]])) { + mean_resp <- map_dbl(model_nodes, ~ mean(.x$fitted[, "(response)"])) + prediction <- ifelse(!is_split, mean_resp, NA) + } else { + stat_mode <- function(x) { + counts <- sort(table(x)) + names(counts)[1] + } + mode_resp <- map_chr(model_nodes, ~ stat_mode(.x$fitted[, "(response)"])) + prediction <- ifelse(!is_split, mode_resp, NA) + } + party_nodes <- map(seq_along(model), ~ partykit::nodeapply(model, .x)) kids <- map(party_nodes, ~ { From 1a0a4b6796e8ba5ccc5b6bf56856d1ab9538d09b Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 17 Dec 2024 14:48:06 -0800 Subject: [PATCH 2/3] make sure ties are handled correctly in partykit classification ties --- R/model-partykit.R | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/R/model-partykit.R b/R/model-partykit.R index 9fa5cda..f78db21 100644 --- a/R/model-partykit.R +++ b/R/model-partykit.R @@ -6,7 +6,11 @@ partykit_tree_info <- function(model) { prediction <- ifelse(!is_split, mean_resp, NA) } else { stat_mode <- function(x) { - counts <- sort(table(x)) + counts <- rev(sort(table(x))) + if (counts[[1]] == counts[[2]]) { + ties <- counts[counts[1] == counts] + return(names(rev(ties))[1]) + } names(counts)[1] } mode_resp <- map_chr(model_nodes, ~ stat_mode(.x$fitted[, "(response)"])) From d8bd0fd81753829033ffe47ae5afd83e153136d8 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Tue, 17 Dec 2024 14:51:40 -0800 Subject: [PATCH 3/3] add .extract_partykit_classprob() function --- DESCRIPTION | 2 +- NAMESPACE | 1 + R/model-partykit.R | 61 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 31c6738..2fae367 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -50,4 +50,4 @@ VignetteBuilder: Config/Needs/website: tidyverse/tidytemplate Config/testthat/edition: 3 Encoding: UTF-8 -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.2 diff --git a/NAMESPACE b/NAMESPACE index 31453ea..befb0e8 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -39,6 +39,7 @@ S3method(tidypredict_test,party) S3method(tidypredict_test,randomForest) S3method(tidypredict_test,ranger) S3method(tidypredict_test,xgb.Booster) +export(.extract_partykit_classprob) export(acceptable_formula) export(as_parsed_model) export(parse_model) diff --git a/R/model-partykit.R b/R/model-partykit.R index f78db21..6a65e0d 100644 --- a/R/model-partykit.R +++ b/R/model-partykit.R @@ -101,3 +101,64 @@ tidypredict_fit.party <- function(model) { parsedmodel <- parse_model(model) build_fit_formula_rf(parsedmodel)[[1]] } + +# For {orbital} +#' @keywords internal +#' @export +.extract_partykit_classprob <- function(model) { + extract_classprob <- function(model) { + mod <- model$fitted + response <- mod[["(response)"]] + weights <- mod[["(weights)"]] + + lvls <- levels(response) + weights_sum <- tapply(weights, response, sum) + weights_sum[is.na(weights_sum)] <- 0 + res <- weights_sum / sum(weights) + names(res) <- lvls + res + } + + preds <- map(seq_along(model), ~extract_classprob(model[[.x]])) + preds <- matrix( + unlist(preds), + nrow = length(preds), + byrow = TRUE, + dimnames = list(NULL, names(preds[[1]])) + ) + + generate_one_tree <- function(tree_info) { + paths <- tree_info$nodeID[tree_info[, "terminal"]] + paths <- map( + paths, + ~ { + prediction <- tree_info$prediction[tree_info$nodeID == .x] + if (is.null(prediction)) cli::cli_abort("Prediction column not found") + if (is.factor(prediction)) prediction <- as.character(prediction) + list( + prediction = prediction, + path = get_ra_path(.x, tree_info, FALSE) + ) + } + ) + + classes <- attr(model$terms, "dataClasses") + pm <- list() + pm$general$model <- "party" + pm$general$type <- "tree" + pm$general$version <- 2 + pm$trees <- list(paths) + parsedmodel <- as_parsed_model(pm) + + build_fit_formula_rf(parsedmodel)[[1]] + } + + tree_info <- partykit_tree_info(model) + + res <- list() + for (i in seq_len(ncol(preds))) { + tree_info$prediction <- preds[, i] + res[[i]] <- generate_one_tree(tree_info) + } + res +}