Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prediction in classification problems - unclear which level is predicted #57

Closed
szimmer opened this issue Aug 28, 2024 · 2 comments
Closed
Labels
bug an unexpected problem or unintended behavior

Comments

@szimmer
Copy link

szimmer commented Aug 28, 2024

The problem

I'm predicting a variable with levels 0 and 1 where I've ordered the factors so 1 is first. When using orbital, the predicted probability returned is the probability of seeing "0" but I would expect it to be the first level.

In the example below, you can see the difference between predicting with predict on the fitted workflow vs predicting using the orbital object. I expected the orbital object to predict .pred_1 but it is predicting .pred_0.

Reproducible example

library(orbital)
library(tidymodels)
library(dplyr)

hotels <- 
  readr::read_csv("https://tidymodels.org/start/case-study/hotels.csv") %>%
  mutate(across(where(is.character), as.factor)) %>%
  mutate(children=if_else(children=="children", 1, 0) %>% factor(levels=c(1,0))) %>%
  select(-arrival_date )
#> Rows: 50000 Columns: 23
#> ── Column specification ────────────────────────────────────────────────────────
#> Delimiter: ","
#> chr  (11): hotel, children, meal, country, market_segment, distribution_chan...
#> dbl  (11): lead_time, stays_in_weekend_nights, stays_in_week_nights, adults,...
#> date  (1): arrival_date
#> 
#> ℹ Use `spec()` to retrieve the full column specification for this data.
#> ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
hotels %>% count(children)
#> # A tibble: 2 × 2
#>   children     n
#>   <fct>    <int>
#> 1 1         4038
#> 2 0        45962
lr_mod <- 
  logistic_reg() %>% 
  set_engine("glm")

lr_recipe <- 
  recipe(children ~ ., data = hotels) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors())

lr_workflow <- 
  workflow() %>% 
  add_model(lr_mod) %>% 
  add_recipe(lr_recipe)

wf_fit <- fit(lr_workflow, hotels)
#> Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred
yhat_fit <- predict(wf_fit, hotels, type="prob")

orb_obj <- orbital(wf_fit)

yhat_orb <- predict(orb_obj, hotels)

tibble(yhat_fit, yhat_orb)
#> # A tibble: 50,000 × 3
#>    .pred_1 .pred_0  .pred
#>      <dbl>   <dbl>  <dbl>
#>  1 0.0154   0.985  0.985 
#>  2 0.113    0.887  0.887 
#>  3 0.0204   0.980  0.980 
#>  4 0.0362   0.964  0.964 
#>  5 0.793    0.207  0.207 
#>  6 0.00922  0.991  0.991 
#>  7 0.944    0.0561 0.0561
#>  8 0.487    0.513  0.513 
#>  9 0.0681   0.932  0.932 
#> 10 0.103    0.897  0.897 
#> # ℹ 49,990 more rows

Created on 2024-08-28 with reprex v2.1.0

Session info
sessionInfo()
#> R version 4.4.1 (2024-06-14 ucrt)
#> Platform: x86_64-w64-mingw32/x64
#> Running under: Windows 10 x64 (build 19045)
#> 
#> Matrix products: default
#> 
#> 
#> locale:
#> [1] LC_COLLATE=English_United States.utf8 
#> [2] LC_CTYPE=English_United States.utf8   
#> [3] LC_MONETARY=English_United States.utf8
#> [4] LC_NUMERIC=C                          
#> [5] LC_TIME=English_United States.utf8    
#> 
#> time zone: America/New_York
#> tzcode source: internal
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#>  [1] yardstick_1.3.1    workflowsets_1.1.0 workflows_1.1.4    tune_1.2.1        
#>  [5] tidyr_1.3.1        tibble_3.2.1       rsample_1.2.1      recipes_1.1.0     
#>  [9] purrr_1.0.2        parsnip_1.2.1      modeldata_1.4.0    infer_1.0.7       
#> [13] ggplot2_3.5.1      dplyr_1.1.4        dials_1.3.0        scales_1.3.0      
#> [17] broom_1.0.6        tidymodels_1.2.0   orbital_0.2.0     
#> 
#> loaded via a namespace (and not attached):
#>  [1] tidyselect_1.2.1    timeDate_4032.109   R.utils_2.12.3     
#>  [4] fastmap_1.2.0       reprex_2.1.0        digest_0.6.36      
#>  [7] rpart_4.1.23        timechange_0.3.0    lifecycle_1.0.4    
#> [10] survival_3.6-4      magrittr_2.0.3      compiler_4.4.1     
#> [13] rlang_1.1.4         tools_4.4.1         utf8_1.2.4         
#> [16] yaml_2.3.8          data.table_1.15.4   knitr_1.47         
#> [19] curl_5.2.1          bit_4.0.5           DiceDesign_1.10    
#> [22] R.cache_0.16.0      withr_3.0.0         R.oo_1.26.0        
#> [25] nnet_7.3-19         grid_4.4.1          fansi_1.0.6        
#> [28] colorspace_2.1-0    future_1.34.0       globals_0.16.3     
#> [31] iterators_1.0.14    MASS_7.3-60.2       cli_3.6.3          
#> [34] crayon_1.5.3        rmarkdown_2.27      generics_0.1.3     
#> [37] rstudioapi_0.16.0   future.apply_1.11.2 tzdb_0.4.0         
#> [40] splines_4.4.1       parallel_4.4.1      vctrs_0.6.5        
#> [43] hardhat_1.4.0       Matrix_1.7-0        hms_1.1.3          
#> [46] bit64_4.0.5         listenv_0.9.1       foreach_1.5.2      
#> [49] gower_1.0.1         glue_1.7.0          parallelly_1.38.0  
#> [52] codetools_0.2-20    lubridate_1.9.3     gtable_0.3.5       
#> [55] munsell_0.5.1       GPfit_1.0-8         styler_1.10.3      
#> [58] pillar_1.9.0        furrr_0.3.1         htmltools_0.5.8.1  
#> [61] ipred_0.9-15        lava_1.8.0          R6_2.5.1           
#> [64] lhs_1.2.0           tidypredict_0.5     vroom_1.6.5        
#> [67] evaluate_0.24.0     lattice_0.22-6      readr_2.1.5        
#> [70] R.methodsS3_1.8.2   backports_1.5.0     class_7.3-22       
#> [73] Rcpp_1.0.12         prodlim_2024.06.25  xfun_0.45          
#> [76] fs_1.6.4            pkgconfig_2.0.3
@EmilHvitfeldt
Copy link
Member

what is happening here is that {orbital} is not supporting any classification models YET. But for some reason, it still worked and was treated as a regression model, which is a bug and will be fixed.

We are tracking classification models here: #46

thanks for reporting!

@EmilHvitfeldt EmilHvitfeldt added the bug an unexpected problem or unintended behavior label Aug 28, 2024
@EmilHvitfeldt
Copy link
Member

EmilHvitfeldt commented Dec 14, 2024

{orbital} now handles prediction with logisitic_reg() as you would expect. With hard prediction concerning the first factor level, to align with how the rest of tidymodels works.

library(orbital)
library(tidymodels)
library(dplyr)

hotels <- 
  readr::read_csv("https://tidymodels.org/start/case-study/hotels.csv") %>%
  mutate(across(where(is.character), as.factor)) %>%
  mutate(children=if_else(children=="children", 1, 0) %>% factor(levels=c(1,0))) %>%
  select(-arrival_date)

lr_mod <- 
  logistic_reg() %>% 
  set_engine("glm")

lr_recipe <- 
  recipe(children ~ ., data = hotels) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors())

lr_workflow <- 
  workflow() %>% 
  add_model(lr_mod) %>% 
  add_recipe(lr_recipe)

wf_fit <- fit(lr_workflow, hotels)

predict(wf_fit, hotels, type="prob")
#> # A tibble: 50,000 × 2
#>    .pred_1 .pred_0
#>      <dbl>   <dbl>
#>  1 0.0154   0.985 
#>  2 0.113    0.887 
#>  3 0.0204   0.980 
#>  4 0.0362   0.964 
#>  5 0.793    0.207 
#>  6 0.00922  0.991 
#>  7 0.944    0.0561
#>  8 0.487    0.513 
#>  9 0.0681   0.932 
#> 10 0.103    0.897 
#> # ℹ 49,990 more rows

orb_obj <- orbital(wf_fit)
predict(orb_obj, hotels)
#> # A tibble: 50,000 × 1
#>    .pred_class
#>    <chr>      
#>  1 0          
#>  2 0          
#>  3 0          
#>  4 0          
#>  5 1          
#>  6 0          
#>  7 1          
#>  8 0          
#>  9 0          
#> 10 0          
#> # ℹ 49,990 more rows

orb_obj <- orbital(wf_fit, type = "prob")
predict(orb_obj, hotels)
#> # A tibble: 50,000 × 2
#>    .pred_1 .pred_0
#>      <dbl>   <dbl>
#>  1 0.0154   0.985 
#>  2 0.113    0.887 
#>  3 0.0204   0.980 
#>  4 0.0362   0.964 
#>  5 0.793    0.207 
#>  6 0.00922  0.991 
#>  7 0.944    0.0561
#>  8 0.487    0.513 
#>  9 0.0681   0.932 
#> 10 0.103    0.897 
#> # ℹ 49,990 more rows

orb_obj <- orbital(wf_fit, type = c("class", "prob"))
predict(orb_obj, hotels)
#> # A tibble: 50,000 × 3
#>    .pred_class .pred_1 .pred_0
#>    <chr>         <dbl>   <dbl>
#>  1 0           0.0154   0.985 
#>  2 0           0.113    0.887 
#>  3 0           0.0204   0.980 
#>  4 0           0.0362   0.964 
#>  5 1           0.793    0.207 
#>  6 0           0.00922  0.991 
#>  7 1           0.944    0.0561
#>  8 0           0.487    0.513 
#>  9 0           0.0681   0.932 
#> 10 0           0.103    0.897 
#> # ℹ 49,990 more rows

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug an unexpected problem or unintended behavior
Projects
None yet
Development

No branches or pull requests

2 participants