-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathnearmiss_impl.R
91 lines (79 loc) · 2.89 KB
/
nearmiss_impl.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#' Remove Points Near Other Classes
#'
#' Generates synthetic positive instances using nearmiss algorithm.
#'
#' @inheritParams step_nearmiss
#' @param df data.frame or tibble. Must have 1 factor variable and remaining
#' numeric variables.
#' @param var Character, name of variable containing factor variable.
#' @param k An integer. Number of nearest neighbor that are used
#' to generate the new examples of the minority class.
#'
#' @return A data.frame or tibble, depending on type of `df`.
#' @export
#'
#' @details
#' All columns used in this function must be numeric with no missing data.
#'
#' @references Inderjeet Mani and I Zhang. knn approach to unbalanced data
#' distributions: a case study involving information extraction. In Proceedings
#' of workshop on learning from imbalanced datasets, 2003.
#'
#' @seealso [step_nearmiss()] for step function of this method
#' @family Direct Implementations
#'
#' @examples
#' circle_numeric <- circle_example[, c("x", "y", "class")]
#'
#' res <- nearmiss(circle_numeric, var = "class")
#'
#' res <- nearmiss(circle_numeric, var = "class", k = 10)
#'
#' res <- nearmiss(circle_numeric, var = "class", under_ratio = 1.5)
nearmiss <- function(df, var, k = 5, under_ratio = 1) {
check_data_frame(df)
check_var(var, df)
check_number_whole(k, min = 1)
check_number_decimal(under_ratio)
predictors <- setdiff(colnames(df), var)
check_numeric(df[, predictors])
check_na(select(df, -all_of(var)))
nearmiss_impl(df, var, ignore_vars = character(), k, under_ratio)
}
nearmiss_impl <- function(df, var, ignore_vars, k = 5, under_ratio = 1) {
classes <- downsample_count(df, var, under_ratio)
out_dfs <- list()
deleted_rows <- integer()
for (i in seq_along(classes)) {
df_only <- df[, !names(df) %in% ignore_vars]
class <- subset_to_matrix(df_only, var, names(classes)[i])
not_class <- subset_to_matrix(df_only, var, names(classes)[i], FALSE)
if (nrow(not_class) <= k) {
cli::cli_abort("Not enough danger observations of {.val {names(classes)[i]}} to perform NEARMISS.")
}
dists <- RANN::nn2(
not_class[, !(colnames(not_class) %in% ignore_vars)],
class[, !(colnames(class) %in% ignore_vars)],
k = k
)$nn.dists
selected_ind <- order(rowMeans(dists)) <= (nrow(class) - classes[i])
deleted_rows <- c(deleted_rows, which(df[[var]] %in% names(classes)[i])[!selected_ind])
}
if (length(deleted_rows) > 0) {
df <- df[-deleted_rows, ]
}
df
}
downsample_count <- function(data, var, ratio) {
min_count <- min(table(data[[var]]))
ratio_target <- min_count * ratio
which_class <- which(table(data[[var]]) > ratio_target)
table(data[[var]])[which_class] - ratio_target
}
subset_to_matrix <- function(data, var, class, equal = TRUE) {
if (equal) {
return(as.matrix(data[data[[var]] == class, names(data) != var]))
} else {
return(as.matrix(data[data[[var]] != class, names(data) != var]))
}
}