Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix arguments of average_loss() #82

Merged
merged 1 commit into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- `average_loss()` also returns a "hstats_matrix" object with `print()` and `plot()` method. The values can be extracted via `$M`.
- The default `v` of `hstats()` and `perm_importance()` is now `NULL`. Internally, it is set to `colnames(X)` (minus the column names of `w` and `y` if passed as name).
- Missing grid values: `partial_dep()` and `ice()` have received a `na.rm` argument that controls if missing values are dropped during grid creation. The default `TRUE` is compatible with earlier releases.
- The position of some function arguments have changed.

# hstats 0.3.0

Expand Down
25 changes: 17 additions & 8 deletions R/average_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,19 @@ average_loss.default <- function(object, X, y,
#' @export
average_loss.ranger <- function(object, X, y,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
loss = "squared_error",
loss = "squared_error",
agg_cols = FALSE,
BY = NULL, by_size = 4L,
w = NULL, ...) {
average_loss.default(
object = object,
X = X,
y = y,
pred_fun = pred_fun,
BY = BY,
pred_fun = pred_fun,
loss = loss,
agg_cols = agg_cols,
BY = BY,
by_size = by_size,
w = w,
...
)
Expand All @@ -134,7 +137,8 @@ average_loss.ranger <- function(object, X, y,
#' @export
average_loss.Learner <- function(object, X, y,
pred_fun = NULL,
loss = "squared_error",
loss = "squared_error",
agg_cols = FALSE,
BY = NULL, by_size = 4L,
w = NULL, ...) {
if (is.null(pred_fun)) {
Expand All @@ -145,8 +149,10 @@ average_loss.Learner <- function(object, X, y,
X = X,
y = y,
pred_fun = pred_fun,
BY = BY,
loss = loss,
loss = loss,
agg_cols = agg_cols,
BY = BY,
by_size = by_size,
w = w,
...
)
Expand All @@ -158,7 +164,8 @@ average_loss.explainer <- function(object,
X = object[["data"]],
y = object[["y"]],
pred_fun = object[["predict_function"]],
loss = "squared_error",
loss = "squared_error",
agg_cols = FALSE,
BY = NULL,
by_size = 4L,
w = object[["weights"]],
Expand All @@ -168,8 +175,10 @@ average_loss.explainer <- function(object,
X = X,
y = y,
pred_fun = pred_fun,
BY = BY,
loss = loss,
agg_cols = agg_cols,
BY = BY,
by_size = by_size,
w = w,
...
)
Expand Down
40 changes: 22 additions & 18 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
#' (such as `type = "response"` in a GLM, or `reshape = TRUE` in a multiclass XGBoost
#' model) can be passed via `...`. The default, [stats::predict()], will work in
#' most cases.
#' @param n_max If `X` has more than `n_max` rows, a random sample of `n_max` rows is
#' selected from `X`. In this case, set a random seed for reproducibility.
#' @param w Optional vector of case weights. Can also be a column name of `X`.
#' @param pairwise_m Number of features for which pairwise statistics are to be
#' calculated. The features are selected based on Friedman and Popescu's overall
#' interaction strength \eqn{H^2_j}. Set to to 0 to avoid pairwise calculations.
Expand All @@ -50,6 +47,9 @@
#' speed-up for dense features, mainly for one-way statistics.
#' Note that the quantiles are calculated after subsampling to `n_max` rows.
#' @param eps Threshold below which numerator values are set to 0. Default is 1e-10.
#' @param n_max If `X` has more than `n_max` rows, a random sample of `n_max` rows is
#' selected from `X`. In this case, set a random seed for reproducibility.
#' @param w Optional vector of case weights. Can also be a column name of `X`.
#' @param verbose Should a progress bar be shown? The default is `TRUE`.
#' @param ... Additional arguments passed to `pred_fun(object, X, ...)`,
#' for instance `type = "response"` in a [glm()] model, or `reshape = TRUE` in a
Expand Down Expand Up @@ -139,9 +139,10 @@ hstats <- function(object, ...) {
#' @describeIn hstats Default hstats method.
#' @export
hstats.default <- function(object, X, v = NULL,
pred_fun = stats::predict, n_max = 500L,
w = NULL, pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10, verbose = TRUE, ...) {
pred_fun = stats::predict,
pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10,
n_max = 500L, w = NULL, verbose = TRUE, ...) {
stopifnot(
is.matrix(X) || is.data.frame(X),
is.function(pred_fun)
Expand Down Expand Up @@ -275,19 +276,20 @@ hstats.default <- function(object, X, v = NULL,
#' @export
hstats.ranger <- function(object, X, v = NULL,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
n_max = 500L, w = NULL, pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10, verbose = TRUE, ...) {
pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10,
n_max = 500L, w = NULL, verbose = TRUE, ...) {
hstats.default(
object = object,
X = X,
v = v,
pred_fun = pred_fun,
n_max = n_max,
w = w,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
quant_approx = quant_approx,
eps = eps,
n_max = n_max,
w = w,
verbose = verbose,
...
)
Expand All @@ -297,8 +299,9 @@ hstats.ranger <- function(object, X, v = NULL,
#' @export
hstats.Learner <- function(object, X, v = NULL,
pred_fun = NULL,
n_max = 500L, w = NULL, pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10, verbose = TRUE, ...) {
pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10,
n_max = 500L, w = NULL, verbose = TRUE, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
Expand All @@ -307,12 +310,12 @@ hstats.Learner <- function(object, X, v = NULL,
X = X,
v = v,
pred_fun = pred_fun,
n_max = n_max,
w = w,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
quant_approx = quant_approx,
eps = eps,
n_max = n_max,
w = w,
verbose = verbose,
...
)
Expand All @@ -323,20 +326,21 @@ hstats.Learner <- function(object, X, v = NULL,
hstats.explainer <- function(object, X = object[["data"]],
v = NULL,
pred_fun = object[["predict_function"]],
n_max = 500L, w = object[["weights"]],
pairwise_m = 5L, threeway_m = 0L,
quant_approx = NULL, eps = 1e-10, verbose = TRUE, ...) {
quant_approx = NULL, eps = 1e-10,
n_max = 500L, w = object[["weights"]],
verbose = TRUE, ...) {
hstats.default(
object = object[["model"]],
X = X,
v = v,
pred_fun = pred_fun,
n_max = n_max,
w = w,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
quant_approx = quant_approx,
eps = eps,
n_max = n_max,
w = w,
verbose = verbose,
...
)
Expand Down
4 changes: 2 additions & 2 deletions R/ice.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ ice.ranger <- function(object, v, X,
BY = NULL, grid = NULL, grid_size = 49L,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 100, ...) {
n_max = 100L, ...) {
ice.default(
object = object,
v = v,
Expand Down Expand Up @@ -194,7 +194,7 @@ ice.explainer <- function(object, v = v, X = object[["data"]],
BY = NULL, grid = NULL, grid_size = 49L,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 100, ...) {
n_max = 100L, ...) {
ice.default(
object = object[["model"]],
v = v,
Expand Down
3 changes: 3 additions & 0 deletions man/average_loss.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 13 additions & 13 deletions man/hstats.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/ice.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.