From f10f69d6f35511861a2b55f4c759f54c804fe23a Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 19 Jul 2023 17:01:31 -0700 Subject: [PATCH 1/4] initial survival workflows tests --- tests/testthat/test-survival-workflows.R | 116 +++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 tests/testthat/test-survival-workflows.R diff --git a/tests/testthat/test-survival-workflows.R b/tests/testthat/test-survival-workflows.R new file mode 100644 index 00000000..51974fce --- /dev/null +++ b/tests/testthat/test-survival-workflows.R @@ -0,0 +1,116 @@ +skip_if_not_installed("censored", minimum_version = "0.2.0.9000") +skip_if_not_installed("parsnip", minimum_version = "1.1.0.9003") +skip_if_not_installed("tune", minimum_version = "1.1.1.9001") + +library(tidymodels) +library(censored) + +lung <- lung |> + tidyr::drop_na() |> + dplyr::mutate(surv = Surv(time, status), .keep = "unused") + +test_that("can `fit()` a censored workflow with a formula", { + mod <- survival_reg() + mod <- set_engine(mod, "survival") + mod <- set_mode(mod, "censored regression") + + workflow <- workflow() + workflow <- add_formula(workflow, surv ~ .) + workflow <- add_model(workflow, mod) + + wf_fit <- fit(workflow, lung) + + expect_s3_class(wf_fit$fit$fit, "model_fit") + + expect_equal( + coef(wf_fit$fit$fit$fit), + coef(survreg(surv ~ ., data = lung, model = TRUE)) + ) +}) + +test_that("can `fit()` a censored workflow with a recipe", { + rec <- recipes::recipe(surv ~ ., lung) + + mod <- survival_reg() + mod <- set_engine(mod, "survival") + mod <- set_mode(mod, "censored regression") + + workflow <- workflow() + workflow <- add_recipe(workflow, rec) + workflow <- add_model(workflow, mod) + + wf_fit <- fit(workflow, lung) + + expect_s3_class(wf_fit$fit$fit, "model_fit") + + expect_equal( + coef(wf_fit$fit$fit$fit), + coef(survreg(surv ~ ., data = lung, model = TRUE)) + ) +}) + +test_that("can `predict()` a censored workflow with a formula", { + mod <- survival_reg() + mod <- set_engine(mod, "survival") + mod <- set_mode(mod, "censored regression") + + workflow <- workflow() + workflow <- add_formula(workflow, surv ~ .) + workflow <- add_model(workflow, mod) + + wf_fit <- fit(workflow, lung) + + preds <- predict(wf_fit, new_data = lung) + + expect_identical(names(preds), ".pred_time") + expect_type(preds$.pred_time, "double") + + preds <- predict(wf_fit, new_data = lung, type = "survival", eval_time = c(100, 200)) + + expect_identical(names(preds), ".pred") + expect_type(preds$.pred, "list") + expect_true( + all(purrr::map_lgl( + preds$.pred, + ~ identical(names(.x), c(".eval_time", ".pred_survival")) + )) + ) + + expect_error( + predict(wf_fit, new_data = lung, type = "numeric") + ) +}) + +test_that("can `predict()` a censored workflow with a recipe", { + rec <- recipes::recipe(surv ~ ., lung) + + mod <- survival_reg() + mod <- set_engine(mod, "survival") + mod <- set_mode(mod, "censored regression") + + workflow <- workflow() + workflow <- add_recipe(workflow, rec) + workflow <- add_model(workflow, mod) + + wf_fit <- fit(workflow, lung) + + preds <- predict(wf_fit, new_data = lung) + + expect_identical(names(preds), ".pred_time") + expect_type(preds$.pred_time, "double") + + preds <- predict(wf_fit, new_data = lung, type = "survival", eval_time = c(100, 200)) + + expect_identical(names(preds), ".pred") + expect_type(preds$.pred, "list") + expect_true( + all(purrr::map_lgl( + preds$.pred, + ~ identical(names(.x), c(".eval_time", ".pred_survival")) + )) + ) + + expect_error( + predict(wf_fit, new_data = lung, type = "numeric") + ) +}) From b0f066f2314e531bf804922b95c6e1fcd0a9f448 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 21 Jul 2023 14:50:37 -0700 Subject: [PATCH 2/4] switch to proportional_hazards(engine = "glmnet") --- tests/testthat/test-survival-workflows.R | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/tests/testthat/test-survival-workflows.R b/tests/testthat/test-survival-workflows.R index 51974fce..352fad26 100644 --- a/tests/testthat/test-survival-workflows.R +++ b/tests/testthat/test-survival-workflows.R @@ -10,9 +10,7 @@ lung <- lung |> dplyr::mutate(surv = Surv(time, status), .keep = "unused") test_that("can `fit()` a censored workflow with a formula", { - mod <- survival_reg() - mod <- set_engine(mod, "survival") - mod <- set_mode(mod, "censored regression") + mod <- proportional_hazards(engine = "glmnet", penalty = 0.1) workflow <- workflow() workflow <- add_formula(workflow, surv ~ .) @@ -23,17 +21,15 @@ test_that("can `fit()` a censored workflow with a formula", { expect_s3_class(wf_fit$fit$fit, "model_fit") expect_equal( - coef(wf_fit$fit$fit$fit), - coef(survreg(surv ~ ., data = lung, model = TRUE)) + wf_fit$fit$fit$fit$fit$beta, + censored::coxnet_train(surv ~ ., data = lung)$fit$beta ) }) test_that("can `fit()` a censored workflow with a recipe", { rec <- recipes::recipe(surv ~ ., lung) - mod <- survival_reg() - mod <- set_engine(mod, "survival") - mod <- set_mode(mod, "censored regression") + mod <- proportional_hazards(engine = "glmnet", penalty = 0.1) workflow <- workflow() workflow <- add_recipe(workflow, rec) @@ -44,15 +40,13 @@ test_that("can `fit()` a censored workflow with a recipe", { expect_s3_class(wf_fit$fit$fit, "model_fit") expect_equal( - coef(wf_fit$fit$fit$fit), - coef(survreg(surv ~ ., data = lung, model = TRUE)) + wf_fit$fit$fit$fit$fit$beta, + censored::coxnet_train(surv ~ ., data = lung)$fit$beta ) }) test_that("can `predict()` a censored workflow with a formula", { - mod <- survival_reg() - mod <- set_engine(mod, "survival") - mod <- set_mode(mod, "censored regression") + mod <- proportional_hazards(engine = "glmnet", penalty = 0.1) workflow <- workflow() workflow <- add_formula(workflow, surv ~ .) @@ -84,9 +78,7 @@ test_that("can `predict()` a censored workflow with a formula", { test_that("can `predict()` a censored workflow with a recipe", { rec <- recipes::recipe(surv ~ ., lung) - mod <- survival_reg() - mod <- set_engine(mod, "survival") - mod <- set_mode(mod, "censored regression") + mod <- proportional_hazards(engine = "glmnet", penalty = 0.1) workflow <- workflow() workflow <- add_recipe(workflow, rec) From 943a61f9c09a7225b1c11903604ecdc6e6d93b42 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 21 Jul 2023 15:00:45 -0700 Subject: [PATCH 3/4] add add_variables() tests --- tests/testthat/test-survival-workflows.R | 47 ++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/testthat/test-survival-workflows.R b/tests/testthat/test-survival-workflows.R index 352fad26..4df35f0c 100644 --- a/tests/testthat/test-survival-workflows.R +++ b/tests/testthat/test-survival-workflows.R @@ -26,6 +26,23 @@ test_that("can `fit()` a censored workflow with a formula", { ) }) +test_that("can `fit()` a censored workflow with variables", { + mod <- proportional_hazards(engine = "glmnet", penalty = 0.1) + + workflow <- workflow() + workflow <- add_variables(workflow, outcomes = surv, predictors = everything()) + workflow <- add_model(workflow, mod) + + wf_fit <- fit(workflow, lung) + + expect_s3_class(wf_fit$fit$fit, "model_fit") + + expect_equal( + wf_fit$fit$fit$fit$fit$beta, + censored::coxnet_train(surv ~ ., data = lung)$fit$beta + ) +}) + test_that("can `fit()` a censored workflow with a recipe", { rec <- recipes::recipe(surv ~ ., lung) @@ -75,6 +92,36 @@ test_that("can `predict()` a censored workflow with a formula", { ) }) +test_that("can `predict()` a censored workflow with a recipe", { + mod <- proportional_hazards(engine = "glmnet", penalty = 0.1) + + workflow <- workflow() + workflow <- add_variables(workflow, outcomes = surv, predictors = everything()) + workflow <- add_model(workflow, mod) + + wf_fit <- fit(workflow, lung) + + preds <- predict(wf_fit, new_data = lung) + + expect_identical(names(preds), ".pred_time") + expect_type(preds$.pred_time, "double") + + preds <- predict(wf_fit, new_data = lung, type = "survival", eval_time = c(100, 200)) + + expect_identical(names(preds), ".pred") + expect_type(preds$.pred, "list") + expect_true( + all(purrr::map_lgl( + preds$.pred, + ~ identical(names(.x), c(".eval_time", ".pred_survival")) + )) + ) + + expect_error( + predict(wf_fit, new_data = lung, type = "numeric") + ) +}) + test_that("can `predict()` a censored workflow with a recipe", { rec <- recipes::recipe(surv ~ ., lung) From 34052e0f6e309a3df1659cff95f48b41ff19e915 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Thu, 9 Nov 2023 15:04:34 +0000 Subject: [PATCH 4/4] move data into tests to avoid changing `lung` for other test files and to make the tests more self-contained --- tests/testthat/test-survival-workflows.R | 28 ++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/tests/testthat/test-survival-workflows.R b/tests/testthat/test-survival-workflows.R index 4df35f0c..24db8a67 100644 --- a/tests/testthat/test-survival-workflows.R +++ b/tests/testthat/test-survival-workflows.R @@ -5,11 +5,11 @@ skip_if_not_installed("tune", minimum_version = "1.1.1.9001") library(tidymodels) library(censored) -lung <- lung |> - tidyr::drop_na() |> - dplyr::mutate(surv = Surv(time, status), .keep = "unused") - test_that("can `fit()` a censored workflow with a formula", { + lung <- lung |> + tidyr::drop_na() |> + dplyr::mutate(surv = Surv(time, status), .keep = "unused") + mod <- proportional_hazards(engine = "glmnet", penalty = 0.1) workflow <- workflow() @@ -27,6 +27,10 @@ test_that("can `fit()` a censored workflow with a formula", { }) test_that("can `fit()` a censored workflow with variables", { + lung <- lung |> + tidyr::drop_na() |> + dplyr::mutate(surv = Surv(time, status), .keep = "unused") + mod <- proportional_hazards(engine = "glmnet", penalty = 0.1) workflow <- workflow() @@ -44,6 +48,10 @@ test_that("can `fit()` a censored workflow with variables", { }) test_that("can `fit()` a censored workflow with a recipe", { + lung <- lung |> + tidyr::drop_na() |> + dplyr::mutate(surv = Surv(time, status), .keep = "unused") + rec <- recipes::recipe(surv ~ ., lung) mod <- proportional_hazards(engine = "glmnet", penalty = 0.1) @@ -63,6 +71,10 @@ test_that("can `fit()` a censored workflow with a recipe", { }) test_that("can `predict()` a censored workflow with a formula", { + lung <- lung |> + tidyr::drop_na() |> + dplyr::mutate(surv = Surv(time, status), .keep = "unused") + mod <- proportional_hazards(engine = "glmnet", penalty = 0.1) workflow <- workflow() @@ -93,6 +105,10 @@ test_that("can `predict()` a censored workflow with a formula", { }) test_that("can `predict()` a censored workflow with a recipe", { + lung <- lung |> + tidyr::drop_na() |> + dplyr::mutate(surv = Surv(time, status), .keep = "unused") + mod <- proportional_hazards(engine = "glmnet", penalty = 0.1) workflow <- workflow() @@ -123,6 +139,10 @@ test_that("can `predict()` a censored workflow with a recipe", { }) test_that("can `predict()` a censored workflow with a recipe", { + lung <- lung |> + tidyr::drop_na() |> + dplyr::mutate(surv = Surv(time, status), .keep = "unused") + rec <- recipes::recipe(surv ~ ., lung) mod <- proportional_hazards(engine = "glmnet", penalty = 0.1)