Skip to content

Commit

Permalink
Merge pull request #116 from mayer79/remove-factor
Browse files Browse the repository at this point in the history
Disallow factor predictions
  • Loading branch information
mayer79 authored Jul 7, 2024
2 parents 3c37e6f + 2278089 commit df1e51e
Show file tree
Hide file tree
Showing 24 changed files with 177 additions and 534 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: hstats
Title: Interaction Statistics
Version: 1.1.2
Version: 1.2.0
Authors@R:
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre"))
Description: Fast, model-agnostic implementation of different H-statistics
Expand Down
11 changes: 11 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# hstats 1.2.0

## Major changes

- Factor-valued predictions are no longer possible.
- Consequently, also removed "classification_error" loss.

## Minor changes

- Code simplifications.

# hstats 1.1.2

## ICE plots
Expand Down
15 changes: 5 additions & 10 deletions R/average_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,19 @@
#' or as discrete vector with corresponding levels.
#' The latter case is turned into a dummy matrix by a fast version of
#' `model.matrix(~ as.factor(y) + 0)`.
#' - "classification_error": Misclassification error. Both the
#' observed values `y` and the predictions can be character/factor. This
#' loss function can be used in non-probabilistic classification settings.
#' BUT: Probabilistic classification (with "mlogloss") is clearly preferred in most
#' situations.
#' - A function with signature `f(actual, predicted)`, returning a numeric
#' vector or matrix of the same length as the input.
#'
#' @inheritParams hstats
#' @param y Vector/matrix of the response, or the corresponding column names in `X`.
#' @param loss One of "squared_error", "logloss", "mlogloss", "poisson",
#' "gamma", "absolute_error", "classification_error". Alternatively, a loss function
#' "gamma", or "absolute_error". Alternatively, a loss function
#' can be provided that turns observed and predicted values into a numeric vector or
#' matrix of unit losses of the same length as `X`.
#' For "mlogloss", the response `y` can either be a dummy matrix or a discrete vector.
#' The latter case is handled via a fast version of `model.matrix(~ as.factor(y) + 0)`.
#' For "classification_error", both predictions and responses can be non-numeric.
#' For "squared_error", both predictions and responses can be factors with identical
#' levels. In this case, squared error is evaulated for each one-hot-encoded column.
#' For "squared_error", the response can be a factor with levels in column order of
#' the predictions. In this case, squared error is evaluated for each one-hot-encoded column.
#' @param agg_cols Should multivariate losses be summed up? Default is `FALSE`.
#' In combination with the squared error loss, `agg_cols = TRUE` gives
#' the Brier score for (probabilistic) classification.
Expand Down Expand Up @@ -116,7 +110,8 @@ average_loss.default <- function(object, X, y,
#' @describeIn average_loss Method for "ranger" models.
#' @export
average_loss.ranger <- function(object, X, y,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
pred_fun = function(m, X, ...)
stats::predict(m, X, ...)$predictions,
loss = "squared_error",
agg_cols = FALSE,
BY = NULL, by_size = 4L,
Expand Down
2 changes: 1 addition & 1 deletion R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ hstats.default <- function(object, X, v = NULL,
}

# Predictions ("F" in Friedman and Popescu) always calculated (cheap)
f <- wcenter(prepare_pred(pred_fun(object, X, ...), ohe = TRUE), w = w)
f <- wcenter(prepare_pred(pred_fun(object, X, ...)), w = w)
mean_f2 <- wcolMeans(f^2, w = w) # A vector

# Initialize first progress bar
Expand Down
1 change: 0 additions & 1 deletion R/ice.R
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ ice.default <- function(object, v, X, pred_fun = stats::predict,
grid = grid,
pred_fun = pred_fun,
pred_only = FALSE,
ohe = TRUE,
...
)
pred <- ice_out[["pred"]]
Expand Down
113 changes: 58 additions & 55 deletions R/losses.R
Original file line number Diff line number Diff line change
@@ -1,18 +1,43 @@
#' Input Checks for Losses
#'
#' Internal function with general input checks.
#'
#' @noRd
#' @keywords internal
#'
#' @param actual Actual values.
#' @param predicted Predictions.
#' @returns `TRUE`
check_loss <- function(actual, predicted) {
stopifnot(
is.vector(actual) || is.matrix(actual),
is.vector(predicted) || is.matrix(predicted),
is.numeric(actual) || is.logical(actual),
is.numeric(predicted) || is.logical(predicted),
NROW(actual) == NROW(predicted),
NCOL(actual) == 1L || NCOL(actual) == NCOL(predicted)
)
return(TRUE)
}

#' Squared Error Loss
#'
#' Internal function. Calculates squared error.
#'
#' @noRd
#' @keywords internal
#'
#' @param actual A numeric vector/matrix, or factor.
#' @param predicted A numeric vector/matrix, or factor.
#' @param actual A numeric vector or matrix, or a factor with levels in the same order
#' as the column names of `predicted`.
#' @param predicted A numeric vector or matrix.
#' @returns Vector or matrix of numeric losses.
loss_squared_error <- function(actual, predicted) {
actual <- drop(prepare_pred(actual, ohe = TRUE))
predicted <- prepare_pred(predicted, ohe = TRUE)

(actual - predicted)^2
if (is.factor(actual)) {
actual <- fdummy(actual)
}
check_loss(actual, predicted)

return((drop(actual) - predicted)^2)
}

#' Absolute Error Loss
Expand All @@ -26,13 +51,9 @@ loss_squared_error <- function(actual, predicted) {
#' @param predicted A numeric vector/matrix.
#' @returns Vector or matrix of numeric losses.
loss_absolute_error <- function(actual, predicted) {
actual <- drop(prepare_pred(actual))
predicted <- prepare_pred(predicted)
if (is.factor(actual) || is.factor(predicted)) {
stop("Absolute loss does not make sense for factors.")
}

abs(actual - predicted)
check_loss(actual, predicted)

return(abs(drop(actual) - predicted))
}

#' Poisson Deviance Loss
Expand All @@ -46,20 +67,18 @@ loss_absolute_error <- function(actual, predicted) {
#' @param predicted A numeric vector/matrix with non-negative values.
#' @returns Vector or matrix of numeric losses.
loss_poisson <- function(actual, predicted) {
actual <- drop(prepare_pred(actual))
predicted <- prepare_pred(predicted)
if (is.factor(actual) || is.factor(predicted)) {
stop("Poisson loss does not make sense for factors.")
}
check_loss(actual, predicted)
stopifnot(
all(predicted >= 0),
all(actual >= 0)
)

actual <- drop(actual)

out <- predicted
p <- actual > 0
out[p] <- (actual * log(actual / predicted) - (actual - predicted))[p]
2 * out
return(2 * out)
}

#' Gamma Deviance Loss
Expand All @@ -73,34 +92,15 @@ loss_poisson <- function(actual, predicted) {
#' @param predicted A numeric vector/matrix with positive values.
#' @returns Vector or matrix of numeric losses.
loss_gamma <- function(actual, predicted) {
actual <- drop(prepare_pred(actual))
predicted <- prepare_pred(predicted)
if (is.factor(actual) || is.factor(predicted)) {
stop("Gamma loss does not make sense for factors.")
}
check_loss(actual, predicted)
stopifnot(
all(predicted > 0),
all(actual > 0)
)

-2 * (log(actual / predicted) - (actual - predicted) / predicted)
}

#' Classification Error Loss
#'
#' Internal function. Calculates per-row misclassification errors.
#'
#' @noRd
#' @keywords internal
#'
#' @param actual A vector/factor/matrix with discrete values.
#' @param predicted A vector/factor/matrix with discrete values.
#' @returns Vector or matrix of numeric losses.
loss_classification_error <- function(actual, predicted) {
actual <- drop(prepare_pred(actual))
predicted <- prepare_pred(predicted)
actual <- drop(actual)

(actual != predicted) * 1.0
return(-2 * (log(actual / predicted) - (actual - predicted) / predicted))
}

#' Log Loss
Expand All @@ -115,19 +115,17 @@ loss_classification_error <- function(actual, predicted) {
#' @param predicted A numeric vector/matrix with values between 0 and 1.
#' @returns Vector or matrix of numeric losses.
loss_logloss <- function(actual, predicted) {
actual <- drop(prepare_pred(actual))
predicted <- prepare_pred(predicted)
if (is.factor(actual) || is.factor(predicted)) {
stop("Log loss does not make sense for factors.")
}
check_loss(actual, predicted)
stopifnot(
all(predicted >= 0),
all(predicted <= 1),
all(actual >= 0),
all(actual <= 1)
)

-xlogy(actual, predicted) - xlogy(1 - actual, 1 - predicted)
actual <- drop(actual)

return(-xlogy(actual, predicted) - xlogy(1 - actual, 1 - predicted))
}

#' Multi-Column Log Loss
Expand All @@ -139,28 +137,33 @@ loss_logloss <- function(actual, predicted) {
#' @keywords internal
#'
#' @param actual A numeric matrix with values between 0 and 1, or a
#' discrete vector that will be one-hot-encoded by a fast version of
#' `model.matrix(~ as.factor(actual) + 0)`.
#' factor, or a discrete numeric vector that will be one-hot-encoded by a
#' fast version of `model.matrix(~ as.factor(actual) + 0)`.
#' The column order of `predicted` must be in line with this!
#' @param predicted A numeric matrix with values between 0 and 1.
#' @returns `TRUE` (or an error message).
#' @returns A numeric vector of losses.
loss_mlogloss <- function(actual, predicted) {
actual <- prepare_pred(actual)
predicted <- prepare_pred(predicted)
if (NCOL(actual) == 1L) { # not only for factors
actual <- fdummy(actual)
}

stopifnot(
is.matrix(actual),
is.matrix(predicted),

is.numeric(actual) || is.logical(actual),
is.numeric(predicted) || is.logical(predicted),

dim(actual) == dim(predicted),
ncol(predicted) >= 2L,
ncol(actual) == ncol(predicted),

all(predicted >= 0),
all(predicted <= 1),
all(actual >= 0),
all(actual <= 1)
)

unname(-rowSums(xlogy(actual, predicted)))
return(unname(-rowSums(xlogy(actual, predicted))))
}

#' Calculates x*log(y)
Expand All @@ -176,6 +179,7 @@ loss_mlogloss <- function(actual, predicted) {
xlogy <- function(x, y) {
out <- x * log(y)
out[x == 0] <- 0

return(out)
}

Expand All @@ -198,7 +202,6 @@ get_loss_fun <- function(loss) {
poisson = loss_poisson,
gamma = loss_gamma,
absolute_error = loss_absolute_error,
classification_error = loss_classification_error,
stop("Unknown loss function.")
)
}
5 changes: 2 additions & 3 deletions R/pd_raw.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@ pd_raw <- function(object, v, X, grid, pred_fun = stats::predict,
#' predictions. Otherwise, a list with two elements: `pred` (predictions)
#' and `grid_pred` (the corresponding grid values in the same mode as the input,
#' but replicated over `X`).
#' @param ohe Should factor output be one-hot encoded? Default is `FALSE`.
#' @returns
#' Either a vector/matrix of predictions or a list with predictions and grid.
ice_raw <- function(object, v, X, grid, pred_fun = stats::predict,
pred_only = TRUE, ohe = FALSE, ...) {
pred_only = TRUE, ...) {
D1 <- length(v) == 1L
n <- nrow(X)
n_grid <- NROW(grid)
Expand All @@ -79,7 +78,7 @@ ice_raw <- function(object, v, X, grid, pred_fun = stats::predict,
}

# Calculate matrix/vector of predictions
pred <- prepare_pred(pred_fun(object, X_pred, ...), ohe = ohe)
pred <- prepare_pred(pred_fun(object, X_pred, ...))

if (pred_only) {
return(pred)
Expand Down
13 changes: 7 additions & 6 deletions R/perm_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#'
#' Calculates permutation importance for a set of features or a set of feature groups.
#' By default, importance is calculated for all columns in `X` (except column names
#' used as response `y` or case weight `w`).
#' used as response `y` or as case weight `w`).
#'
#' The permutation importance of a feature is defined as the increase in the average
#' loss when shuffling the corresponding feature values before calculating predictions.
Expand Down Expand Up @@ -31,7 +31,7 @@
#' # MODEL 1: Linear regression
#' fit <- lm(Sepal.Length ~ ., data = iris)
#' s <- perm_importance(fit, X = iris, y = "Sepal.Length")

#'
#' s
#' s$M
#' s$SE # Standard errors are available thanks to repeated shuffling
Expand Down Expand Up @@ -105,7 +105,7 @@ perm_importance.default <- function(object, X, y, v = NULL,
X <- X[ix, , drop = FALSE]
if (is.vector(y) || is.factor(y)) {
y <- y[ix]
} else {
} else { # matrix case
y <- y[ix, , drop = FALSE]
}
if (!is.null(w)) {
Expand All @@ -128,8 +128,8 @@ perm_importance.default <- function(object, X, y, v = NULL,
X <- rep_rows(X, ind)
if (is.vector(y) || is.factor(y)) {
y <- y[ind]
} else {
y <- rep_rows(y, ind)
} else { # matrix case
y <- y[ind, , drop = FALSE]
}
}

Expand Down Expand Up @@ -206,7 +206,8 @@ perm_importance.default <- function(object, X, y, v = NULL,
#' @describeIn perm_importance Method for "ranger" models.
#' @export
perm_importance.ranger <- function(object, X, y, v = NULL,
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
pred_fun = function(m, X, ...)
stats::predict(m, X, ...)$predictions,
loss = "squared_error", m_rep = 4L,
agg_cols = FALSE,
normalize = FALSE, n_max = 10000L,
Expand Down
Loading

0 comments on commit df1e51e

Please sign in to comment.