Skip to content

Commit

Permalink
also renamed param in R for consistency with other R plot functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sebhrusen committed Oct 31, 2023
1 parent 445b209 commit 4c798ce
Show file tree
Hide file tree
Showing 12 changed files with 121 additions and 110 deletions.
93 changes: 50 additions & 43 deletions h2o-r/h2o-package/R/models.R
Original file line number Diff line number Diff line change
Expand Up @@ -5572,7 +5572,7 @@ h2o.cross_validation_predictions <- function(object) {
#' partial dependence the mean response (probabilities) is returned rather than the mean of the log class probability.
#'
#' @param object An \linkS4class{H2OModel} object.
#' @param data An H2OFrame object used for scoring and constructing the plot.
#' @param newdata An H2OFrame object used for scoring and constructing the plot.
#' @param cols Feature(s) for which partial dependence will be calculated.
#' @param destination_key An key reference to the created partial dependence tables in H2O.
#' @param nbins Number of bins used. For categorical columns make sure the number of bins exceeds the level count.
Expand Down Expand Up @@ -5613,19 +5613,20 @@ h2o.cross_validation_predictions <- function(object) {
#' iris_gbm <- h2o.gbm(x = c(1:4), y = 5, training_frame = iris_hex)
#'
#' # one target class
#' h2o.partialPlot(object = iris_gbm, data = iris_hex, cols="Petal.Length", targets=c("setosa"))
#' h2o.partialPlot(object = iris_gbm, newdata = iris_hex, cols="Petal.Length", targets=c("setosa"))
#' # three target classes
#' h2o.partialPlot(object = iris_gbm, data = iris_hex, cols="Petal.Length",
#' h2o.partialPlot(object = iris_gbm, newdata = iris_hex, cols="Petal.Length",
#' targets=c("setosa", "virginica", "versicolor"))
#' }
#' @export

h2o.partialPlot <- function(object, data, cols, destination_key, nbins=20, plot = TRUE, plot_stddev = TRUE,
weight_column=-1, include_na=FALSE, user_splits=NULL, col_pairs_2dpdp=NULL, save_to=NULL,
row_index=-1, targets=NULL) {
h2o.partialPlot <- function(object, newdata, cols, destination_key, nbins=20, plot = TRUE, plot_stddev = TRUE,
weight_column=-1, include_na=FALSE, user_splits=NULL, col_pairs_2dpdp=NULL, save_to=NULL,
row_index=-1, targets=NULL, ...) {
varargs <- list(...)
if(!is(object, "H2OModel")) stop("object must be an H2Omodel")
if( is(object, "H2OOrdinalModel")) stop("object must be a regression model or binary and multinomial classfier")
if(!is(data, "H2OFrame")) stop("data must be H2OFrame")
if(!is(newdata, "H2OFrame")) stop("newdata must be H2OFrame")
if(!is.numeric(nbins) | !(nbins > 0) ) stop("nbins must be a positive numeric")
if(!is.logical(plot)) stop("plot must be a logical value")
if(!is.logical(plot_stddev)) stop("plot must be a logical value")
Expand All @@ -5636,61 +5637,67 @@ h2o.partialPlot <- function(object, data, cols, destination_key, nbins=20, plot
if(!is.character(targets[i])) stop("targets parameter must be a list of string values")
}
}
for (arg in names(varargs)) {
if (arg == 'data') {
warning("argument 'data' is deprecated; please use 'newdata' instead.")
if (missing(newdata)) newdata <- varargs$data else warning("ignoring 'data' as 'newdata' was also provided.")
}
}

noPairs = missing(col_pairs_2dpdp)
noCols = missing(cols)
if(noCols && noPairs) cols = object@parameters$x # set to default only if both are missing
noPairs <- missing(col_pairs_2dpdp)
noCols <- missing(cols)
if(noCols && noPairs) cols <- object@parameters$x # set to default only if both are missing

y = object@parameters$y
numCols = 0
numColPairs = 0
y <- object@parameters$y
numCols <- 0
numColPairs <- 0
if (!missing(cols)) { # check valid cols in cols for 1d pdp
x <- cols
args <- .verify_dataxy(data, x, y)
args <- .verify_dataxy(newdata, x, y)
}
cpairs <- NULL
if (!missing(col_pairs_2dpdp)) { # verify valid cols for 2d pdp
for (onePair in col_pairs_2dpdp) {
pargs <- .verify_dataxy(data, onePair, y)
pargs <- .verify_dataxy(newdata, onePair, y)
cpairs <-
c(cpairs, paste0("[", paste (pargs$x, collapse = ','), "]"))
}
numColPairs = length(cpairs)
numColPairs <- length(cpairs)
}

if (is.numeric(weight_column) && (weight_column != -1)) {
stop("weight_column should be a column name of your data frame.")
} else if (is.character(weight_column)) { # weight_column_index is column name
if (!weight_column %in% h2o.names(data))
if (!weight_column %in% h2o.names(newdata))
stop("weight_column_index should be one of your columns in your data frame.")
else
weight_column <- match(weight_column, h2o.names(data))-1
weight_column <- match(weight_column, h2o.names(newdata))-1
}

if (!is.numeric(row_index)) {
stop("row_index should be numeric.")
}

parms = list()
parms <- list()
if (!missing(col_pairs_2dpdp)) {
parms$col_pairs_2dpdp <- paste0("[", paste (cpairs, collapse = ','), "]")
}
if (!missing(cols)) {
parms$cols <- paste0("[", paste (args$x, collapse = ','), "]")
numCols = length(cols)
numCols <- length(cols)
}
if(is.null(targets)){
num_1d_pp_data <- numCols
} else {
num_1d_pp_data <- numCols * length(targets)
}
noCols = missing(cols)
noCols <- missing(cols)
parms$model_id <- attr(object, "model_id")
parms$frame_id <- attr(data, "id")
parms$frame_id <- attr(newdata, "id")
parms$nbins <- nbins
parms$weight_column_index <- weight_column
parms$add_missing_na <- include_na
parms$row_index = row_index
parms$row_index <- row_index

if (is.null(user_splits) || length(user_splits) == 0) {
parms$user_cols <- NULL
Expand All @@ -5700,15 +5707,15 @@ h2o.partialPlot <- function(object, data, cols, destination_key, nbins=20, plot
user_cols <- c()
user_values <- c()
user_num_splits <- c()
column_names <- h2o.names(data)
column_names <- h2o.names(newdata)
for (ind in c(1:length(user_splits))) {
aList <- user_splits[[ind]]
csname = aList[1]
csname <- aList[1]
if (csname %in% column_names) {
if (h2o.isnumeric(data[csname]) || h2o.isfactor(data[csname]) || h2o.getTypes(data)[[which(names(data) == csname)]] == "time") {
if (h2o.isnumeric(newdata[csname]) || h2o.isfactor(newdata[csname]) || h2o.getTypes(newdata)[[which(names(data) == csname)]] == "time") {
nVal <- length(aList)-1
if (h2o.isfactor(data[csname])) {
domains <- h2o.levels(data[csname]) # enum values
if (h2o.isfactor(newdata[csname])) {
domains <- h2o.levels(newdata[csname]) # enum values
tempVal <- aList[2:length(aList)]
intVals <- c(1:length(tempVal))
for (eleind in c(1:nVal)) {
Expand Down Expand Up @@ -5769,7 +5776,7 @@ h2o.partialPlot <- function(object, data, cols, destination_key, nbins=20, plot
max_upper <- max(max_upper, pp[,2] + pp[,3])
if (i <= num_1d_pp_data) {
if(is.null(targets)){
col_name_index = i
col_name_index <- i
title <- paste("Partial dependency plot for", cols[col_name_index])
} else if(!is.null(targets)){
if(length(cols) > 1 && i %% length(cols) == 0) {
Expand Down Expand Up @@ -5798,8 +5805,8 @@ h2o.partialPlot <- function(object, data, cols, destination_key, nbins=20, plot
}
}
}
col_types = unlist(h2o.getTypes(data))
col_names = names(data)
col_types <- unlist(h2o.getTypes(newdata))
col_names <- names(newdata)

pp.plot.1d <- function(pp) {
if(!all(is.na(pp))) {
Expand All @@ -5820,8 +5827,8 @@ h2o.partialPlot <- function(object, data, cols, destination_key, nbins=20, plot
## Plot one standard deviation above and below the mean
if(plot_stddev) {
## Added upper and lower std dev confidence bound
upper = y + stddev
lower = y - stddev
upper <- y + stddev
lower <- y - stddev
plot(pp[,1:2], type = line_type, pch=pch, medpch=pch, medcol="red", medlty=0, staplelty=0, boxlty=0, col="red", main = attr(pp,"description"), ylim = c(min(lower), max(upper)))
pp.plot.1d.plotNA(pp, type, "red")
polygon(pp.plot.1d.proccessDataForPolygon(c(pp[,1], rev(pp[,1])), c(lower, rev(upper))) , col = adjustcolor("red", alpha.f = 0.1), border = F)
Expand All @@ -5840,7 +5847,7 @@ h2o.partialPlot <- function(object, data, cols, destination_key, nbins=20, plot

pp.plot.1d.plotNA <- function(pp, type, color) {
## Plot NA value if numerical
NAsIds = which(is.na(pp[,1:1]))
NAsIds <- which(is.na(pp[,1:1]))
if (type != "enum" && include_na && length(NAsIds) != 0) {
points(pp[,1:1],array(pp[NAsIds, 2:2], dim = c(length(pp[,1:1]), 1)), col=color, type="l", lty=5)
if (is.null(targets)) {
Expand Down Expand Up @@ -5880,10 +5887,10 @@ h2o.partialPlot <- function(object, data, cols, destination_key, nbins=20, plot

pp.plot.1d.proccessDataForPolygon <- function(X, Y) {
## polygon can't handle NAs
NAsIds = which(is.na(X))
NAsIds <- which(is.na(X))
if (length(NAsIds) != 0) {
X = X[-NAsIds]
Y = Y[-NAsIds]
X <- X[-NAsIds]
Y <- Y[-NAsIds]
}
return(cbind(X, Y))
}
Expand Down Expand Up @@ -5972,10 +5979,10 @@ h2o.partialPlot <- function(object, data, cols, destination_key, nbins=20, plot
## Plot one standard deviation above and below the mean
if (plot_stddev) {
## Added upper and lower std dev confidence bound
upper = pp[, 3] + pp[, 4]
lower = pp[, 3] - pp[, 4]
Zupper = matrix(upper, ncol=dim(XX)[2], byrow=F)
Zlower = matrix(lower, ncol=dim(XX)[2], byrow=F)
upper <- pp[, 3] + pp[, 4]
lower <- pp[, 3] - pp[, 4]
Zupper <- matrix(upper, ncol=dim(XX)[2], byrow=F)
Zlower <- matrix(lower, ncol=dim(XX)[2], byrow=F)
rgl::open3d()
plot3Drgl::persp3Drgl(XX, YY, ZZ, theta=30, phi=15, axes=TRUE,scale=2, box=TRUE, nticks=5,
ticktype="detailed", xlab=names(pp)[1], ylab=names(pp)[2], zlab="2D partial plots",
Expand Down Expand Up @@ -6015,7 +6022,7 @@ h2o.partialPlot <- function(object, data, cols, destination_key, nbins=20, plot
pp.plot.save.2d <- function(pp, nBins=nbins, user_cols=NULL, user_num_splits=NULL) {
# If user accidentally provides one of the most common suffixes in R, it is removed.
save_to <- gsub(replacement = "", pattern = "(\\.png)|(\\.jpg)|(\\.pdf)", x = save_to)
colnames = paste0(names(pp)[1], "_", names(pp)[2])
colnames <- paste0(names(pp)[1], "_", names(pp)[2])
destination_file <- paste0(save_to,"_",colnames,'.png')
pp.plot.2d(pp, nbins, user_cols, user_num_splits)
rgl::snapshot3d(destination_file)
Expand All @@ -6033,7 +6040,7 @@ h2o.partialPlot <- function(object, data, cols, destination_key, nbins=20, plot
from <- 1
to <- length(targets)
for(i in 1:numCols) {
pp = pps[from:to]
pp <- pps[from:to]
pp.plot.1d.multinomial(pp)
if(!is.null(save_to)){
pp.plot.save.1d.multinomial(pp)
Expand Down
2 changes: 1 addition & 1 deletion h2o-r/tests/testdir_jira/runit_pubdev_4897.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ test.pubdev_4897 <- function() {
model <- h2o.gbm(x = "Var2", y = "Var1", training_frame = data, ntrees = 50)
pdf("plot")
dev.control(displaylist="enable")
h2o.partialPlot(object = model, data = data, cols = "Var2", plot = TRUE)
h2o.partialPlot(object = model, newdata = data, cols = "Var2", plot = TRUE)
recordedPlot <- recordPlot()
dev.off()
unlink("plot")
Expand Down
2 changes: 1 addition & 1 deletion h2o-r/tests/testdir_jira/runit_pubdev_6122.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ test.pdp.save <- function() {
model <- h2o.gbm(x=cols, y = "CAPSULE", training_frame = data)

temp_filename_no_extension <- tempfile(pattern = "pdp", tmpdir = tempdir(), fileext = "")
plot <- h2o.partialPlot(object = model, data = data, save_to = temp_filename_no_extension)
plot <- h2o.partialPlot(object = model, newdata = data, save_to = temp_filename_no_extension)
expect_false(is.null(plot))

check_file <- function(feature){
Expand Down
4 changes: 2 additions & 2 deletions h2o-r/tests/testdir_misc/runit_PUBDEV-6775-2D-pdp.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ test <- function() {
temp_filename_no_extension <- tempfile(pattern = "pdp", tmpdir = tempdir(), fileext = "")
## Calculate partial dependence using h2o.partialPlot for columns "AGE" and "RACE"
prostate_drf = h2o.randomForest(x = c("AGE", "RACE"), y = "CAPSULE", training_frame = prostate_hex, ntrees = 50, seed = 12345)
h2o_pp_1d_2d = h2o.partialPlot(object = prostate_drf, data = prostate_hex, cols = c("RACE", "AGE"),
h2o_pp_1d_2d = h2o.partialPlot(object = prostate_drf, newdata = prostate_hex, cols = c("RACE", "AGE"),
col_pairs_2dpdp=list(c("RACE", "AGE"), c("AGE", "PSA")), plot = T,
save_to=temp_filename_no_extension)
h2o_pp_2d_only = h2o.partialPlot(object = prostate_drf, data = prostate_hex, col_pairs_2dpdp=list(c("RACE", "AGE"), c("AGE", "PSA")),
h2o_pp_2d_only = h2o.partialPlot(object = prostate_drf, newdata = prostate_hex, col_pairs_2dpdp=list(c("RACE", "AGE"), c("AGE", "PSA")),
plot = FALSE)
# compare 2d pdp results from 2d pdp only and from 1d and 2d pdps
assert_partialPlots_twoDTable_equal(h2o_pp_1d_2d[[3]],h2o_pp_2d_only[[1]]) # 2d pdp RACE and AGE
Expand Down
6 changes: 3 additions & 3 deletions h2o-r/tests/testdir_misc/runit_PUBDEV_6438_2D_pdp.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ test <- function() {
user_splits_list = list(c("AGE", ageSplit))
temp_filename_no_extension <- tempfile(pattern = "pdp", tmpdir = tempdir(), fileext = "")
## Calculate partial dependence using h2o.partialPlot for columns "AGE" and "RACE"
h2o_pp_1d_2d = h2o.partialPlot(object = prostate_drf, data = prostate_hex, cols = c("RACE", "AGE"),
h2o_pp_1d_2d = h2o.partialPlot(object = prostate_drf, newdata = prostate_hex, cols = c("RACE", "AGE"),
col_pairs_2dpdp=list(c("RACE", "AGE"), c("AGE", "PSA")), plot = TRUE,
user_splits=user_splits_list, save_to=temp_filename_no_extension)
if (file.exists(temp_filename_no_extension))
file.remove(temp_filename_no_extension)

h2o_pp_2d_only = h2o.partialPlot(object = prostate_drf, data = prostate_hex, col_pairs_2dpdp=list(c("RACE", "AGE"), c("AGE", "PSA")), plot = FALSE,
h2o_pp_2d_only = h2o.partialPlot(object = prostate_drf, newdata = prostate_hex, col_pairs_2dpdp=list(c("RACE", "AGE"), c("AGE", "PSA")), plot = FALSE,
user_splits=user_splits_list)
h2o_pp_1d_only = h2o.partialPlot(object = prostate_drf, data = prostate_hex, cols = c("RACE", "AGE"), plot = FALSE, user_splits=user_splits_list)
h2o_pp_1d_only = h2o.partialPlot(object = prostate_drf, newdata = prostate_hex, cols = c("RACE", "AGE"), plot = FALSE, user_splits=user_splits_list)

# compare 1d pdp results from 1d pdp only and from 1d and 2d pdps
assert_partialPlots_twoDTable_equal(h2o_pp_1d_2d[[1]],h2o_pp_1d_only[[1]]) # compare RACE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,29 @@ test <- function() {
seed=1234
prostate_drf = h2o.randomForest(x = c("AGE", "RACE"), y = "CAPSULE", training_frame = prostate_hex, ntrees = 25, seed = seed)

h2o.partialPlot(object = prostate_drf, data = prostate_hex, cols = "RACE", plot = TRUE, include_na = TRUE, plot_stddev = TRUE)
h2o.partialPlot(object = prostate_drf, data = prostate_hex, cols = "RACE", plot = TRUE, include_na = FALSE, plot_stddev = TRUE)
h2o.partialPlot(object = prostate_drf, data = prostate_hex, cols = "RACE", plot = TRUE, include_na = FALSE, plot_stddev = FALSE)
h2o.partialPlot(object = prostate_drf, data = prostate_hex, cols = "RACE", plot = TRUE, include_na = TRUE, plot_stddev = FALSE)
h2o.partialPlot(object = prostate_drf, newdata = prostate_hex, cols = "RACE", plot = TRUE, include_na = TRUE, plot_stddev = TRUE)
h2o.partialPlot(object = prostate_drf, newdata = prostate_hex, cols = "RACE", plot = TRUE, include_na = FALSE, plot_stddev = TRUE)
h2o.partialPlot(object = prostate_drf, newdata = prostate_hex, cols = "RACE", plot = TRUE, include_na = FALSE, plot_stddev = FALSE)
h2o.partialPlot(object = prostate_drf, newdata = prostate_hex, cols = "RACE", plot = TRUE, include_na = TRUE, plot_stddev = FALSE)

# 1D multiple cols
h2o.partialPlot(object = prostate_drf, data = prostate_hex, cols = c("AGE", "RACE", "DCAPS"), plot = TRUE, include_na = TRUE, plot_stddev = TRUE)
h2o.partialPlot(object = prostate_drf, data = prostate_hex, cols = c("AGE", "RACE", "DCAPS"), plot = TRUE, include_na = FALSE, plot_stddev = TRUE)
h2o.partialPlot(object = prostate_drf, data = prostate_hex, cols = c("AGE", "RACE", "DCAPS"), plot = TRUE, include_na = FALSE, plot_stddev = FALSE)
h2o.partialPlot(object = prostate_drf, data = prostate_hex, cols = c("AGE", "RACE", "DCAPS"), plot = TRUE, include_na = TRUE, plot_stddev = FALSE)
h2o.partialPlot(object = prostate_drf, newdata = prostate_hex, cols = c("AGE", "RACE", "DCAPS"), plot = TRUE, include_na = TRUE, plot_stddev = TRUE)
h2o.partialPlot(object = prostate_drf, newdata = prostate_hex, cols = c("AGE", "RACE", "DCAPS"), plot = TRUE, include_na = FALSE, plot_stddev = TRUE)
h2o.partialPlot(object = prostate_drf, newdata = prostate_hex, cols = c("AGE", "RACE", "DCAPS"), plot = TRUE, include_na = FALSE, plot_stddev = FALSE)
h2o.partialPlot(object = prostate_drf, newdata = prostate_hex, cols = c("AGE", "RACE", "DCAPS"), plot = TRUE, include_na = TRUE, plot_stddev = FALSE)

# 1D multiple cols && targets
iris[,'random'] <- as.factor(as.data.frame(unlist(sample(x = 1:4, size = length(iris[[1]]), replace=TRUE)))[[1]])
iris_hex <- as.h2o(iris)
iris_gbm <- h2o.gbm(x = c(1:4,6), y = 5, training_frame = iris_hex)

# one column
h2o.partialPlot(object = iris_gbm, data = iris_hex, cols = "Petal.Length", targets = c("setosa"), plot = TRUE, include_na = TRUE, plot_stddev = TRUE )
h2o.partialPlot(object = iris_gbm, data = iris_hex, cols = "Petal.Length", targets = c("setosa", "virginica", "versicolor"), plot = TRUE, include_na = TRUE, plot_stddev = TRUE)
h2o.partialPlot(object = iris_gbm, newdata = iris_hex, cols = "Petal.Length", targets = c("setosa"), plot = TRUE, include_na = TRUE, plot_stddev = TRUE )
h2o.partialPlot(object = iris_gbm, newdata = iris_hex, cols = "Petal.Length", targets = c("setosa", "virginica", "versicolor"), plot = TRUE, include_na = TRUE, plot_stddev = TRUE)

# two colums
h2o.partialPlot(object = iris_gbm, data = iris_hex, cols=c("Petal.Length", "Sepal.Length"), targets=c("setosa"))
h2o.partialPlot(object = iris_gbm, data = iris_hex, cols=c("Petal.Length", "Sepal.Length"), targets=c("setosa"), include_na = FALSE, plot_stddev = TRUE)
h2o.partialPlot(object = iris_gbm, newdata = iris_hex, cols=c("Petal.Length", "Sepal.Length"), targets=c("setosa"))
h2o.partialPlot(object = iris_gbm, newdata = iris_hex, cols=c("Petal.Length", "Sepal.Length"), targets=c("setosa"), include_na = FALSE, plot_stddev = TRUE)

}

Expand Down
Loading

0 comments on commit 4c798ce

Please sign in to comment.