-
Notifications
You must be signed in to change notification settings - Fork 630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
caret 6.0-86: Function groupKFold does not return the requested number of folds #1150
Comments
Having the same problem, here is an example - I request 15 folds, and get less, and an inconstant number. Note that all groups in this example are of the same size: set.seed(1)
x <- rep(letters[1:20], each = 20)
length(caret::groupKFold(x, k = 15))
#> [1] 12
length(caret::groupKFold(x, k = 15))
#> [1] 11
length(caret::groupKFold(x, k = 15))
#> [1] 8
length(caret::groupKFold(x, k = 15))
#> [1] 11 Created on 2020-06-17 by the reprex package (v0.3.0) |
It seems that
set.seed(1)
x <- rep(letters[1:20], each = 20)
folds <- caret::groupKFold(x, k = 15)
folds_i <- lapply(folds, function(i) unique(x[i]))
folds_i <- unname(unlist(folds_i))
table(folds_i)
#> folds_i
#> a b c d e f g h i j k l m n o p q r s t
#> 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 Created on 2020-06-17 by the reprex package (v0.3.0) In the docs, point 3 can be seen here:
Perhaps these constrains can be controlled by the user? If for example number of folds in more important than group balance? Thanks! |
@mattansb I'm not sure reason 3 could be the cause of this odd behavior, because even with I think you could be right about reason 2 being the cause though. It seems possible, but I haven't tested it. In any case, at the very least I would consider this a bug because it doesn't generate a warning to the user when fewer than In the meantime I'm using this code for leave one group out: leave_one_group_out <- function(group_list) {
result <- list()
fold_i <- 0
for (left_out_group in unique(group_list)) {
fold_i <- fold_i + 1
result[[fold_i]] <- which(group_list != left_out_group, arr.ind = TRUE)
}
return(result)
} |
In my - possibly falsy - comprehension, something weird appears in the call of the sample function inside groupKFold: function (group, k = length(unique(group)))
{
g_unique <- unique(group)
m <- length(g_unique)
if (k > m) {
stop("`k` should be less than ", m)
}
g_folds <- sample(k, size = m, replace = TRUE)
out <- split(seq_along(group), g_folds[match(group, g_unique)])
names(out) <- paste0("Fold", gsub(" ", "0",
format(seq_along(out))))
lapply(out, function(z) seq_along(group)[-z])
} Why do not use the sample function that way: (The right number of folds would be returned at each function call, as expected) function (group, k = length(unique(group)))
{
g_unique <- unique(group)
m <- length(g_unique)
if (k > m) {
stop("`k` should be less than ", m)
}
g_folds <- sample(m, size = k, replace = FALSE)
out <- split(seq_along(group), g_folds[match(group, g_unique)])
names(out) <- paste0("Fold", gsub(" ", "0",
format(seq_along(out))))
lapply(out, function(z) seq_along(group)[-z])
} Furthermore, in my opinion, I would turn the error message into: stop("`k` should be less than or equal to", m) instead of: stop("`k` should be less than ", m) Anthony |
I encountered the same problem. I think @AnthonyTedde is right - it is due to the sample() function inside groupKFold. Say m = 30, k = 10. The current code selects 30 elements from range 1:10, so if In order to set
|
This issue still is a problem, I think simply changing the sample to replace = FALSE solves the problem |
I used other answers to code functions for group K fold COMPLETE ( it means gives all splits possible and all combinations in these):
Example: |
This still appears to be an issue 2 years later? If groupKFold works as expected, the documentation is not helpful. My initial take is that the function does not implement grouped k-fold cross validation in any way that I am familiar with. For example, under any definition of grouped cross validation, how would this happen?:
IMO, grouped k-fold, should just do k-fold cross validation on the groups. Something like:
Am I missing something? |
I am still encountering this problem in caret v 6.0-93. I get what I expected- where the number of folds is the number I requested- using an 'old' version of this function found in #540 and copied here (function from topepo):
Perhaps the change in #1108 is where things start to act differently. Regardless, thanks topepo for the |
Still encountering this problem in Also, I think it is inconsistent that |
The function groupKFold in caret 6.0-86 does not return the requested number of folds, even though it is possible according to the number of groups. caret version 6.0-84 returns the correct number of folds.
Example:
R Version 4.0.1
Caret Version 6.0-86
require(caret)
set.seed(1)
Vector with 10 groups
group <- as.factor(sort(rep(seq(1, 10), 10)))
folds <- caret::groupKFold(group, k = 10)
Number of folds I expected: 10 folds with 1 group missing in each
Number of folds I got: 7
length(folds)
Caret Version 6.0-84
require(caret)
set.seed(1)
Vector with 10 groups
group <- as.factor(sort(rep(seq(1, 10), 10)))
folds <- caret::groupKFold(group, k = 10)
Number of folds I expected: 10 folds with 1 group missing in each
Number of folds I got: 10
length(folds)
The text was updated successfully, but these errors were encountered: