Skip to content

Commit

Permalink
Merge pull request #112 from tidymodels/updates
Browse files Browse the repository at this point in the history
Updates
  • Loading branch information
edgararuiz authored Jan 17, 2023
2 parents 730459c + 059f873 commit 58351b7
Showing 11 changed files with 365 additions and 183 deletions.
8 changes: 4 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
Package: tidypredict
Title: Run Predictions Inside the Database
Version: 0.4.9.9000
Version: 0.4.9.9001
Authors@R:
c(
person("Max", "Kuhn", , "max@rstudio.com", role = c("aut", "cre")),
person("Edgar", "Ruiz", email = "edgar@rstudio.com", role = c("aut"))
person("Edgar", "Ruiz", email = "edgar@posit.co", role = c("aut", "cre")),
person("Max", "Kuhn", email = "max@posit.co", role = c("aut"))
)
Description: It parses a fitted 'R' model object, and returns a formula in
'Tidy Eval' code that calculates the predictions. It works with
@@ -51,5 +51,5 @@ VignetteBuilder:
Config/Needs/website:
tidyverse/tidytemplate
Encoding: UTF-8
RoxygenNote: 7.2.0.9000
RoxygenNote: 7.2.3
Config/testthat/edition: 3
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# tidypredict (development version)

- Addresses issues with XGBoost models

- Improvements to XGBoosts tests

# tidypredict 0.4.9

- Fixes issue handling GLM Binomial earth models (#97)
2 changes: 1 addition & 1 deletion R/model-earth.R
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ parse_model.earth <- function(model) {

pm <- list()
pm$general$model <- "earth"
pm$general$type <- "tree"
pm$general$type <- "regression"
pm$general$version <- 2
pm$general$is_glm <- 0
if (is_glm) {
72 changes: 50 additions & 22 deletions R/model-xgboost.R
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@ get_xgb_path <- function(row_id, tree) {

get_xgb_tree <- function(tree) {
paths <- seq_len(nrow(tree))[tree[, "Feature"] == "Leaf"]
map(
x <- map(
paths,
~ {
list(
@@ -45,32 +45,40 @@ get_xgb_tree <- function(tree) {
)
}
)
x
}

get_xgb_trees <- function(model) {
UseMethod("get_xgb_trees", model)
}

get_xgb_trees.xgb.Booster <- function(model) {
xgb_dump_text_with_stats <- xgboost::xgb.dump(model, dump_format = "text", with_stats = TRUE)
xd <- xgboost::xgb.dump(
model = model,
dump_format = "text",
with_stats = TRUE
)
feature_names <- model$feature_names
get_xgb_trees.character(xgb_dump_text_with_stats, feature_names)
get_xgb_trees_character(xd, feature_names)
}

get_xgb_trees.character <- function(xgb_dump_text_with_stats, feature_names) {
get_xgb_trees_character <- function(xd, feature_names) {
feature_names_tbl <- data.frame(
Feature = as.character(0:(length(feature_names) - 1)),
feature_name = feature_names,
stringsAsFactors = FALSE
)
trees <- xgboost::xgb.model.dt.tree(text = xgb_dump_text_with_stats)
trees <- xgboost::xgb.model.dt.tree(text = xd)
trees <- as.data.frame(trees)
trees$original_order <- 1:nrow(trees)
trees <- merge(trees, feature_names_tbl, by = "Feature", all.x = TRUE)
trees <- trees[order(trees$original_order), !names(trees) %in% "original_order"]
trees[, c("Yes", "No", "Missing")] <- lapply(trees[, c("Yes", "No", "Missing")], function(x) sub("^.*-", "", x))
trees[, c("Yes", "No", "Missing")] <- lapply(trees[, c("Yes", "No", "Missing")], function(x) as.integer(x) + 1)
purrr::map(split(trees, trees$Tree), get_xgb_tree)
trees[, c("Yes", "No", "Missing")] <-
lapply(trees[, c("Yes", "No", "Missing")], function(x) sub("^.*-", "", x))
trees[, c("Yes", "No", "Missing")] <-
lapply(trees[, c("Yes", "No", "Missing")], function(x) as.integer(x) + 1)

trees_split <- split(trees, trees$Tree)
trees_rows <- purrr::map_dbl(trees_split, nrow)
trees_filtered <- trees_split[trees_rows > 1]

purrr::map(trees_filtered, get_xgb_tree)
}

#' @export
@@ -97,14 +105,26 @@ get_xgb_case <- function(path, prediction) {
cl <- map(
path,
~ {
if (.x$op == "less" & .x$missing) i <- expr((!!sym(.x$col) >= !!as.numeric(.x$val) | is.na(!!sym(.x$col))))
if (.x$op == "more-equal" & .x$missing) i <- expr((!!sym(.x$col) < !!as.numeric(.x$val) | is.na(!!sym(.x$col))))
if (.x$op == "less" & !.x$missing) i <- expr(!!sym(.x$col) >= !!as.numeric(.x$val))
if (.x$op == "more-equal" & !.x$missing) i <- expr(!!sym(.x$col) < !!as.numeric(.x$val))
if (.x$op == "less" & .x$missing) {
i <- expr((!!sym(.x$col) >= !!as.numeric(.x$val) | is.na(!!sym(.x$col))))
}
if (.x$op == "more-equal" & .x$missing) {
i <- expr((!!sym(.x$col) < !!as.numeric(.x$val) | is.na(!!sym(.x$col))))
}
if (.x$op == "less" & !.x$missing) {
i <- expr(!!sym(.x$col) >= !!as.numeric(.x$val))
}
if (.x$op == "more-equal" & !.x$missing) {
i <- expr(!!sym(.x$col) < !!as.numeric(.x$val))
}
i
}
)
cl <- if (length(cl) > 0) reduce(cl, function(x, y) expr(!!x & !!y)) else TRUE
cl <- if (length(cl) > 0) {
reduce(cl, function(x, y) expr(!!x & !!y))
} else {
TRUE
}
expr(!!cl ~ !!prediction)
}

@@ -132,18 +152,26 @@ build_fit_formula_xgb <- function(parsedmodel) {
if (is.null(objective)) {
assigned <- 1
f <- expr(!!f + !!base_score)
warning("If the objective is a custom function, please explicitly apply it to the output.")
} else if (objective %in% c("reg:squarederror")) {
warning(
paste(
"If the objective is a custom function, please",
"explicitly apply it to the output."
)
)
} else if (objective %in% c("reg:squarederror", "binary:logitraw")) {
assigned <- 1
f <- expr(!!f + !!base_score)
} else if (objective %in% c("binary:logitraw")) {
assigned <- 1
} else if (objective %in% c("binary:logistic", "reg:logistic")) {
assigned <- 1
f <- expr(1 - 1 / (1 + exp(!!f + log(!!base_score / (1 - !!base_score)))))
}
if (assigned == 0) {
stop("Only objectives 'binary:logistic', 'reg:squarederror', 'reg:logistic', 'binary:logitraw' are supported yet.")
stop(
paste0(
"Only objectives 'binary:logistic', 'reg:squarederror',",
"'reg:logistic', 'binary:logitraw' are supported yet."
)
)
}
f
}
3 changes: 2 additions & 1 deletion R/test-predictions.R
Original file line number Diff line number Diff line change
@@ -179,14 +179,15 @@ xgb_booster <- function(model, df = model$model, threshold = 0.000000000001,
include_intervals = FALSE, max_rows = NULL, xg_df = NULL) {
if (is.numeric(max_rows)) df <- head(df, max_rows)
base <- predict(model, xg_df)
if("model_fit" %in% class(model)) base <- base$.pred
te <- tidypredict_to_column(
df,
model,
add_interval = FALSE,
vars = c("fit_te", "upr_te", "lwr_te")
)
raw_results <- cbind(base, te)
raw_results$fit_diff <- raw_results$fit - raw_results$fit_te
raw_results$fit_diff <- raw_results$base - raw_results$fit_te
raw_results$fit_threshold <- raw_results$fit_diff > threshold

rowid <- seq_len(nrow(raw_results))
7 changes: 6 additions & 1 deletion man/tidypredict-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 8 additions & 1 deletion tests/testthat/_snaps/earth.md
Original file line number Diff line number Diff line change
@@ -3,5 +3,12 @@
Code
tidypredict_fit(pm)
Output
list()
1 - 1/(1 + exp(2.913526 + (ifelse(age > 32, age - 32, 0) * -0.0375715) +
(ifelse(pclass == "2nd", 1, 0) * ifelse(sex == "male", 1,
0) * -1.7680945) + (ifelse(pclass == "3rd", 1, 0) * -5.030056) +
(ifelse(pclass == "3rd", 1, 0) * ifelse(sibsp < 4, 4 - sibsp,
0) * 0.6186527) + (ifelse(pclass == "3rd", 1, 0) * ifelse(sex ==
"male", 1, 0) * 1.2226954) + (ifelse(sex == "male", 1, 0) *
-3.1856245) + (ifelse(sex == "male", 1, 0) * ifelse(age <
16, 16 - age, 0) * 0.241814)))

Loading

0 comments on commit 58351b7

Please sign in to comment.