diff --git a/.github/workflows/r-test.yml b/.github/workflows/r-test.yml index cf2fb148..0a8464bc 100644 --- a/.github/workflows/r-test.yml +++ b/.github/workflows/r-test.yml @@ -39,7 +39,7 @@ jobs: - name: Create a CRAN-ready version of the R package run: | - Rscript cran-bootstrap.R 0 0 + Rscript cran-bootstrap.R 0 0 1 - uses: r-lib/actions/check-r-package@v2 with: diff --git a/DESCRIPTION b/DESCRIPTION index 6b3277dc..7d3bdc53 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,5 +1,5 @@ Package: stochtree -Title: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference +Title: Stochastic tree Ensembles (XBART and BART) for Supervised Learning and Causal Inference Version: 0.1.0 Authors@R: c( @@ -10,7 +10,11 @@ Authors@R: person("Jingyu", "He", role = "aut"), person("stochtree contributors", role = c("cph")) ) -Description: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference. +Description: Flexible stochastic tree ensemble software. Robust implementations of + Bayesian Additive Regression Trees (Chipman, George, McCulloch (2010) ) + for supervised learning and (Bayesian Causal Forests (BCF) Hahn, Murray, Carvalho (2020) ) + for causal inference. Enables model serialization and parallel sampling + and provides a low-level interface for custom stochastic forest samplers. License: MIT + file LICENSE Encoding: UTF-8 Roxygen: list(markdown = TRUE) diff --git a/LICENSE b/LICENSE index 163f07f2..1941a4d0 100644 --- a/LICENSE +++ b/LICENSE @@ -1,2 +1,2 @@ -YEAR: 2024 -COPYRIGHT HOLDER: stochtree authors \ No newline at end of file +YEAR: 2025 +COPYRIGHT HOLDER: stochtree contributors \ No newline at end of file diff --git a/LICENSE.md b/LICENSE.md index 3c81b245..5e7f9a94 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -1,6 +1,6 @@ # MIT License -Copyright (c) 2024 stochtree authors +Copyright (c) 2023-2025 stochtree authors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/NAMESPACE b/NAMESPACE index bcb23c2d..a18eabe7 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -28,7 +28,9 @@ export(createCppRNG) export(createForest) export(createForestDataset) export(createForestModel) +export(createForestModelConfig) export(createForestSamples) +export(createGlobalModelConfig) export(createOutcome) export(createPreprocessorFromJson) export(createPreprocessorFromJsonString) diff --git a/NEWS.md b/NEWS.md index aa0f54d0..397c25ab 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,16 @@ # stochtree 0.1.0 -* Initial CRAN submission. +* Initial release on CRAN. +* Support for sampling stochastic tree ensembles using two algorithms: MCMC and Grow-From-Root (GFR) +* High-level model types supported: + * Supervised learning with constant leaves or user-specified leaf regression models + * Causal effect estimation with binary, continuous, or multivariate treatments +* Additional high-level modeling features: + * Forest-based variance function estimation (heteroskedasticity) + * Additive (univariate or multivariate) group random effects + * Multi-chain sampling and support for parallelism + * "Warm-start" initialization of MCMC forest samplers via the Grow-From-Root (GFR) algorithm + * Automated preprocessing / handling of categorical variables +* Low-level interface: + * Ability to combine a forest sampler with other (additive) model terms, without using C++ + * Combine and sample an arbitrary number of forests or random effects terms diff --git a/R/bart.R b/R/bart.R index 1c814dd3..ba33ae84 100644 --- a/R/bart.R +++ b/R/bart.R @@ -199,7 +199,6 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (num_mcmc == 0) keep_gfr <- T # Check if previous model JSON is provided and parse it if so - # TODO: check that `previous_model_warmstart_sample_num` is <= the number of samples in this previous model has_prev_model <- !is.null(previous_model_json) if (has_prev_model) { previous_bart_model <- createBARTModelFromJsonString(previous_model_json) @@ -222,6 +221,10 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (previous_bart_model$model_params$has_rfx) { previous_rfx_samples <- previous_bart_model$rfx_samples } else previous_rfx_samples <- NULL + previous_model_num_samples <- previous_bart_model$model_params$num_samples + if (previous_model_warmstart_sample_num > previous_model_num_samples) { + stop("`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`") + } } else { previous_y_bar <- NULL previous_y_scale <- NULL @@ -230,6 +233,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train previous_rfx_samples <- NULL previous_forest_samples_mean <- NULL previous_forest_samples_variance <- NULL + previous_model_num_samples <- 0 } # Determine whether conditional mean, variance, or both will be modeled @@ -540,11 +544,22 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train # Sampling data structures feature_types <- as.integer(feature_types) + global_model_config <- createGlobalModelConfig(global_error_variance=current_sigma2) if (include_mean_forest) { - forest_model_mean <- createForestModel(forest_dataset_train, feature_types, num_trees_mean, nrow(X_train), alpha_mean, beta_mean, min_samples_leaf_mean, max_depth_mean) + forest_model_config_mean <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_mean, num_features=ncol(X_train), + num_observations=nrow(X_train), variable_weights=variable_weights_mean, leaf_dimension=leaf_dimension, + alpha=alpha_mean, beta=beta_mean, min_samples_leaf=min_samples_leaf_mean, max_depth=max_depth_mean, + leaf_model_type=leaf_model_mean_forest, leaf_model_scale=current_leaf_scale, + cutpoint_grid_size=cutpoint_grid_size) + forest_model_mean <- createForestModel(forest_dataset_train, forest_model_config_mean, global_model_config) } if (include_variance_forest) { - forest_model_variance <- createForestModel(forest_dataset_train, feature_types, num_trees_variance, nrow(X_train), alpha_variance, beta_variance, min_samples_leaf_variance, max_depth_variance) + forest_model_config_variance <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_variance, num_features=ncol(X_train), + num_observations=nrow(X_train), variable_weights=variable_weights_variance, leaf_dimension=1, + alpha=alpha_variance, beta=beta_variance, min_samples_leaf=min_samples_leaf_variance, + max_depth=max_depth_variance, leaf_model_type=leaf_model_variance_forest, + cutpoint_grid_size=cutpoint_grid_size) + forest_model_variance <- createForestModel(forest_dataset_train, forest_model_config_variance, global_model_config) } # Container of forest samples @@ -601,11 +616,13 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (requires_basis) init_values_mean_forest <- rep(0., ncol(leaf_basis_train)) else init_values_mean_forest <- 0. active_forest_mean$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mean, leaf_model_mean_forest, init_values_mean_forest) + active_forest_mean$adjust_residual(forest_dataset_train, outcome_train, forest_model_mean, requires_basis, F) } # Initialize the leaves of each tree in the variance forest if (include_variance_forest) { active_forest_variance$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_variance, leaf_model_variance_forest, variance_forest_init) + } # Run GFR (warm start) if specified @@ -624,26 +641,28 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (include_mean_forest) { forest_model_mean$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_mean, active_forest_mean, - rng, feature_types, leaf_model_mean_forest, current_leaf_scale, variable_weights_mean, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean, + active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = T ) } if (include_variance_forest) { forest_model_variance$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance, - rng, feature_types, leaf_model_variance_forest, current_leaf_scale, variable_weights_variance, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, + active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = T ) } if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) } if (sample_sigma_leaf) { leaf_scale_double <- sampleLeafVarianceOneIteration(active_forest_mean, rng, a_leaf, b_leaf) current_leaf_scale <- as.matrix(leaf_scale_double) if (keep_sample) leaf_scale_samples[sample_counter] <- leaf_scale_double + forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) } if (has_rfx) { rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, keep_sample, current_sigma2, rng) @@ -663,6 +682,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (sample_sigma_leaf) { leaf_scale_double <- leaf_scale_samples[forest_ind + 1] current_leaf_scale <- as.matrix(leaf_scale_double) + forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) } } if (include_variance_forest) { @@ -673,7 +693,10 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train resetRandomEffectsModel(rfx_model, rfx_samples, forest_ind, sigma_alpha_init) resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) } - if (sample_sigma_global) current_sigma2 <- global_var_samples[forest_ind + 1] + if (sample_sigma_global) { + current_sigma2 <- global_var_samples[forest_ind + 1] + global_model_config$update_global_error_variance(current_sigma2) + } } else if (has_prev_model) { if (include_mean_forest) { resetActiveForest(active_forest_mean, previous_forest_samples_mean, previous_model_warmstart_sample_num - 1) @@ -681,21 +704,28 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (sample_sigma_leaf && (!is.null(previous_leaf_var_samples))) { leaf_scale_double <- previous_leaf_var_samples[previous_model_warmstart_sample_num] current_leaf_scale <- as.matrix(leaf_scale_double) + forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) } } if (include_variance_forest) { resetActiveForest(active_forest_variance, previous_forest_samples_variance, previous_model_warmstart_sample_num - 1) resetForestModel(forest_model_variance, active_forest_variance, forest_dataset_train, outcome_train, FALSE) } - # TODO: also initialize from previous RFX samples - # if (has_rfx) { - # rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, - # sigma_xi_init, sigma_xi_shape, sigma_xi_scale) - # rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) - # } + if (has_rfx) { + if (is.null(previous_rfx_samples)) { + warning("`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started") + rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, + sigma_xi_init, sigma_xi_shape, sigma_xi_scale) + rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) + } else { + resetRandomEffectsModel(rfx_model, previous_rfx_samples, previous_model_warmstart_sample_num - 1, sigma_alpha_init) + resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) + } + } if (sample_sigma_global) { if (!is.null(previous_global_var_samples)) { current_sigma2 <- previous_global_var_samples[previous_model_warmstart_sample_num] + global_model_config$update_global_error_variance(current_sigma2) } } } else { @@ -705,6 +735,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train resetForestModel(forest_model_mean, active_forest_mean, forest_dataset_train, outcome_train, TRUE) if (sample_sigma_leaf) { current_leaf_scale <- as.matrix(sigma_leaf_init) + forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) } } if (include_variance_forest) { @@ -717,7 +748,10 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train sigma_xi_init, sigma_xi_shape, sigma_xi_scale) rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) } - if (sample_sigma_global) current_sigma2 <- sigma2_init + if (sample_sigma_global) { + current_sigma2 <- sigma2_init + global_model_config$update_global_error_variance(current_sigma2) + } } for (i in (num_gfr+1):num_samples) { is_mcmc <- i > (num_gfr + num_burnin) @@ -746,26 +780,28 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train if (include_mean_forest) { forest_model_mean$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_mean, active_forest_mean, - rng, feature_types, leaf_model_mean_forest, current_leaf_scale, variable_weights_mean, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean, + active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = F ) } if (include_variance_forest) { forest_model_variance$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance, - rng, feature_types, leaf_model_variance_forest, current_leaf_scale, variable_weights_variance, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, + active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = F ) } if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) } if (sample_sigma_leaf) { leaf_scale_double <- sampleLeafVarianceOneIteration(active_forest_mean, rng, a_leaf, b_leaf) current_leaf_scale <- as.matrix(leaf_scale_double) if (keep_sample) leaf_scale_samples[sample_counter] <- leaf_scale_double + forest_model_config_mean$update_leaf_model_scale(current_leaf_scale) } if (has_rfx) { rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, keep_sample, current_sigma2, rng) @@ -1097,8 +1133,10 @@ predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] #' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, -#' rfx_group_ids_train = rfx_group_ids_train, rfx_group_ids_test = rfx_group_ids_test, -#' rfx_basis_train = rfx_basis_train, rfx_basis_test = rfx_basis_test, +#' rfx_group_ids_train = rfx_group_ids_train, +#' rfx_group_ids_test = rfx_group_ids_test, +#' rfx_basis_train = rfx_basis_train, +#' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100) #' rfx_samples <- getRandomEffectSamples(bart_model) getRandomEffectSamples.bartmodel <- function(object, ...){ @@ -1154,6 +1192,10 @@ getRandomEffectSamples.bartmodel <- function(object, ...){ saveBARTModelToJson <- function(object){ jsonobj <- createCppJson() + if (!inherits(object, "bartmodel")) { + stop("`object` must be a BART model") + } + if (is.null(object$model_params)) { stop("This BCF model has not yet been sampled") } diff --git a/R/bcf.R b/R/bcf.R index 49fa908a..9118333c 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -263,12 +263,18 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id b_forest <- variance_forest_params_updated$var_forest_prior_scale keep_vars_variance <- variance_forest_params_updated$keep_vars drop_vars_variance <- variance_forest_params_updated$drop_vars + + # Check if there are enough GFR samples to seed num_chains samplers + if (num_gfr > 0) { + if (num_chains > num_gfr) { + stop("num_chains > num_gfr, meaning we do not have enough GFR samples to seed num_chains distinct MCMC chains") + } + } # Override keep_gfr if there are no MCMC samples if (num_mcmc == 0) keep_gfr <- T # Check if previous model JSON is provided and parse it if so - # TODO: check that `previous_model_warmstart_sample_num` is <= the number of samples in this previous model has_prev_model <- !is.null(previous_model_json) if (has_prev_model) { previous_bcf_model <- createBCFModelFromJsonString(previous_model_json) @@ -280,8 +286,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id previous_forest_samples_variance <- previous_bcf_model$forests_variance } else previous_forest_samples_variance <- NULL if (previous_bcf_model$model_params$sample_sigma_global) { - previous_global_var_samples <- previous_bcf_model$sigma2_samples*( - previous_var_scale / (previous_y_scale*previous_y_scale) + previous_global_var_samples <- previous_bcf_model$sigma2_samples / ( + previous_y_scale*previous_y_scale ) } else previous_global_var_samples <- NULL if (previous_bcf_model$model_params$sample_sigma_leaf_mu) { @@ -300,10 +306,13 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id previous_b_1_samples <- NULL previous_b_0_samples <- NULL } + previous_model_num_samples <- previous_bcf_model$model_params$num_samples + if (previous_model_warmstart_sample_num > previous_model_num_samples) { + stop("`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`") + } } else { previous_y_bar <- NULL previous_y_scale <- NULL - previous_var_scale <- NULL previous_global_var_samples <- NULL previous_leaf_var_mu_samples <- NULL previous_leaf_var_tau_samples <- NULL @@ -687,8 +696,20 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) current_leaf_scale_tau <- as.matrix(sigma_leaf_tau) + # Set mu and tau leaf models / dimensions + leaf_model_mu_forest <- 0 + leaf_dimension_mu_forest <- 1 + if (ncol(Z_train) > 1) { + leaf_model_tau_forest <- 2 + leaf_dimension_tau_forest <- ncol(Z_train) + } else { + leaf_model_tau_forest <- 1 + leaf_dimension_tau_forest <- 1 + } + # Set variance leaf model type (currently only one option) leaf_model_variance_forest <- 3 + leaf_dimension_variance_forest <- 1 # Random effects prior parameters if (has_rfx) { @@ -763,10 +784,26 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id rng <- createCppRNG(random_seed) # Sampling data structures - forest_model_mu <- createForestModel(forest_dataset_train, feature_types, num_trees_mu, nrow(X_train), alpha_mu, beta_mu, min_samples_leaf_mu, max_depth_mu) - forest_model_tau <- createForestModel(forest_dataset_train, feature_types, num_trees_tau, nrow(X_train), alpha_tau, beta_tau, min_samples_leaf_tau, max_depth_tau) + global_model_config <- createGlobalModelConfig(global_error_variance=current_sigma2) + forest_model_config_mu <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_mu, num_features=ncol(X_train), + num_observations=nrow(X_train), variable_weights=variable_weights_mu, leaf_dimension=leaf_dimension_mu_forest, + alpha=alpha_mu, beta=beta_mu, min_samples_leaf=min_samples_leaf_mu, max_depth=max_depth_mu, + leaf_model_type=leaf_model_mu_forest, leaf_model_scale=current_leaf_scale_mu, + cutpoint_grid_size=cutpoint_grid_size) + forest_model_config_tau <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_tau, num_features=ncol(X_train), + num_observations=nrow(X_train), variable_weights=variable_weights_tau, leaf_dimension=leaf_dimension_tau_forest, + alpha=alpha_tau, beta=beta_tau, min_samples_leaf=min_samples_leaf_tau, max_depth=max_depth_tau, + leaf_model_type=leaf_model_tau_forest, leaf_model_scale=current_leaf_scale_tau, + cutpoint_grid_size=cutpoint_grid_size) + forest_model_mu <- createForestModel(forest_dataset_train, forest_model_config_mu, global_model_config) + forest_model_tau <- createForestModel(forest_dataset_train, forest_model_config_tau, global_model_config) if (include_variance_forest) { - forest_model_variance <- createForestModel(forest_dataset_train, feature_types, num_trees_variance, nrow(X_train), alpha_variance, beta_variance, min_samples_leaf_variance, max_depth_variance) + forest_model_config_variance <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_variance, num_features=ncol(X_train), + num_observations=nrow(X_train), variable_weights=variable_weights_variance, + leaf_dimension=leaf_dimension_variance_forest, alpha=alpha_variance, beta=beta_variance, + min_samples_leaf=min_samples_leaf_variance, max_depth=max_depth_variance, + leaf_model_type=leaf_model_variance_forest, cutpoint_grid_size=cutpoint_grid_size) + forest_model_variance <- createForestModel(forest_dataset_train, forest_model_config_variance, global_model_config) } # Container of forest samples @@ -810,26 +847,28 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Sample the prognostic forest forest_model_mu$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_mu, active_forest_mu, - rng, feature_types, 0, current_leaf_scale_mu, variable_weights_mu, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mu, + active_forest = active_forest_mu, rng = rng, forest_model_config = forest_model_config_mu, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = T ) # Sample variance parameters (if requested) if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) + global_model_config$update_global_error_variance(current_sigma2) } if (sample_sigma_leaf_mu) { leaf_scale_mu_double <- sampleLeafVarianceOneIteration(active_forest_mu, rng, a_leaf_mu, b_leaf_mu) current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) if (keep_sample) leaf_scale_mu_samples[sample_counter] <- leaf_scale_mu_double + forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) } # Sample the treatment forest forest_model_tau$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_tau, active_forest_tau, - rng, feature_types, 1, current_leaf_scale_tau, variable_weights_tau, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_tau, + active_forest = active_forest_tau, rng = rng, forest_model_config = forest_model_config_tau, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = T ) # Sample coding parameters (if requested) @@ -872,19 +911,21 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Sample variance parameters (if requested) if (include_variance_forest) { forest_model_variance$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance, - rng, feature_types, leaf_model_variance_forest, current_leaf_scale_mu, variable_weights_variance, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, + active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = T ) } if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) } if (sample_sigma_leaf_tau) { leaf_scale_tau_double <- sampleLeafVarianceOneIteration(active_forest_tau, rng, a_leaf_tau, b_leaf_tau) current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) if (keep_sample) leaf_scale_tau_samples[sample_counter] <- leaf_scale_tau_double + forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) } # Sample random effects parameters (if requested) @@ -907,10 +948,12 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (sample_sigma_leaf_mu) { leaf_scale_mu_double <- leaf_scale_mu_samples[forest_ind + 1] current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) + forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) } if (sample_sigma_leaf_tau) { leaf_scale_tau_double <- leaf_scale_tau_samples[forest_ind + 1] current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) + forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau) } if (include_variance_forest) { resetActiveForest(active_forest_variance, forest_samples_variance, forest_ind) @@ -931,7 +974,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) } - if (sample_sigma_global) current_sigma2 <- global_var_samples[forest_ind + 1] + if (sample_sigma_global) { + current_sigma2 <- global_var_samples[forest_ind + 1] + global_model_config$update_global_error_variance(current_sigma2) + } } else if (has_prev_model) { resetActiveForest(active_forest_mu, previous_forest_samples_mu, previous_model_warmstart_sample_num - 1) resetForestModel(forest_model_mu, active_forest_mu, forest_dataset_train, outcome_train, TRUE) @@ -944,10 +990,12 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id if (sample_sigma_leaf_mu && (!is.null(previous_leaf_var_mu_samples))) { leaf_scale_mu_double <- previous_leaf_var_mu_samples[previous_model_warmstart_sample_num] current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) + forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) } if (sample_sigma_leaf_tau && (!is.null(previous_leaf_var_tau_samples))) { leaf_scale_tau_double <- previous_leaf_var_tau_samples[previous_model_warmstart_sample_num] current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) + forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau) } if (adaptive_coding) { if (!is.null(previous_b_1_samples)) { @@ -964,16 +1012,22 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) } - # TODO: also initialize from previous RFX samples - # if (has_rfx) { - # rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, - # sigma_xi_init, sigma_xi_shape, sigma_xi_scale) - # rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) - # } + if (has_rfx) { + if (is.null(previous_rfx_samples)) { + warning("`previous_model_json` did not have any random effects samples, so the RFX sampler will be run from scratch while the forests and any other parameters are warm started") + rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, + sigma_xi_init, sigma_xi_shape, sigma_xi_scale) + rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train) + } else { + resetRandomEffectsModel(rfx_model, previous_rfx_samples, previous_model_warmstart_sample_num - 1, sigma_alpha_init) + resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples) + } + } if (sample_sigma_global) { if (!is.null(previous_global_var_samples)) { current_sigma2 <- previous_global_var_samples[previous_model_warmstart_sample_num] } + global_model_config$update_global_error_variance(current_sigma2) } } else { resetActiveForest(active_forest_mu) @@ -984,9 +1038,11 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id resetForestModel(forest_model_tau, active_forest_tau, forest_dataset_train, outcome_train, TRUE) if (sample_sigma_leaf_mu) { current_leaf_scale_mu <- as.matrix(sigma_leaf_mu) + forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) } if (sample_sigma_leaf_tau) { current_leaf_scale_tau <- as.matrix(sigma_leaf_tau) + forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau) } if (include_variance_forest) { resetActiveForest(active_forest_variance) @@ -1009,7 +1065,10 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id } forest_model_tau$propagate_basis_update(forest_dataset_train, outcome_train, active_forest_tau) } - if (sample_sigma_global) current_sigma2 <- sigma2_init + if (sample_sigma_global) { + current_sigma2 <- sigma2_init + global_model_config$update_global_error_variance(current_sigma2) + } } for (i in (num_gfr+1):num_samples) { is_mcmc <- i > (num_gfr + num_burnin) @@ -1038,26 +1097,28 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Sample the prognostic forest forest_model_mu$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_mu, active_forest_mu, - rng, feature_types, 0, current_leaf_scale_mu, variable_weights_mu, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mu, + active_forest = active_forest_mu, rng = rng, forest_model_config = forest_model_config_mu, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = F ) # Sample variance parameters (if requested) if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) + global_model_config$update_global_error_variance(current_sigma2) } if (sample_sigma_leaf_mu) { leaf_scale_mu_double <- sampleLeafVarianceOneIteration(active_forest_mu, rng, a_leaf_mu, b_leaf_mu) current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double) if (keep_sample) leaf_scale_mu_samples[sample_counter] <- leaf_scale_mu_double + forest_model_config_mu$update_leaf_model_scale(current_leaf_scale_mu) } # Sample the treatment forest forest_model_tau$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_tau, active_forest_tau, - rng, feature_types, 1, current_leaf_scale_tau, variable_weights_tau, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_tau, + active_forest = active_forest_tau, rng = rng, forest_model_config = forest_model_config_tau, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = F ) # Sample coding parameters (if requested) @@ -1100,19 +1161,21 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id # Sample variance parameters (if requested) if (include_variance_forest) { forest_model_variance$sample_one_iteration( - forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance, - rng, feature_types, leaf_model_variance_forest, current_leaf_scale_mu, variable_weights_variance, - a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T + forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance, + active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance, + global_model_config = global_model_config, keep_forest = keep_sample, gfr = F ) } if (sample_sigma_global) { current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global) if (keep_sample) global_var_samples[sample_counter] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) } if (sample_sigma_leaf_tau) { leaf_scale_tau_double <- sampleLeafVarianceOneIteration(active_forest_tau, rng, a_leaf_tau, b_leaf_tau) current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double) if (keep_sample) leaf_scale_tau_samples[sample_counter] <- leaf_scale_tau_double + forest_model_config_tau$update_leaf_model_scale(current_leaf_scale_tau) } # Sample random effects parameters (if requested) @@ -1345,7 +1408,8 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id #' mu_train <- mu_x[train_inds] #' tau_test <- tau_x[test_inds] #' tau_train <- tau_x[train_inds] -#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train) +#' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, +#' propensity_train = pi_train) #' preds <- predict(bcf_model, X_test, Z_test, pi_test) #' plot(rowMeans(preds$mu_hat), mu_test, xlab = "predicted", #' ylab = "actual", main = "Prognostic function") @@ -1533,9 +1597,11 @@ predict.bcfmodel <- function(object, X, Z, propensity = NULL, rfx_group_ids = NU #' mu_params <- list(sample_sigma_leaf = TRUE) #' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, #' mu_forest_params = mu_params, @@ -1622,9 +1688,11 @@ getRandomEffectSamples.bcfmodel <- function(object, ...){ #' mu_params <- list(sample_sigma_leaf = TRUE) #' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, #' mu_forest_params = mu_params, @@ -1633,7 +1701,7 @@ getRandomEffectSamples.bcfmodel <- function(object, ...){ saveBCFModelToJson <- function(object){ jsonobj <- createCppJson() - if (class(object) != "bcfmodel") { + if (!inherits(object, "bcfmodel")) { stop("`object` must be a BCF model") } @@ -1785,9 +1853,11 @@ saveBCFModelToJson <- function(object){ #' mu_params <- list(sample_sigma_leaf = TRUE) #' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, #' mu_forest_params = mu_params, @@ -1861,9 +1931,11 @@ saveBCFModelToJsonFile <- function(object, filename){ #' mu_params <- list(sample_sigma_leaf = TRUE) #' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, #' mu_forest_params = mu_params, @@ -1939,9 +2011,11 @@ saveBCFModelToJsonString <- function(object){ #' mu_params <- list(sample_sigma_leaf = TRUE) #' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, #' mu_forest_params = mu_params, @@ -2102,9 +2176,11 @@ createBCFModelFromJson <- function(json_object){ #' mu_params <- list(sample_sigma_leaf = TRUE) #' tau_params <- list(sample_sigma_leaf = FALSE) #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100, #' mu_forest_params = mu_params, @@ -2181,9 +2257,11 @@ createBCFModelFromJsonFile <- function(json_filename){ #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100) #' # bcf_json <- saveBCFModelToJsonString(bcf_model) @@ -2259,9 +2337,11 @@ createBCFModelFromJsonString <- function(json_string){ #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100) #' # bcf_json_list <- list(saveBCFModelToJson(bcf_model)) @@ -2469,9 +2549,11 @@ createBCFModelFromCombinedJson <- function(json_object_list){ #' rfx_term_test <- rfx_term[test_inds] #' rfx_term_train <- rfx_term[train_inds] #' bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, -#' propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, +#' propensity_train = pi_train, +#' rfx_group_ids_train = rfx_group_ids_train, #' rfx_basis_train = rfx_basis_train, X_test = X_test, -#' Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, +#' Z_test = Z_test, propensity_test = pi_test, +#' rfx_group_ids_test = rfx_group_ids_test, #' rfx_basis_test = rfx_basis_test, #' num_gfr = 100, num_burnin = 0, num_mcmc = 100) #' # bcf_json_string_list <- list(saveBCFModelToJsonString(bcf_model)) diff --git a/R/config.R b/R/config.R new file mode 100644 index 00000000..08693674 --- /dev/null +++ b/R/config.R @@ -0,0 +1,395 @@ +#' Object used to get / set parameters and other model configuration options +#' for a forest model in the "low-level" stochtree interface +#' +#' @description +#' The "low-level" stochtree interface enables a high degreee of sampler +#' customization, in which users employ R wrappers around C++ objects +#' like ForestDataset, Outcome, CppRng, and ForestModel to run the +#' Gibbs sampler of a BART model with custom modifications. +#' ForestModelConfig allows users to specify / query the parameters of a +#' forest model they wish to run. + +ForestModelConfig <- R6::R6Class( + classname = "ForestModelConfig", + cloneable = FALSE, + public = list( + + #' @field feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + feature_types = NULL, + + #' @field num_trees Number of trees in the forest being sampled + num_trees = NULL, + + #' @field num_features Number of features in training dataset + num_features = NULL, + + #' @field num_observations Number of observations in training dataset + num_observations = NULL, + + #' @field leaf_dimension Dimension of the leaf model + leaf_dimension = NULL, + + #' @field alpha Root node split probability in tree prior + alpha = NULL, + + #' @field beta Depth prior penalty in tree prior + beta = NULL, + + #' @field min_samples_leaf Minimum number of samples in a tree leaf + min_samples_leaf = NULL, + + #' @field max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. + max_depth = NULL, + + #' @field leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression) + leaf_model_type = NULL, + + #' @field leaf_model_scale Scale parameter used in Gaussian leaf models + leaf_model_scale = NULL, + + #' @field variable_weights Vector specifying sampling probability for all p covariates in ForestDataset + variable_weights = NULL, + + #' @field variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`) + variance_forest_shape = NULL, + + #' @field variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`) + variance_forest_scale = NULL, + + #' @field cutpoint_grid_size Number of unique cutpoints to consider + cutpoint_grid_size = NULL, + + #' Create a new ForestModelConfig object. + #' + #' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + #' @param num_trees Number of trees in the forest being sampled + #' @param num_features Number of features in training dataset + #' @param num_observations Number of observations in training dataset + #' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset + #' @param leaf_dimension Dimension of the leaf model (default: `1`) + #' @param alpha Root node split probability in tree prior (default: `0.95`) + #' @param beta Depth prior penalty in tree prior (default: `2.0`) + #' @param min_samples_leaf Minimum number of samples in a tree leaf (default: `5`) + #' @param max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. Default: `-1`. + #' @param leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: `0`. + #' @param leaf_model_scale Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`). Calibrated internally as `1/num_trees`, propagated along diagonal if needed for multivariate leaf models. + #' @param variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. + #' @param variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. + #' @param cutpoint_grid_size Number of unique cutpoints to consider (default: `100`) + #' + #' @return A new ForestModelConfig object. + initialize = function(feature_types = NULL, num_trees = NULL, num_features = NULL, + num_observations = NULL, variable_weights = NULL, leaf_dimension = 1, + alpha = 0.95, beta = 2.0, min_samples_leaf = 5, max_depth = -1, + leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1.0, + variance_forest_scale = 1.0, cutpoint_grid_size = 100) { + if (is.null(feature_types)) { + if (is.null(num_features)) { + stop("Neither of `num_features` nor `feature_types` (a vector from which `num_features` can be inferred) was provided. Please provide at least one of these inputs when creating a ForestModelConfig object.") + } + warning("`feature_types` not provided, will be assumed to be numeric") + feature_types <- rep(0, num_features) + } else { + if (is.null(num_features)) { + num_features <- length(feature_types) + } + } + if (is.null(variable_weights)) { + warning("`variable_weights` not provided, will be assumed to be equal-weighted") + variable_weights <- rep(1/num_features, num_features) + } + if (num_features != length(feature_types)) { + stop("`feature_types` must have `num_features` total elements") + } + if (num_features != length(variable_weights)) { + stop("`variable_weights` must have `num_features` total elements") + } + self$feature_types <- feature_types + self$variable_weights <- variable_weights + self$num_trees <- num_trees + self$num_features <- num_features + self$num_observations <- num_observations + self$leaf_dimension <- leaf_dimension + self$alpha <- alpha + self$beta <- beta + self$min_samples_leaf <- min_samples_leaf + self$max_depth <- max_depth + self$variance_forest_shape <- variance_forest_shape + self$variance_forest_scale <- variance_forest_scale + self$cutpoint_grid_size <- cutpoint_grid_size + + if (!(as.integer(leaf_model_type) == leaf_model_type)) { + stop("`leaf_model_type` must be an integer between 0 and 3") + if ((leaf_model_type < 0) | (leaf_model_type > 3)) { + stop("`leaf_model_type` must be an integer between 0 and 3") + } + } + self$leaf_model_type <- leaf_model_type + + if (is.null(leaf_model_scale)) { + self$leaf_model_scale <- diag(1/num_trees, leaf_dimension) + } else if (is.matrix(leaf_model_scale)) { + if (ncol(leaf_model_scale) != nrow(leaf_model_scale)) { + stop("`leaf_model_scale` must be a square matrix") + } + if (ncol(leaf_model_scale) != leaf_dimension) { + stop("`leaf_model_scale` must have `leaf_dimension` rows and columns") + } + self$leaf_model_scale <- leaf_model_scale + } else { + if (leaf_model_scale <= 0) { + stop("`leaf_model_scale` must be positive, if provided as scalar") + } + self$leaf_model_scale <- diag(leaf_model_scale, leaf_dimension) + } + }, + + #' @description + #' Update feature types + #' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + update_feature_types = function(feature_types) { + stopifnot(length(feature_types) == self$num_features) + self$feature_types <- feature_types + }, + + #' @description + #' Update variable weights + #' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset + update_variable_weights = function(variable_weights) { + stopifnot(length(variable_weights) == self$num_features) + self$variable_weights <- variable_weights + }, + + #' @description + #' Update root node split probability in tree prior + #' @param alpha Root node split probability in tree prior + update_alpha = function(alpha) { + self$alpha <- alpha + }, + + #' @description + #' Update depth prior penalty in tree prior + #' @param beta Depth prior penalty in tree prior + update_beta = function(beta) { + self$beta <- beta + }, + + #' @description + #' Update root node split probability in tree prior + #' @param min_samples_leaf Minimum number of samples in a tree leaf + update_min_samples_leaf = function(min_samples_leaf) { + self$min_samples_leaf <- min_samples_leaf + }, + + #' @description + #' Update root node split probability in tree prior + #' @param max_depth Maximum depth of any tree in the ensemble in the model + update_max_depth = function(max_depth) { + self$max_depth <- max_depth + }, + + #' @description + #' Update scale parameter used in Gaussian leaf models + #' @param leaf_model_scale Scale parameter used in Gaussian leaf models + update_leaf_model_scale = function(leaf_model_scale) { + if (is.matrix(leaf_model_scale)) { + if (ncol(leaf_model_scale) != nrow(leaf_model_scale)) { + stop("`leaf_model_scale` must be a square matrix") + } + if (ncol(leaf_model_scale) != self$leaf_dimension) { + stop("`leaf_model_scale` must have `leaf_dimension` rows and columns") + } + self$leaf_model_scale <- leaf_model_scale + } else { + if (leaf_model_scale <= 0) { + stop("`leaf_model_scale` must be positive, if provided as scalar") + } + self$leaf_model_scale <- diag(leaf_model_scale, leaf_dimension) + } + }, + + #' @description + #' Update shape parameter for IG leaf models + #' @param variance_forest_shape Shape parameter for IG leaf models + update_variance_forest_shape = function(variance_forest_shape) { + self$variance_forest_shape <- variance_forest_shape + }, + + #' @description + #' Update scale parameter for IG leaf models + #' @param variance_forest_scale Scale parameter for IG leaf models + update_variance_forest_scale = function(variance_forest_scale) { + self$variance_forest_scale <- variance_forest_scale + }, + + #' @description + #' Update number of unique cutpoints to consider + #' @param cutpoint_grid_size Number of unique cutpoints to consider + update_cutpoint_grid_size = function(cutpoint_grid_size) { + self$cutpoint_grid_size <- cutpoint_grid_size + }, + + #' @description + #' Query feature types for this ForestModelConfig object + #' @returns Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + get_feature_types = function() { + return(self$feature_types) + }, + + #' @description + #' Query variable weights for this ForestModelConfig object + #' @returns Vector specifying sampling probability for all p covariates in ForestDataset + get_variable_weights = function() { + return(self$variable_weights) + }, + + #' @description + #' Query root node split probability in tree prior for this ForestModelConfig object + #' @returns Root node split probability in tree prior + get_alpha = function() { + return(self$alpha) + }, + + #' @description + #' Query depth prior penalty in tree prior for this ForestModelConfig object + #' @returns Depth prior penalty in tree prior + get_beta = function() { + return(self$beta) + }, + + #' @description + #' Query root node split probability in tree prior for this ForestModelConfig object + #' @returns Minimum number of samples in a tree leaf + get_min_samples_leaf = function() { + return(self$min_samples_leaf) + }, + + #' @description + #' Query root node split probability in tree prior for this ForestModelConfig object + #' @returns Maximum depth of any tree in the ensemble in the model + get_max_depth = function() { + return(self$max_depth) + }, + + #' @description + #' Query scale parameter used in Gaussian leaf models for this ForestModelConfig object + #' @returns Scale parameter used in Gaussian leaf models + get_leaf_model_scale = function() { + return(self$leaf_model_scale) + }, + + #' @description + #' Query shape parameter for IG leaf models for this ForestModelConfig object + #' @returns Shape parameter for IG leaf models + get_variance_forest_shape = function() { + return(self$variance_forest_shape) + }, + + #' @description + #' Query scale parameter for IG leaf models for this ForestModelConfig object + #' @returns Scale parameter for IG leaf models + get_variance_forest_scale = function() { + return(self$variance_forest_scale) + }, + + #' @description + #' Query number of unique cutpoints to consider for this ForestModelConfig object + #' @returns Number of unique cutpoints to consider + get_cutpoint_grid_size = function() { + return(self$cutpoint_grid_size) + } + ) +) + +#' Object used to get / set global parameters and other global model +#' configuration options in the "low-level" stochtree interface +#' +#' @description +#' The "low-level" stochtree interface enables a high degreee of sampler +#' customization, in which users employ R wrappers around C++ objects +#' like ForestDataset, Outcome, CppRng, and ForestModel to run the +#' Gibbs sampler of a BART model with custom modifications. +#' GlobalModelConfig allows users to specify / query the global parameters +#' of a model they wish to run. + +GlobalModelConfig <- R6::R6Class( + classname = "GlobalModelConfig", + cloneable = FALSE, + public = list( + + #' @field global_error_variance Global error variance parameter + global_error_variance = NULL, + + #' Create a new GlobalModelConfig object. + #' + #' @param global_error_variance Global error variance parameter (default: `1.0`) + #' + #' @return A new GlobalModelConfig object. + initialize = function(global_error_variance = 1.0) { + self$global_error_variance <- global_error_variance + }, + + #' @description + #' Update global error variance parameter + #' @param global_error_variance Global error variance parameter + update_global_error_variance = function(global_error_variance) { + self$global_error_variance <- global_error_variance + }, + + #' @description + #' Query global error variance parameter for this GlobalModelConfig object + #' @returns Global error variance parameter + get_global_error_variance = function() { + return(self$global_error_variance) + } + ) +) + +#' Create a forest model config object +#' +#' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) +#' @param num_trees Number of trees in the forest being sampled +#' @param num_features Number of features in training dataset +#' @param num_observations Number of observations in training dataset +#' @param variable_weights Vector specifying sampling probability for all p covariates in ForestDataset +#' @param leaf_dimension Dimension of the leaf model (default: `1`) +#' @param alpha Root node split probability in tree prior (default: `0.95`) +#' @param beta Depth prior penalty in tree prior (default: `2.0`) +#' @param min_samples_leaf Minimum number of samples in a tree leaf (default: `5`) +#' @param max_depth Maximum depth of any tree in the ensemble in the model. Setting to `-1` does not enforce any depth limits on trees. Default: `-1`. +#' @param leaf_model_type Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: `0`. +#' @param leaf_model_scale Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`). Calibrated internally as `1/num_trees`, propagated along diagonal if needed for multivariate leaf models. +#' @param variance_forest_shape Shape parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. +#' @param variance_forest_scale Scale parameter for IG leaf models (applicable when `leaf_model_type = 3`). Default: `1`. +#' @param cutpoint_grid_size Number of unique cutpoints to consider (default: `100`) +#' @return ForestModelConfig object +#' @export +#' +#' @examples +#' config <- createForestModelConfig(num_trees = 10, num_features = 5, num_observations = 100) +createForestModelConfig <- function(feature_types = NULL, num_trees = NULL, num_features = NULL, + num_observations = NULL, variable_weights = NULL, leaf_dimension = 1, + alpha = 0.95, beta = 2.0, min_samples_leaf = 5, max_depth = -1, + leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1.0, + variance_forest_scale = 1.0, cutpoint_grid_size = 100){ + return(invisible(( + ForestModelConfig$new(feature_types, num_trees, num_features, num_observations, + variable_weights, leaf_dimension, alpha, beta, min_samples_leaf, + max_depth, leaf_model_type, leaf_model_scale, variance_forest_shape, + variance_forest_scale, cutpoint_grid_size) + ))) +} + +#' Create a global model config object +#' +#' @param global_error_variance Global error variance parameter (default: `1.0`) +#' @return GlobalModelConfig object +#' @export +#' +#' @examples +#' config <- createGlobalModelConfig(global_error_variance = 100) +createGlobalModelConfig <- function(global_error_variance = 1.0){ + return(invisible(( + GlobalModelConfig$new(global_error_variance) + ))) +} diff --git a/R/cpp11.R b/R/cpp11.R index 8ad8ba24..7188e9f7 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -556,12 +556,12 @@ compute_leaf_indices_cpp <- function(forest_container, covariates, forest_nums) .Call(`_stochtree_compute_leaf_indices_cpp`, forest_container, covariates, forest_nums) } -sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized) { - invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized)) +sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) { + invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest)) } -sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized) { - invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized)) +sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest) { + invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest)) } sample_sigma2_one_iteration_cpp <- function(residual, dataset, rng, a, b) { diff --git a/R/forest.R b/R/forest.R index 224e8c7c..6af83bbe 100644 --- a/R/forest.R +++ b/R/forest.R @@ -558,6 +558,9 @@ Forest <- R6::R6Class( #' @field forest_ptr External pointer to a C++ TreeEnsemble class forest_ptr = NULL, + #' @field internal_forest_is_empty Whether the forest has not yet been "initialized" such that its `predict` function can be called. + internal_forest_is_empty = TRUE, + #' @description #' Create a new Forest object. #' @param num_trees Number of trees in the forest @@ -567,6 +570,7 @@ Forest <- R6::R6Class( #' @return A new `Forest` object. initialize = function(num_trees, leaf_dimension=1, is_leaf_constant=F, is_exponentiated=F) { self$forest_ptr <- active_forest_cpp(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) + self$internal_forest_is_empty <- TRUE }, #' @description @@ -610,6 +614,7 @@ Forest <- R6::R6Class( #' @param leaf_value Constant leaf value(s) to be fixed for each tree in the ensemble indexed by `forest_num`. Can be either a single number or a vector, depending on the forest's leaf dimension. set_root_leaves = function(leaf_value) { stopifnot(!is.null(self$forest_ptr)) + stopifnot(self$internal_forest_is_empty) # Set leaf values if (length(leaf_value) == 1) { @@ -621,6 +626,8 @@ Forest <- R6::R6Class( } else { stop("leaf_value must be a numeric value or vector of length >= 1") } + + self$internal_forest_is_empty = FALSE }, #' @description @@ -636,12 +643,15 @@ Forest <- R6::R6Class( stopifnot(!is.null(outcome$data_ptr)) stopifnot(!is.null(forest_model$tracker_ptr)) stopifnot(!is.null(self$forest_ptr)) + stopifnot(self$internal_forest_is_empty) # Initialize the model initialize_forest_model_active_forest_cpp( dataset$data_ptr, outcome$data_ptr, self$forest_ptr, forest_model$tracker_ptr, leaf_value, leaf_model_int ) + + self$internal_forest_is_empty = FALSE }, #' @description @@ -745,6 +755,23 @@ Forest <- R6::R6Class( #' @return Average maximum depth average_max_depth = function() { return(ensemble_average_max_depth_active_forest_cpp(self$forest_ptr)) + }, + + #' @description + #' When a forest object is created, it is "empty" in the sense that none + #' of its component trees have leaves with values. There are two ways to + #' "initialize" a Forest object. First, the `set_root_leaves()` method + #' simply initializes every tree in the forest to a single node carrying + #' the same (user-specified) leaf value. Second, the `prepare_for_sampler()` + #' method initializes every tree in the forest to a single node with the + #' same value and also propagates this information through to a ForestModel + #' object, which must be synchronized with a Forest during a forest + #' sampler loop. + #' @return `TRUE` if a Forest has not yet been initialized with a constant + #' root value, `FALSE` otherwise if the forest has already been + #' initialized / grown. + is_empty = function() { + return(self$internal_forest_is_empty) } ) ) @@ -818,6 +845,7 @@ createForest <- function(num_trees, leaf_dimension=1, is_leaf_constant=F, is_exp resetActiveForest <- function(active_forest, forest_samples=NULL, forest_num=NULL) { if (is.null(forest_samples)) { root_reset_active_forest_cpp(active_forest$forest_ptr) + active_forest$internal_forest_is_empty = TRUE } else { if (is.null(forest_num)) { stop("`forest_num` must be specified if `forest_samples` is provided") @@ -860,14 +888,25 @@ resetActiveForest <- function(active_forest, forest_samples=NULL, forest_num=NUL #' y <- -5 + 10*(X[,1] > 0.5) + rnorm(n) #' outcome <- createOutcome(y) #' rng <- createCppRNG(1234) -#' forest_model <- createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) +#' global_model_config <- createGlobalModelConfig(global_error_variance=sigma2) +#' forest_model_config <- createForestModelConfig(feature_types=feature_types, +#' num_trees=num_trees, num_observations=n, +#' num_features=p, alpha=alpha, beta=beta, +#' min_samples_leaf=min_samples_leaf, +#' max_depth=max_depth, +#' variable_weights=variable_weights, +#' cutpoint_grid_size=cutpoint_grid_size, +#' leaf_model_type=leaf_model, +#' leaf_model_scale=leaf_scale) +#' forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) #' active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) -#' forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) +#' forest_samples <- createForestSamples(num_trees, leaf_dimension, +#' is_leaf_constant, is_exponentiated) +#' active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) #' forest_model$sample_one_iteration( #' forest_dataset, outcome, forest_samples, active_forest, -#' rng, feature_types, leaf_model, leaf_scale, variable_weights, -#' a_forest, b_forest, sigma2, cutpoint_grid_size, keep_forest = TRUE, -#' gfr = FALSE, pre_initialized = TRUE +#' rng, forest_model_config, global_model_config, +#' keep_forest = TRUE, gfr = FALSE #' ) #' resetActiveForest(active_forest, forest_samples, 0) #' resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE) diff --git a/R/kernel.R b/R/kernel.R index becbb43b..0e79b47e 100644 --- a/R/kernel.R +++ b/R/kernel.R @@ -48,9 +48,8 @@ #' computeForestLeafIndices(bart_model, X, "mean", c(1,3,9)) computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) { # Extract relevant forest container - object_name <- class(model_object)[1] - stopifnot(object_name %in% c("bartmodel", "bcfmodel", "ForestSamples")) - model_type <- ifelse(object_name=="bartmodel", "bart", ifelse(object_name=="bcfmodel", "bcf", "forest_samples")) + stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples")))) + model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples")) if (model_type == "bart") { stopifnot(forest_type %in% c("mean", "variance")) if (forest_type=="mean") { @@ -143,8 +142,8 @@ computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL, #' computeForestLeafVariances(bart_model, "mean", c(1,3,5)) computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NULL) { # Extract relevant forest container - stopifnot(class(model_object) %in% c("bartmodel", "bcfmodel")) - model_type <- ifelse(class(model_object)=="bartmodel", "bart", "bcf") + stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel")))) + model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", "bcf") if (model_type == "bart") { stopifnot(forest_type %in% c("mean", "variance")) if (forest_type=="mean") { @@ -234,9 +233,8 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU #' computeForestMaxLeafIndex(bart_model, X, "mean", c(1,3,9)) computeForestMaxLeafIndex <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) { # Extract relevant forest container - object_name <- class(model_object)[1] - stopifnot(object_name %in% c("bartmodel", "bcfmodel", "ForestSamples")) - model_type <- ifelse(object_name=="bartmodel", "bart", ifelse(object_name=="bcfmodel", "bcf", "forest_samples")) + stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples")))) + model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples")) if (model_type == "bart") { stopifnot(forest_type %in% c("mean", "variance")) if (forest_type=="mean") { diff --git a/R/model.R b/R/model.R index 5dc1de69..b8da4cf2 100644 --- a/R/model.R +++ b/R/model.R @@ -65,34 +65,39 @@ ForestModel <- R6::R6Class( #' @param forest_samples Container of forest samples #' @param active_forest "Active" forest updated by the sampler in each iteration #' @param rng Wrapper around C++ random number generator - #' @param feature_types Vector specifying the type of all p covariates in `forest_dataset` (0 = numeric, 1 = ordered categorical, 2 = unordered categorical) - #' @param leaf_model_int Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression) - #' @param leaf_model_scale Scale parameter used in the leaf node model (should be a q x q matrix where q is the dimensionality of the basis and is only >1 when `leaf_model_int = 2`) - #' @param variable_weights Vector specifying sampling probability for all p covariates in `forest_dataset` - #' @param a_forest Shape parameter on variance forest model (if applicable) - #' @param b_forest Scale parameter on variance forest model (if applicable) - #' @param global_scale Global variance parameter - #' @param cutpoint_grid_size (Optional) Number of unique cutpoints to consider (default: `500`, currently only used when `GFR = TRUE`) - #' @param keep_forest (Optional) Whether the updated forest sample should be saved to `forest_samples`. Default: `T`. - #' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: `T`. - #' @param pre_initialized (Optional) Whether or not the leaves are pre-initialized outside of the sampling loop (before any samples are drawn). In multi-forest implementations like BCF, this is true, though in the single-forest supervised learning implementation, we can let C++ do the initialization. Default: `F`. - sample_one_iteration = function(forest_dataset, residual, forest_samples, active_forest, rng, feature_types, - leaf_model_int, leaf_model_scale, variable_weights, - a_forest, b_forest, global_scale, cutpoint_grid_size = 500, - keep_forest = T, gfr = T, pre_initialized = F) { + #' @param forest_model_config ForestModelConfig object containing forest model parameters and settings + #' @param global_model_config GlobalModelConfig object containing global model parameters and settings + #' @param keep_forest (Optional) Whether the updated forest sample should be saved to `forest_samples`. Default: `TRUE`. + #' @param gfr (Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: `TRUE`. + sample_one_iteration = function(forest_dataset, residual, forest_samples, active_forest, + rng, forest_model_config, global_model_config, keep_forest = T, gfr = T) { + if (active_forest$is_empty()) { + stop("`active_forest` has not yet been initialized, which is necessary to run the sampler. Please set constant values for `active_forest`'s leaves using either the `set_root_leaves` or `prepare_for_sampler` methods.") + } + + # Unpack parameters from model config object + feature_types <- forest_model_config$feature_types + leaf_model_int <- forest_model_config$leaf_model_type + leaf_model_scale <- forest_model_config$leaf_model_scale + variable_weights <- forest_model_config$variable_weights + a_forest <- forest_model_config$variance_forest_shape + b_forest <- forest_model_config$variance_forest_scale + global_scale <- global_model_config$global_error_variance + cutpoint_grid_size <- forest_model_config$cutpoint_grid_size + if (gfr) { sample_gfr_one_iteration_cpp( forest_dataset$data_ptr, residual$data_ptr, forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr, self$tree_prior_ptr, rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale, - variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest, pre_initialized + variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest ) } else { sample_mcmc_one_iteration_cpp( forest_dataset$data_ptr, residual$data_ptr, forest_samples$forest_container_ptr, active_forest$forest_ptr, self$tracker_ptr, self$tree_prior_ptr, rng$rng_ptr, feature_types, cutpoint_grid_size, leaf_model_scale, - variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest, pre_initialized + variable_weights, a_forest, b_forest, global_scale, leaf_model_int, keep_forest ) } }, @@ -182,14 +187,9 @@ createCppRNG <- function(random_seed = -1){ #' Create a forest model object #' -#' @param forest_dataset `ForestDataset` object, used to initialize forest sampling data structures -#' @param feature_types Feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) -#' @param num_trees Number of trees in the forest being sampled -#' @param n Number of observations in `forest_dataset` -#' @param alpha Root node split probability in tree prior -#' @param beta Depth prior penalty in tree prior -#' @param min_samples_leaf Minimum number of samples in a tree leaf -#' @param max_depth Maximum depth of any tree in the ensemble in the mean model. Setting to ``-1`` does not enforce any depth limits on trees. +#' @param forest_dataset ForestDataset object, used to initialize forest sampling data structures +#' @param forest_model_config ForestModelConfig object containing forest model parameters and settings +#' @param global_model_config GlobalModelConfig object containing global model parameters and settings #' #' @return `ForestModel` object #' @export @@ -205,10 +205,18 @@ createCppRNG <- function(random_seed = -1){ #' feature_types <- as.integer(rep(0, p)) #' X <- matrix(runif(n*p), ncol = p) #' forest_dataset <- createForestDataset(X) -#' forest_model <- createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) -createForestModel <- function(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) { +#' forest_model_config <- createForestModelConfig(feature_types=feature_types, +#' num_trees=num_trees, num_features=p, +#' num_observations=n, alpha=alpha, beta=beta, +#' min_samples_leaf=min_samples_leaf, +#' max_depth=max_depth, leaf_model_type=1) +#' global_model_config <- createGlobalModelConfig(global_error_variance=1.0) +#' forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) +createForestModel <- function(forest_dataset, forest_model_config, global_model_config) { return(invisible(( - ForestModel$new(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) + ForestModel$new(forest_dataset, forest_model_config$feature_types, forest_model_config$num_trees, + forest_model_config$num_observations, forest_model_config$alpha, forest_model_config$beta, + forest_model_config$min_samples_leaf, forest_model_config$max_depth) ))) } diff --git a/_pkgdown.yml b/_pkgdown.yml index f0a68688..43b7c995 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -79,6 +79,10 @@ reference: - createForestModel - ForestSamples - createForestSamples + - ForestModelConfig + - createForestModelConfig + - GlobalModelConfig + - createGlobalModelConfig - CppRNG - createCppRNG - calibrateInverseGammaErrorVariance diff --git a/cran-bootstrap.R b/cran-bootstrap.R index 3a47f19f..e06eba11 100644 --- a/cran-bootstrap.R +++ b/cran-bootstrap.R @@ -14,25 +14,28 @@ # include_vignettes : 1 to include the vignettes folder in the R package subfolder # 0 to exclude vignettes (overriden to 1 if pkgdown_build = 1 below) # -# pkgdown_build : 1 to include pkgdown specific files (R_Readm) -# 0 to exclude vignettes +# pkgdown_build : 1 to include pkgdown specific files (R_README.md, _pkgdown.yml) +# 0 to exclude pkgdown specific files +# +# include_tests : 1 to include unit tests +# 0 to exclude unit tests # # Run this script from the command line via # -# Explicitly include vignettes and build pkgdown site -# --------------------------------------------------- -# Rscript cran-bootstrap.R 1 1 +# Explicitly include vignettes and unit tests and build pkgdown site +# ------------------------------------------------------------------ +# Rscript cran-bootstrap.R 1 1 1 # -# Explicitly include vignettes but don't build pkgdown site -# --------------------------------------------------------- -# Rscript cran-bootstrap.R 1 0 +# Explicitly include vignettes and unit tests but don't build pkgdown site +# ------------------------------------------------------------------------ +# Rscript cran-bootstrap.R 1 0 1 # -# Explicitly exclude vignettes and don't build pkgdown site -# --------------------------------------------------------- -# Rscript cran-bootstrap.R 0 0 +# Explicitly exclude vignettes and unit tests and don't build pkgdown site +# ------------------------------------------------------------------------ +# Rscript cran-bootstrap.R 0 0 0 # -# Exclude vignettes and pkgdown by default -# ---------------------------------------- +# Exclude vignettes, unit tests, and pkgdown by default +# ----------------------------------------------------- # Rscript cran-bootstrap.R # Unpack command line arguments @@ -40,9 +43,11 @@ args <- commandArgs(trailingOnly = T) if (length(args) > 0){ include_vignettes <- as.logical(as.integer(args[1])) pkgdown_build <- as.logical(as.integer(args[2])) + include_tests <- as.logical(as.integer(args[3])) } else{ include_vignettes <- F pkgdown_build <- F + include_tests <- F } # Create the stochtree_cran folder @@ -95,10 +100,14 @@ if (pkgdown_build) { } # Handle tests separately (move from test/R/ folder to tests/ folder) -test_files_src <- list.files("test/R", recursive = TRUE, full.names = TRUE) -test_files_dst <- file.path(cran_dir, gsub("test/R", "tests", test_files_src)) -pkg_core_files <- c(pkg_core_files, test_files_src) -pkg_core_files_dst <- c(pkg_core_files_dst, test_files_dst) +if (include_tests) { + test_files_src <- list.files("test/R", recursive = TRUE, full.names = TRUE) + test_files_dst <- file.path(cran_dir, gsub("test/R", "tests", test_files_src)) + pkg_core_files <- c(pkg_core_files, test_files_src) + pkg_core_files_dst <- c(pkg_core_files_dst, test_files_dst) +} + +# Copy over all core package files if (all(file.exists(pkg_core_files))) { n_removed <- suppressWarnings(sum(file.remove(pkg_core_files_dst))) if (n_removed > 0) { @@ -142,6 +151,23 @@ if (!include_vignettes) { writeLines(description_lines, cran_description) } +# Remove testthat deps from DESCRIPTION if no tests +if (!include_tests) { + cran_description <- file.path(cran_dir, "DESCRIPTION") + description_lines <- readLines(cran_description) + if (include_vignettes) { + suggestion_match <- grep("testthat (>= 3.0.0)", description_lines) + suggestion_lines <- suggestion_match + } else { + suggestion_begin <- grep("Suggests:", description_lines) + suggestion_end <- grep("SystemRequirements:", description_lines) - 1 + suggestion_lines <- suggestion_begin:suggestion_end + } + testthat_config_line <- grep("Config/testthat/edition:", description_lines) + description_lines <- description_lines[-c(suggestion_lines, testthat_config_line)] + writeLines(description_lines, cran_description) +} + # Remove vignettes from _pkgdown.yml if no vignettes if ((!include_vignettes) & (pkgdown_build)) { pkgdown_yml <- file.path(cran_dir, "_pkgdown.yml") diff --git a/inst/COPYRIGHTS b/inst/COPYRIGHTS index 6d246808..71270347 100644 --- a/inst/COPYRIGHTS +++ b/inst/COPYRIGHTS @@ -1,15 +1,34 @@ stochtree Copyright 2023-2025 stochtree contributors +Several stochtree C++ header and source files include or are inspired by code +in several open-source decision tree libraries: xgboost, LightGBM, and treelite. +Copyright and license information for each of these three projects are detailed +further below and in comments in each of the files. +File: src/include/stochtree/category_tracker.h [xgboost] +File: src/include/stochtree/common.h [xgboost] +File: src/include/stochtree/ensemble.h [xgboost] +File: src/include/stochtree/io.h [LightGBM] +File: src/include/stochtree/log.h [LightGBM] +File: src/include/stochtree/meta.h [LightGBM] +File: src/include/stochtree/partition_tracker.h [LightGBM, xgboost] +File: src/include/stochtree/tree.h [xgboost, treelite] + This project includes software from the xgboost project (Apache, 2.0). * Copyright 2015-2024, XGBoost Contributors This project includes software from the LightGBM project (MIT). * Copyright (c) 2016 Microsoft Corporation +This project includes software from the treelite project (Apache, 2.0). +* Copyright (c) 2017-2023 by [treelite] Contributors + This project includes software from the fast_double_parser project (Apache, 2.0). * Copyright (c) Daniel Lemire +This project includes software from the JSON for Modern C++ project (MIT). +* Copyright © 2013-2025 Niels Lohmann + This project includes software from the Eigen project (MPL, 2.0), whose headers carry the following copyrights: File: Eigen/Core Copyright (C) 2008 Gael Guennebaud diff --git a/man/Forest.Rd b/man/Forest.Rd index 075460f4..fd5cc045 100644 --- a/man/Forest.Rd +++ b/man/Forest.Rd @@ -10,6 +10,8 @@ Wrapper around a C++ tree ensemble \if{html}{\out{
}} \describe{ \item{\code{forest_ptr}}{External pointer to a C++ TreeEnsemble class} + +\item{\code{internal_forest_is_empty}}{Whether the forest has not yet been "initialized" such that its \code{predict} function can be called.} } \if{html}{\out{
}} } @@ -32,6 +34,7 @@ Wrapper around a C++ tree ensemble \item \href{#method-Forest-get_forest_split_counts}{\code{Forest$get_forest_split_counts()}} \item \href{#method-Forest-tree_max_depth}{\code{Forest$tree_max_depth()}} \item \href{#method-Forest-average_max_depth}{\code{Forest$average_max_depth()}} +\item \href{#method-Forest-is_empty}{\code{Forest$is_empty()}} } } \if{html}{\out{
}} @@ -361,4 +364,27 @@ Average the maximum depth of each tree in the forest Average maximum depth } } +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Forest-is_empty}{}}} +\subsection{Method \code{is_empty()}}{ +When a forest object is created, it is "empty" in the sense that none +of its component trees have leaves with values. There are two ways to +"initialize" a Forest object. First, the \code{set_root_leaves()} method +simply initializes every tree in the forest to a single node carrying +the same (user-specified) leaf value. Second, the \code{prepare_for_sampler()} +method initializes every tree in the forest to a single node with the +same value and also propagates this information through to a ForestModel +object, which must be synchronized with a Forest during a forest +sampler loop. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Forest$is_empty()}\if{html}{\out{
}} +} + +\subsection{Returns}{ +\code{TRUE} if a Forest has not yet been initialized with a constant +root value, \code{FALSE} otherwise if the forest has already been +initialized / grown. +} +} } diff --git a/man/ForestModel.Rd b/man/ForestModel.Rd index 61b688a3..d317da72 100644 --- a/man/ForestModel.Rd +++ b/man/ForestModel.Rd @@ -85,17 +85,10 @@ Run a single iteration of the forest sampling algorithm (MCMC or GFR) forest_samples, active_forest, rng, - feature_types, - leaf_model_int, - leaf_model_scale, - variable_weights, - a_forest, - b_forest, - global_scale, - cutpoint_grid_size = 500, + forest_model_config, + global_model_config, keep_forest = T, - gfr = T, - pre_initialized = F + gfr = T )}\if{html}{\out{}} } @@ -112,27 +105,13 @@ Run a single iteration of the forest sampling algorithm (MCMC or GFR) \item{\code{rng}}{Wrapper around C++ random number generator} -\item{\code{feature_types}}{Vector specifying the type of all p covariates in \code{forest_dataset} (0 = numeric, 1 = ordered categorical, 2 = unordered categorical)} - -\item{\code{leaf_model_int}}{Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression)} - -\item{\code{leaf_model_scale}}{Scale parameter used in the leaf node model (should be a q x q matrix where q is the dimensionality of the basis and is only >1 when \code{leaf_model_int = 2})} - -\item{\code{variable_weights}}{Vector specifying sampling probability for all p covariates in \code{forest_dataset}} - -\item{\code{a_forest}}{Shape parameter on variance forest model (if applicable)} - -\item{\code{b_forest}}{Scale parameter on variance forest model (if applicable)} - -\item{\code{global_scale}}{Global variance parameter} - -\item{\code{cutpoint_grid_size}}{(Optional) Number of unique cutpoints to consider (default: \code{500}, currently only used when \code{GFR = TRUE})} +\item{\code{forest_model_config}}{ForestModelConfig object containing forest model parameters and settings} -\item{\code{keep_forest}}{(Optional) Whether the updated forest sample should be saved to \code{forest_samples}. Default: \code{T}.} +\item{\code{global_model_config}}{GlobalModelConfig object containing global model parameters and settings} -\item{\code{gfr}}{(Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: \code{T}.} +\item{\code{keep_forest}}{(Optional) Whether the updated forest sample should be saved to \code{forest_samples}. Default: \code{TRUE}.} -\item{\code{pre_initialized}}{(Optional) Whether or not the leaves are pre-initialized outside of the sampling loop (before any samples are drawn). In multi-forest implementations like BCF, this is true, though in the single-forest supervised learning implementation, we can let C++ do the initialization. Default: \code{F}.} +\item{\code{gfr}}{(Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: \code{TRUE}.} } \if{html}{\out{}} } diff --git a/man/ForestModelConfig.Rd b/man/ForestModelConfig.Rd new file mode 100644 index 00000000..e899c8b1 --- /dev/null +++ b/man/ForestModelConfig.Rd @@ -0,0 +1,431 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/config.R +\name{ForestModelConfig} +\alias{ForestModelConfig} +\title{Object used to get / set parameters and other model configuration options +for a forest model in the "low-level" stochtree interface} +\value{ +Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) + +Vector specifying sampling probability for all p covariates in ForestDataset + +Root node split probability in tree prior + +Depth prior penalty in tree prior + +Minimum number of samples in a tree leaf + +Maximum depth of any tree in the ensemble in the model + +Scale parameter used in Gaussian leaf models + +Shape parameter for IG leaf models + +Scale parameter for IG leaf models + +Number of unique cutpoints to consider +} +\description{ +The "low-level" stochtree interface enables a high degreee of sampler +customization, in which users employ R wrappers around C++ objects +like ForestDataset, Outcome, CppRng, and ForestModel to run the +Gibbs sampler of a BART model with custom modifications. +ForestModelConfig allows users to specify / query the parameters of a +forest model they wish to run. +} +\section{Public fields}{ +\if{html}{\out{
}} +\describe{ +\item{\code{feature_types}}{Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)} + +\item{\code{num_trees}}{Number of trees in the forest being sampled} + +\item{\code{num_features}}{Number of features in training dataset} + +\item{\code{num_observations}}{Number of observations in training dataset} + +\item{\code{leaf_dimension}}{Dimension of the leaf model} + +\item{\code{alpha}}{Root node split probability in tree prior} + +\item{\code{beta}}{Depth prior penalty in tree prior} + +\item{\code{min_samples_leaf}}{Minimum number of samples in a tree leaf} + +\item{\code{max_depth}}{Maximum depth of any tree in the ensemble in the model. Setting to \code{-1} does not enforce any depth limits on trees.} + +\item{\code{leaf_model_type}}{Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression)} + +\item{\code{leaf_model_scale}}{Scale parameter used in Gaussian leaf models} + +\item{\code{variable_weights}}{Vector specifying sampling probability for all p covariates in ForestDataset} + +\item{\code{variance_forest_shape}}{Shape parameter for IG leaf models (applicable when \code{leaf_model_type = 3})} + +\item{\code{variance_forest_scale}}{Scale parameter for IG leaf models (applicable when \code{leaf_model_type = 3})} + +\item{\code{cutpoint_grid_size}}{Number of unique cutpoints to consider +Create a new ForestModelConfig object.} +} +\if{html}{\out{
}} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-ForestModelConfig-new}{\code{ForestModelConfig$new()}} +\item \href{#method-ForestModelConfig-update_feature_types}{\code{ForestModelConfig$update_feature_types()}} +\item \href{#method-ForestModelConfig-update_variable_weights}{\code{ForestModelConfig$update_variable_weights()}} +\item \href{#method-ForestModelConfig-update_alpha}{\code{ForestModelConfig$update_alpha()}} +\item \href{#method-ForestModelConfig-update_beta}{\code{ForestModelConfig$update_beta()}} +\item \href{#method-ForestModelConfig-update_min_samples_leaf}{\code{ForestModelConfig$update_min_samples_leaf()}} +\item \href{#method-ForestModelConfig-update_max_depth}{\code{ForestModelConfig$update_max_depth()}} +\item \href{#method-ForestModelConfig-update_leaf_model_scale}{\code{ForestModelConfig$update_leaf_model_scale()}} +\item \href{#method-ForestModelConfig-update_variance_forest_shape}{\code{ForestModelConfig$update_variance_forest_shape()}} +\item \href{#method-ForestModelConfig-update_variance_forest_scale}{\code{ForestModelConfig$update_variance_forest_scale()}} +\item \href{#method-ForestModelConfig-update_cutpoint_grid_size}{\code{ForestModelConfig$update_cutpoint_grid_size()}} +\item \href{#method-ForestModelConfig-get_feature_types}{\code{ForestModelConfig$get_feature_types()}} +\item \href{#method-ForestModelConfig-get_variable_weights}{\code{ForestModelConfig$get_variable_weights()}} +\item \href{#method-ForestModelConfig-get_alpha}{\code{ForestModelConfig$get_alpha()}} +\item \href{#method-ForestModelConfig-get_beta}{\code{ForestModelConfig$get_beta()}} +\item \href{#method-ForestModelConfig-get_min_samples_leaf}{\code{ForestModelConfig$get_min_samples_leaf()}} +\item \href{#method-ForestModelConfig-get_max_depth}{\code{ForestModelConfig$get_max_depth()}} +\item \href{#method-ForestModelConfig-get_leaf_model_scale}{\code{ForestModelConfig$get_leaf_model_scale()}} +\item \href{#method-ForestModelConfig-get_variance_forest_shape}{\code{ForestModelConfig$get_variance_forest_shape()}} +\item \href{#method-ForestModelConfig-get_variance_forest_scale}{\code{ForestModelConfig$get_variance_forest_scale()}} +\item \href{#method-ForestModelConfig-get_cutpoint_grid_size}{\code{ForestModelConfig$get_cutpoint_grid_size()}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-new}{}}} +\subsection{Method \code{new()}}{ +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$new( + feature_types = NULL, + num_trees = NULL, + num_features = NULL, + num_observations = NULL, + variable_weights = NULL, + leaf_dimension = 1, + alpha = 0.95, + beta = 2, + min_samples_leaf = 5, + max_depth = -1, + leaf_model_type = 1, + leaf_model_scale = NULL, + variance_forest_shape = 1, + variance_forest_scale = 1, + cutpoint_grid_size = 100 +)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{feature_types}}{Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)} + +\item{\code{num_trees}}{Number of trees in the forest being sampled} + +\item{\code{num_features}}{Number of features in training dataset} + +\item{\code{num_observations}}{Number of observations in training dataset} + +\item{\code{variable_weights}}{Vector specifying sampling probability for all p covariates in ForestDataset} + +\item{\code{leaf_dimension}}{Dimension of the leaf model (default: \code{1})} + +\item{\code{alpha}}{Root node split probability in tree prior (default: \code{0.95})} + +\item{\code{beta}}{Depth prior penalty in tree prior (default: \code{2.0})} + +\item{\code{min_samples_leaf}}{Minimum number of samples in a tree leaf (default: \code{5})} + +\item{\code{max_depth}}{Maximum depth of any tree in the ensemble in the model. Setting to \code{-1} does not enforce any depth limits on trees. Default: \code{-1}.} + +\item{\code{leaf_model_type}}{Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: \code{0}.} + +\item{\code{leaf_model_scale}}{Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when \code{leaf_model_int = 2}). Calibrated internally as \code{1/num_trees}, propagated along diagonal if needed for multivariate leaf models.} + +\item{\code{variance_forest_shape}}{Shape parameter for IG leaf models (applicable when \code{leaf_model_type = 3}). Default: \code{1}.} + +\item{\code{variance_forest_scale}}{Scale parameter for IG leaf models (applicable when \code{leaf_model_type = 3}). Default: \code{1}.} + +\item{\code{cutpoint_grid_size}}{Number of unique cutpoints to consider (default: \code{100})} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +A new ForestModelConfig object. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_feature_types}{}}} +\subsection{Method \code{update_feature_types()}}{ +Update feature types +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_feature_types(feature_types)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{feature_types}}{Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_variable_weights}{}}} +\subsection{Method \code{update_variable_weights()}}{ +Update variable weights +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_variable_weights(variable_weights)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{variable_weights}}{Vector specifying sampling probability for all p covariates in ForestDataset} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_alpha}{}}} +\subsection{Method \code{update_alpha()}}{ +Update root node split probability in tree prior +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_alpha(alpha)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{alpha}}{Root node split probability in tree prior} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_beta}{}}} +\subsection{Method \code{update_beta()}}{ +Update depth prior penalty in tree prior +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_beta(beta)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{beta}}{Depth prior penalty in tree prior} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_min_samples_leaf}{}}} +\subsection{Method \code{update_min_samples_leaf()}}{ +Update root node split probability in tree prior +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_min_samples_leaf(min_samples_leaf)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{min_samples_leaf}}{Minimum number of samples in a tree leaf} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_max_depth}{}}} +\subsection{Method \code{update_max_depth()}}{ +Update root node split probability in tree prior +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_max_depth(max_depth)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{max_depth}}{Maximum depth of any tree in the ensemble in the model} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_leaf_model_scale}{}}} +\subsection{Method \code{update_leaf_model_scale()}}{ +Update scale parameter used in Gaussian leaf models +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_leaf_model_scale(leaf_model_scale)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{leaf_model_scale}}{Scale parameter used in Gaussian leaf models} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_variance_forest_shape}{}}} +\subsection{Method \code{update_variance_forest_shape()}}{ +Update shape parameter for IG leaf models +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_variance_forest_shape(variance_forest_shape)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{variance_forest_shape}}{Shape parameter for IG leaf models} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_variance_forest_scale}{}}} +\subsection{Method \code{update_variance_forest_scale()}}{ +Update scale parameter for IG leaf models +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_variance_forest_scale(variance_forest_scale)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{variance_forest_scale}}{Scale parameter for IG leaf models} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-update_cutpoint_grid_size}{}}} +\subsection{Method \code{update_cutpoint_grid_size()}}{ +Update number of unique cutpoints to consider +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$update_cutpoint_grid_size(cutpoint_grid_size)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{cutpoint_grid_size}}{Number of unique cutpoints to consider} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_feature_types}{}}} +\subsection{Method \code{get_feature_types()}}{ +Query feature types for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_feature_types()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_variable_weights}{}}} +\subsection{Method \code{get_variable_weights()}}{ +Query variable weights for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_variable_weights()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_alpha}{}}} +\subsection{Method \code{get_alpha()}}{ +Query root node split probability in tree prior for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_alpha()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_beta}{}}} +\subsection{Method \code{get_beta()}}{ +Query depth prior penalty in tree prior for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_beta()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_min_samples_leaf}{}}} +\subsection{Method \code{get_min_samples_leaf()}}{ +Query root node split probability in tree prior for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_min_samples_leaf()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_max_depth}{}}} +\subsection{Method \code{get_max_depth()}}{ +Query root node split probability in tree prior for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_max_depth()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_leaf_model_scale}{}}} +\subsection{Method \code{get_leaf_model_scale()}}{ +Query scale parameter used in Gaussian leaf models for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_leaf_model_scale()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_variance_forest_shape}{}}} +\subsection{Method \code{get_variance_forest_shape()}}{ +Query shape parameter for IG leaf models for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_variance_forest_shape()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_variance_forest_scale}{}}} +\subsection{Method \code{get_variance_forest_scale()}}{ +Query scale parameter for IG leaf models for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_variance_forest_scale()}\if{html}{\out{
}} +} + +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ForestModelConfig-get_cutpoint_grid_size}{}}} +\subsection{Method \code{get_cutpoint_grid_size()}}{ +Query number of unique cutpoints to consider for this ForestModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ForestModelConfig$get_cutpoint_grid_size()}\if{html}{\out{
}} +} + +} +} diff --git a/man/GlobalModelConfig.Rd b/man/GlobalModelConfig.Rd new file mode 100644 index 00000000..fa28e635 --- /dev/null +++ b/man/GlobalModelConfig.Rd @@ -0,0 +1,80 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/config.R +\name{GlobalModelConfig} +\alias{GlobalModelConfig} +\title{Object used to get / set global parameters and other global model +configuration options in the "low-level" stochtree interface} +\value{ +Global error variance parameter +} +\description{ +The "low-level" stochtree interface enables a high degreee of sampler +customization, in which users employ R wrappers around C++ objects +like ForestDataset, Outcome, CppRng, and ForestModel to run the +Gibbs sampler of a BART model with custom modifications. +GlobalModelConfig allows users to specify / query the global parameters +of a model they wish to run. +} +\section{Public fields}{ +\if{html}{\out{
}} +\describe{ +\item{\code{global_error_variance}}{Global error variance parameter +Create a new GlobalModelConfig object.} +} +\if{html}{\out{
}} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-GlobalModelConfig-new}{\code{GlobalModelConfig$new()}} +\item \href{#method-GlobalModelConfig-update_global_error_variance}{\code{GlobalModelConfig$update_global_error_variance()}} +\item \href{#method-GlobalModelConfig-get_global_error_variance}{\code{GlobalModelConfig$get_global_error_variance()}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-GlobalModelConfig-new}{}}} +\subsection{Method \code{new()}}{ +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{GlobalModelConfig$new(global_error_variance = 1)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{global_error_variance}}{Global error variance parameter (default: \code{1.0})} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +A new GlobalModelConfig object. +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-GlobalModelConfig-update_global_error_variance}{}}} +\subsection{Method \code{update_global_error_variance()}}{ +Update global error variance parameter +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{GlobalModelConfig$update_global_error_variance(global_error_variance)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{global_error_variance}}{Global error variance parameter} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-GlobalModelConfig-get_global_error_variance}{}}} +\subsection{Method \code{get_global_error_variance()}}{ +Query global error variance parameter for this GlobalModelConfig object +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{GlobalModelConfig$get_global_error_variance()}\if{html}{\out{
}} +} + +} +} diff --git a/man/createBCFModelFromCombinedJson.Rd b/man/createBCFModelFromCombinedJson.Rd index e374c311..b1fb9ac9 100644 --- a/man/createBCFModelFromCombinedJson.Rd +++ b/man/createBCFModelFromCombinedJson.Rd @@ -70,9 +70,11 @@ rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100) # bcf_json_list <- list(saveBCFModelToJson(bcf_model)) diff --git a/man/createBCFModelFromCombinedJsonString.Rd b/man/createBCFModelFromCombinedJsonString.Rd index f1853d7f..988c7346 100644 --- a/man/createBCFModelFromCombinedJsonString.Rd +++ b/man/createBCFModelFromCombinedJsonString.Rd @@ -70,9 +70,11 @@ rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100) # bcf_json_string_list <- list(saveBCFModelToJsonString(bcf_model)) diff --git a/man/createBCFModelFromJson.Rd b/man/createBCFModelFromJson.Rd index 602db813..2cde726a 100644 --- a/man/createBCFModelFromJson.Rd +++ b/man/createBCFModelFromJson.Rd @@ -71,9 +71,11 @@ rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma_leaf = TRUE) tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, mu_forest_params = mu_params, diff --git a/man/createBCFModelFromJsonFile.Rd b/man/createBCFModelFromJsonFile.Rd index 2f9be821..cb83403e 100644 --- a/man/createBCFModelFromJsonFile.Rd +++ b/man/createBCFModelFromJsonFile.Rd @@ -71,9 +71,11 @@ rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma_leaf = TRUE) tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, mu_forest_params = mu_params, diff --git a/man/createBCFModelFromJsonString.Rd b/man/createBCFModelFromJsonString.Rd index 7e27f9bb..1cd567ca 100644 --- a/man/createBCFModelFromJsonString.Rd +++ b/man/createBCFModelFromJsonString.Rd @@ -69,9 +69,11 @@ rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100) # bcf_json <- saveBCFModelToJsonString(bcf_model) diff --git a/man/createForestModel.Rd b/man/createForestModel.Rd index 05263bbb..d9000925 100644 --- a/man/createForestModel.Rd +++ b/man/createForestModel.Rd @@ -4,33 +4,14 @@ \alias{createForestModel} \title{Create a forest model object} \usage{ -createForestModel( - forest_dataset, - feature_types, - num_trees, - n, - alpha, - beta, - min_samples_leaf, - max_depth -) +createForestModel(forest_dataset, forest_model_config, global_model_config) } \arguments{ -\item{forest_dataset}{\code{ForestDataset} object, used to initialize forest sampling data structures} +\item{forest_dataset}{ForestDataset object, used to initialize forest sampling data structures} -\item{feature_types}{Feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)} +\item{forest_model_config}{ForestModelConfig object containing forest model parameters and settings} -\item{num_trees}{Number of trees in the forest being sampled} - -\item{n}{Number of observations in \code{forest_dataset}} - -\item{alpha}{Root node split probability in tree prior} - -\item{beta}{Depth prior penalty in tree prior} - -\item{min_samples_leaf}{Minimum number of samples in a tree leaf} - -\item{max_depth}{Maximum depth of any tree in the ensemble in the mean model. Setting to \code{-1} does not enforce any depth limits on trees.} +\item{global_model_config}{GlobalModelConfig object containing global model parameters and settings} } \value{ \code{ForestModel} object @@ -49,5 +30,11 @@ max_depth <- 10 feature_types <- as.integer(rep(0, p)) X <- matrix(runif(n*p), ncol = p) forest_dataset <- createForestDataset(X) -forest_model <- createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) +forest_model_config <- createForestModelConfig(feature_types=feature_types, + num_trees=num_trees, num_features=p, + num_observations=n, alpha=alpha, beta=beta, + min_samples_leaf=min_samples_leaf, + max_depth=max_depth, leaf_model_type=1) +global_model_config <- createGlobalModelConfig(global_error_variance=1.0) +forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) } diff --git a/man/createForestModelConfig.Rd b/man/createForestModelConfig.Rd new file mode 100644 index 00000000..90de767c --- /dev/null +++ b/man/createForestModelConfig.Rd @@ -0,0 +1,64 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/config.R +\name{createForestModelConfig} +\alias{createForestModelConfig} +\title{Create a forest model config object} +\usage{ +createForestModelConfig( + feature_types = NULL, + num_trees = NULL, + num_features = NULL, + num_observations = NULL, + variable_weights = NULL, + leaf_dimension = 1, + alpha = 0.95, + beta = 2, + min_samples_leaf = 5, + max_depth = -1, + leaf_model_type = 1, + leaf_model_scale = NULL, + variance_forest_shape = 1, + variance_forest_scale = 1, + cutpoint_grid_size = 100 +) +} +\arguments{ +\item{feature_types}{Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)} + +\item{num_trees}{Number of trees in the forest being sampled} + +\item{num_features}{Number of features in training dataset} + +\item{num_observations}{Number of observations in training dataset} + +\item{variable_weights}{Vector specifying sampling probability for all p covariates in ForestDataset} + +\item{leaf_dimension}{Dimension of the leaf model (default: \code{1})} + +\item{alpha}{Root node split probability in tree prior (default: \code{0.95})} + +\item{beta}{Depth prior penalty in tree prior (default: \code{2.0})} + +\item{min_samples_leaf}{Minimum number of samples in a tree leaf (default: \code{5})} + +\item{max_depth}{Maximum depth of any tree in the ensemble in the model. Setting to \code{-1} does not enforce any depth limits on trees. Default: \code{-1}.} + +\item{leaf_model_type}{Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: \code{0}.} + +\item{leaf_model_scale}{Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when \code{leaf_model_int = 2}). Calibrated internally as \code{1/num_trees}, propagated along diagonal if needed for multivariate leaf models.} + +\item{variance_forest_shape}{Shape parameter for IG leaf models (applicable when \code{leaf_model_type = 3}). Default: \code{1}.} + +\item{variance_forest_scale}{Scale parameter for IG leaf models (applicable when \code{leaf_model_type = 3}). Default: \code{1}.} + +\item{cutpoint_grid_size}{Number of unique cutpoints to consider (default: \code{100})} +} +\value{ +ForestModelConfig object +} +\description{ +Create a forest model config object +} +\examples{ +config <- createForestModelConfig(num_trees = 10, num_features = 5, num_observations = 100) +} diff --git a/man/createGlobalModelConfig.Rd b/man/createGlobalModelConfig.Rd new file mode 100644 index 00000000..59225789 --- /dev/null +++ b/man/createGlobalModelConfig.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/config.R +\name{createGlobalModelConfig} +\alias{createGlobalModelConfig} +\title{Create a global model config object} +\usage{ +createGlobalModelConfig(global_error_variance = 1) +} +\arguments{ +\item{global_error_variance}{Global error variance parameter (default: \code{1.0})} +} +\value{ +GlobalModelConfig object +} +\description{ +Create a global model config object +} +\examples{ +config <- createGlobalModelConfig(global_error_variance = 100) +} diff --git a/man/getRandomEffectSamples.bartmodel.Rd b/man/getRandomEffectSamples.bartmodel.Rd index 2ff00687..9f273732 100644 --- a/man/getRandomEffectSamples.bartmodel.Rd +++ b/man/getRandomEffectSamples.bartmodel.Rd @@ -52,8 +52,10 @@ rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, - rfx_group_ids_train = rfx_group_ids_train, rfx_group_ids_test = rfx_group_ids_test, - rfx_basis_train = rfx_basis_train, rfx_basis_test = rfx_basis_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100) rfx_samples <- getRandomEffectSamples(bart_model) } diff --git a/man/getRandomEffectSamples.bcfmodel.Rd b/man/getRandomEffectSamples.bcfmodel.Rd index ca03ffe4..410f44c4 100644 --- a/man/getRandomEffectSamples.bcfmodel.Rd +++ b/man/getRandomEffectSamples.bcfmodel.Rd @@ -73,9 +73,11 @@ rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma_leaf = TRUE) tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, mu_forest_params = mu_params, diff --git a/man/predict.bcfmodel.Rd b/man/predict.bcfmodel.Rd index 3fd2f1a4..c0b14eb5 100644 --- a/man/predict.bcfmodel.Rd +++ b/man/predict.bcfmodel.Rd @@ -78,7 +78,8 @@ mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] -bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train) +bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train) preds <- predict(bcf_model, X_test, Z_test, pi_test) plot(rowMeans(preds$mu_hat), mu_test, xlab = "predicted", ylab = "actual", main = "Prognostic function") diff --git a/man/resetForestModel.Rd b/man/resetForestModel.Rd index 07f3a8fa..f0fec6ca 100644 --- a/man/resetForestModel.Rd +++ b/man/resetForestModel.Rd @@ -47,14 +47,25 @@ forest_dataset <- createForestDataset(X) y <- -5 + 10*(X[,1] > 0.5) + rnorm(n) outcome <- createOutcome(y) rng <- createCppRNG(1234) -forest_model <- createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) +global_model_config <- createGlobalModelConfig(global_error_variance=sigma2) +forest_model_config <- createForestModelConfig(feature_types=feature_types, + num_trees=num_trees, num_observations=n, + num_features=p, alpha=alpha, beta=beta, + min_samples_leaf=min_samples_leaf, + max_depth=max_depth, + variable_weights=variable_weights, + cutpoint_grid_size=cutpoint_grid_size, + leaf_model_type=leaf_model, + leaf_model_scale=leaf_scale) +forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) -forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) +forest_samples <- createForestSamples(num_trees, leaf_dimension, + is_leaf_constant, is_exponentiated) +active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, active_forest, - rng, feature_types, leaf_model, leaf_scale, variable_weights, - a_forest, b_forest, sigma2, cutpoint_grid_size, keep_forest = TRUE, - gfr = FALSE, pre_initialized = TRUE + rng, forest_model_config, global_model_config, + keep_forest = TRUE, gfr = FALSE ) resetActiveForest(active_forest, forest_samples, 0) resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE) diff --git a/man/saveBCFModelToJson.Rd b/man/saveBCFModelToJson.Rd index 89598334..171d0c53 100644 --- a/man/saveBCFModelToJson.Rd +++ b/man/saveBCFModelToJson.Rd @@ -69,9 +69,11 @@ rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma_leaf = TRUE) tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, mu_forest_params = mu_params, diff --git a/man/saveBCFModelToJsonFile.Rd b/man/saveBCFModelToJsonFile.Rd index 14417564..2c8ea980 100644 --- a/man/saveBCFModelToJsonFile.Rd +++ b/man/saveBCFModelToJsonFile.Rd @@ -71,9 +71,11 @@ rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma_leaf = TRUE) tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, mu_forest_params = mu_params, diff --git a/man/saveBCFModelToJsonString.Rd b/man/saveBCFModelToJsonString.Rd index e1d6769c..3c0bdee1 100644 --- a/man/saveBCFModelToJsonString.Rd +++ b/man/saveBCFModelToJsonString.Rd @@ -69,9 +69,11 @@ rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma_leaf = TRUE) tau_params <- list(sample_sigma_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, - propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, + propensity_train = pi_train, + rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, - Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, + Z_test = Z_test, propensity_test = pi_test, + rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 100, num_burnin = 0, num_mcmc = 100, mu_forest_params = mu_params, diff --git a/man/stochtree-package.Rd b/man/stochtree-package.Rd index 0377fb91..6d82b32a 100644 --- a/man/stochtree-package.Rd +++ b/man/stochtree-package.Rd @@ -4,9 +4,9 @@ \name{stochtree-package} \alias{stochtree} \alias{stochtree-package} -\title{stochtree: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference} +\title{stochtree: Stochastic tree Ensembles (XBART and BART) for Supervised Learning and Causal Inference} \description{ -Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference. +Flexible stochastic tree ensemble software. Robust implementations of Bayesian Additive Regression Trees (Chipman, George, McCulloch (2010) \doi{10.1214/09-AOAS285}) for supervised learning and (Bayesian Causal Forests (BCF) Hahn, Murray, Carvalho (2020) \doi{10.1214/19-BA1195}) for causal inference. Enables model serialization and parallel sampling and provides a low-level interface for custom stochastic forest samplers. } \seealso{ Useful links: diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 2364da8f..00a3fcbc 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -1028,18 +1028,18 @@ extern "C" SEXP _stochtree_compute_leaf_indices_cpp(SEXP forest_container, SEXP END_CPP11 } // sampler.cpp -void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest, bool pre_initialized); -extern "C" SEXP _stochtree_sample_gfr_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest, SEXP pre_initialized) { +void sample_gfr_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest); +extern "C" SEXP _stochtree_sample_gfr_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest) { BEGIN_CPP11 - sample_gfr_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest), cpp11::as_cpp>(pre_initialized)); + sample_gfr_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest)); return R_NilValue; END_CPP11 } // sampler.cpp -void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest, bool pre_initialized); -extern "C" SEXP _stochtree_sample_mcmc_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest, SEXP pre_initialized) { +void sample_mcmc_one_iteration_cpp(cpp11::external_pointer data, cpp11::external_pointer residual, cpp11::external_pointer forest_samples, cpp11::external_pointer active_forest, cpp11::external_pointer tracker, cpp11::external_pointer split_prior, cpp11::external_pointer rng, cpp11::integers feature_types, int cutpoint_grid_size, cpp11::doubles_matrix<> leaf_model_scale_input, cpp11::doubles variable_weights, double a_forest, double b_forest, double global_variance, int leaf_model_int, bool keep_forest); +extern "C" SEXP _stochtree_sample_mcmc_one_iteration_cpp(SEXP data, SEXP residual, SEXP forest_samples, SEXP active_forest, SEXP tracker, SEXP split_prior, SEXP rng, SEXP feature_types, SEXP cutpoint_grid_size, SEXP leaf_model_scale_input, SEXP variable_weights, SEXP a_forest, SEXP b_forest, SEXP global_variance, SEXP leaf_model_int, SEXP keep_forest) { BEGIN_CPP11 - sample_mcmc_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest), cpp11::as_cpp>(pre_initialized)); + sample_mcmc_one_iteration_cpp(cpp11::as_cpp>>(data), cpp11::as_cpp>>(residual), cpp11::as_cpp>>(forest_samples), cpp11::as_cpp>>(active_forest), cpp11::as_cpp>>(tracker), cpp11::as_cpp>>(split_prior), cpp11::as_cpp>>(rng), cpp11::as_cpp>(feature_types), cpp11::as_cpp>(cutpoint_grid_size), cpp11::as_cpp>>(leaf_model_scale_input), cpp11::as_cpp>(variable_weights), cpp11::as_cpp>(a_forest), cpp11::as_cpp>(b_forest), cpp11::as_cpp>(global_variance), cpp11::as_cpp>(leaf_model_int), cpp11::as_cpp>(keep_forest)); return R_NilValue; END_CPP11 } @@ -1583,8 +1583,8 @@ static const R_CallMethodDef CallEntries[] = { {"_stochtree_rng_cpp", (DL_FUNC) &_stochtree_rng_cpp, 1}, {"_stochtree_root_reset_active_forest_cpp", (DL_FUNC) &_stochtree_root_reset_active_forest_cpp, 1}, {"_stochtree_root_reset_rfx_tracker_cpp", (DL_FUNC) &_stochtree_root_reset_rfx_tracker_cpp, 4}, - {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 17}, - {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 17}, + {"_stochtree_sample_gfr_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_gfr_one_iteration_cpp, 16}, + {"_stochtree_sample_mcmc_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_mcmc_one_iteration_cpp, 16}, {"_stochtree_sample_sigma2_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_sigma2_one_iteration_cpp, 5}, {"_stochtree_sample_tau_one_iteration_cpp", (DL_FUNC) &_stochtree_sample_tau_one_iteration_cpp, 4}, {"_stochtree_set_leaf_value_active_forest_cpp", (DL_FUNC) &_stochtree_set_leaf_value_active_forest_cpp, 2}, diff --git a/src/sampler.cpp b/src/sampler.cpp index 5b5d8afb..4dbe5e13 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -24,9 +24,12 @@ void sample_gfr_one_iteration_cpp(cpp11::external_pointer feature_types_(feature_types.size()); for (int i = 0; i < feature_types.size(); i++) { @@ -93,9 +96,12 @@ void sample_mcmc_one_iteration_cpp(cpp11::external_pointer feature_types_(feature_types.size()); for (int i = 0; i < feature_types.size(); i++) { diff --git a/test/R/testthat/test-bart.R b/test/R/testthat/test-bart.R index b0372236..9a885b91 100644 --- a/test/R/testthat/test-bart.R +++ b/test/R/testthat/test-bart.R @@ -129,3 +129,106 @@ test_that("GFR BART", { general_params = general_param_list) ) }) + +test_that("Warmstart BART", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + noise_sd <- 1 + y <- f_XW + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BART model with only GFR + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 0, + general_params = general_param_list) + + # Save to JSON string + bart_model_json_string <- saveBARTModelToJsonString(bart_model) + + # Run a new BART chain from the existing (X)BART model + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + num_gfr = 0, num_burnin = 10, num_mcmc = 10, + previous_model_json = bart_model_json_string, + previous_model_warmstart_sample_num = 1, + general_params = general_param_list) + + ) + + # Generate simulated data with random effects + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + rfx_group_ids <- sample(1:2, size = n, replace = T) + rfx_basis <- rep(1, n) + rfx_coefs <- c(-5, 5) + rfx_term <- rfx_coefs[rfx_group_ids] * rfx_basis + noise_sd <- 1 + y <- f_XW + rfx_term + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds] + rfx_basis_train <- rfx_basis[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BART model with only GFR + general_param_list <- list(num_chains = 1, keep_every = 1) + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 10, num_burnin = 0, num_mcmc = 0, + general_params = general_param_list) + + # Save to JSON string + bart_model_json_string <- saveBARTModelToJsonString(bart_model) + + # Run a new BART chain from the existing (X)BART model + general_param_list <- list(num_chains = 4, keep_every = 5) + expect_no_error( + bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + num_gfr = 0, num_burnin = 10, num_mcmc = 10, + previous_model_json = bart_model_json_string, + previous_model_warmstart_sample_num = 1, + general_params = general_param_list) + ) +}) diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R new file mode 100644 index 00000000..0a34c37c --- /dev/null +++ b/test/R/testthat/test-bcf.R @@ -0,0 +1,331 @@ +test_that("MCMC BCF", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + mu_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + pi_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) + ) + tau_X <- ( + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) + ) + Z <- rbinom(n, 1, pi_X) + noise_sd <- 1 + y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi[test_inds] + pi_train <- pi[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 1 chain, no thinning + general_param_list <- list(num_chains = 1, keep_every = 1) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) + + # 3 chains, no thinning + general_param_list <- list(num_chains = 3, keep_every = 1) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) + + # 1 chain, thinning + general_param_list <- list(num_chains = 1, keep_every = 5) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) + + # 3 chains, thinning + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) +}) + +test_that("GFR BCF", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + mu_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + pi_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) + ) + tau_X <- ( + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) + ) + Z <- rbinom(n, 1, pi_X) + noise_sd <- 1 + y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi[test_inds] + pi_train <- pi[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # 1 chain, no thinning + general_param_list <- list(num_chains = 1, keep_every = 1) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) + + # 3 chains, no thinning + general_param_list <- list(num_chains = 3, keep_every = 1) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) + + # 1 chain, thinning + general_param_list <- list(num_chains = 1, keep_every = 5) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) + + # 3 chains, thinning + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) + + # Check for error when more chains than GFR forests + general_param_list <- list(num_chains = 11, keep_every = 1) + expect_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) + + # Check for error when more chains than GFR forests + general_param_list <- list(num_chains = 11, keep_every = 5) + expect_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 10, + num_mcmc = 10, general_params = general_param_list) + ) +}) + +test_that("Warmstart BCF", { + skip_on_cran() + + # Generate simulated data + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + mu_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + pi_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) + ) + tau_X <- ( + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) + ) + Z <- rbinom(n, 1, pi_X) + noise_sd <- 1 + y <- mu_X + tau_X*Z + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi[test_inds] + pi_train <- pi[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BCF model with only GFR + general_param_list <- list(num_chains = 1, keep_every = 1) + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 0, + num_mcmc = 0, general_params = general_param_list) + + # Save to JSON string + bcf_model_json_string <- saveBCFModelToJsonString(bcf_model) + + # Run a new BCF chain from the existing (X)BCF model + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = 10, + num_mcmc = 10, previous_model_json = bcf_model_json_string, + previous_model_warmstart_sample_num = 1, + general_params = general_param_list) + ) + + # Generate simulated data with random effects + n <- 100 + p <- 5 + X <- matrix(runif(n*p), ncol = p) + mu_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ) + pi_X <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) + ) + tau_X <- ( + ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) + ) + Z <- rbinom(n, 1, pi_X) + rfx_group_ids <- sample(1:2, size = n, replace = T) + rfx_basis <- rep(1, n) + rfx_coefs <- c(-5, 5) + rfx_term <- rfx_coefs[rfx_group_ids] * rfx_basis + noise_sd <- 1 + y <- mu_X + tau_X*Z + rfx_term + rnorm(n, 0, noise_sd) + test_set_pct <- 0.2 + n_test <- round(test_set_pct*n) + n_train <- n - n_test + test_inds <- sort(sample(1:n, n_test, replace = FALSE)) + train_inds <- (1:n)[!((1:n) %in% test_inds)] + X_test <- X[test_inds,] + X_train <- X[train_inds,] + Z_test <- Z[test_inds] + Z_train <- Z[train_inds] + pi_test <- pi[test_inds] + pi_train <- pi[train_inds] + mu_test <- mu_X[test_inds] + mu_train <- mu_X[train_inds] + tau_test <- tau_X[test_inds] + tau_train <- tau_X[train_inds] + rfx_group_ids_test <- rfx_group_ids[test_inds] + rfx_group_ids_train <- rfx_group_ids[train_inds] + rfx_basis_test <- rfx_basis[test_inds] + rfx_basis_train <- rfx_basis[train_inds] + y_test <- y[test_inds] + y_train <- y[train_inds] + + # Run a BCF model with only GFR + general_param_list <- list(num_chains = 1, keep_every = 1) + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + propensity_test = pi_test, num_gfr = 10, num_burnin = 0, + num_mcmc = 0, general_params = general_param_list) + + # Save to JSON string + bcf_model_json_string <- saveBCFModelToJsonString(bcf_model) + + # Run a new BCF chain from the existing (X)BCF model + general_param_list <- list(num_chains = 3, keep_every = 5) + expect_no_error( + bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + rfx_group_ids_train = rfx_group_ids_train, + rfx_group_ids_test = rfx_group_ids_test, + rfx_basis_train = rfx_basis_train, + rfx_basis_test = rfx_basis_test, + propensity_test = pi_test, num_gfr = 0, num_burnin = 10, + num_mcmc = 10, previous_model_json = bcf_model_json_string, + previous_model_warmstart_sample_num = 1, + general_params = general_param_list) + ) +}) diff --git a/test/R/testthat/test-residual.R b/test/R/testthat/test-residual.R index 04165271..eef4f731 100644 --- a/test/R/testthat/test-residual.R +++ b/test/R/testthat/test-residual.R @@ -36,7 +36,14 @@ test_that("Residual updates correctly propagated after forest sampling step", { cpp_rng = createCppRNG(-1) # Create forest sampler and forest container - forest_model = createForestModel(forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth) + global_model_config = createGlobalModelConfig(global_error_variance=current_sigma2) + forest_model_config = createForestModelConfig(feature_types=feature_types, num_trees=num_trees, + num_observations=n, alpha=alpha, beta=beta, + min_samples_leaf=min_samples_leaf, max_depth=max_depth, + leaf_model_type=0, leaf_model_scale=current_leaf_scale, + variable_weights=variable_weights, variance_forest_shape=a_forest, + variance_forest_scale=b_forest, cutpoint_grid_size=cutpoint_grid_size) + forest_model = createForestModel(forest_dataset, forest_model_config, global_model_config) forest_samples = createForestSamples(num_trees, 1, F) active_forest = createForest(num_trees, 1, F) @@ -47,8 +54,7 @@ test_that("Residual updates correctly propagated after forest sampling step", { # Run the forest sampling algorithm for a single iteration forest_model$sample_one_iteration( forest_dataset, residual, forest_samples, active_forest, - cpp_rng, feature_types, 0, current_leaf_scale, variable_weights, a_forest, b_forest, - current_sigma2, cutpoint_grid_size, keep_forest = T, gfr = T, pre_initialized = T + cpp_rng, forest_model_config, global_model_config, keep_forest = T, gfr = T ) # Get the current residual after running the sampler diff --git a/test/R/testthat/test-serialization.R b/test/R/testthat/test-serialization.R index e640d3f8..0d78957f 100644 --- a/test/R/testthat/test-serialization.R +++ b/test/R/testthat/test-serialization.R @@ -7,9 +7,9 @@ test_that("BART Serialization", { X <- matrix(runif(n*p), ncol = p) f_XW <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + - ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + - ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + - ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 y <- f_XW + rnorm(n, 0, noise_sd) diff --git a/vignettes/CausalInference.Rmd b/vignettes/CausalInference.Rmd index e7a8ff61..093d0260 100644 --- a/vignettes/CausalInference.Rmd +++ b/vignettes/CausalInference.Rmd @@ -110,7 +110,7 @@ initialization samples (@krantsevich2023stochastic). This is the default in ```{r} num_gfr <- 10 num_burnin <- 0 -num_mcmc <- 1000 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) @@ -159,8 +159,8 @@ Next, we simulate from this ensemble model without any warm-start initialization ```{r} num_gfr <- 0 -num_burnin <- 1000 -num_mcmc <- 1000 +num_burnin <- 2000 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) @@ -328,7 +328,7 @@ Next, we simulate from this ensemble model without any warm-start initialization ```{r} num_gfr <- 0 -num_burnin <- 100 +num_burnin <- 2000 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) @@ -497,7 +497,7 @@ Next, we simulate from this ensemble model without any warm-start initialization ```{r} num_gfr <- 0 -num_burnin <- 100 +num_burnin <- 2000 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) @@ -566,7 +566,7 @@ X_4 &\sim N\left(X_2,1\right)\\ We draw from the DGP defined above ```{r data_4} -n <- 1000 +n <- 500 x1 <- rnorm(n) x2 <- rnorm(n) x3 <- rnorm(n) @@ -664,7 +664,7 @@ Next, we simulate from this ensemble model without any warm-start initialization ```{r} num_gfr <- 0 -num_burnin <- 100 +num_burnin <- 2000 num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) @@ -773,9 +773,9 @@ rfx_term_train <- rfx_term[train_inds] Here we simulate only from the "warm-start" model (running root-MCMC BART with random effects is simply a matter of modifying the below code snippet by setting `num_gfr <- 0` and `num_mcmc` > 0). ```{r} -num_gfr <- 100 +num_gfr <- 10 num_burnin <- 0 -num_mcmc <- 500 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) @@ -891,8 +891,8 @@ Here we simulate from the model with the original MCMC sampler, using all of the ```{r} num_gfr <- 0 -num_burnin <- 1000 -num_mcmc <- 1000 +num_burnin <- 2000 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) @@ -953,8 +953,8 @@ Here we simulate from the model with the original MCMC sampler, using only covar ```{r} num_gfr <- 0 -num_burnin <- 1000 -num_mcmc <- 1000 +num_burnin <- 2000 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) @@ -1016,7 +1016,7 @@ Here we simulate from the model with the warm-start sampler, using all of the co ```{r} num_gfr <- 10 num_burnin <- 0 -num_mcmc <- 1000 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) @@ -1078,7 +1078,7 @@ Here we simulate from the model with the warm-start sampler, using only covariat ```{r} num_gfr <- 10 num_burnin <- 0 -num_mcmc <- 1000 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) @@ -1206,7 +1206,7 @@ initialization samples (@krantsevich2023stochastic). This is the default in ```{r} num_gfr <- 10 num_burnin <- 0 -num_mcmc <- 1000 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) @@ -1255,8 +1255,8 @@ Next, we simulate from this ensemble model without any warm-start initialization ```{r} num_gfr <- 0 -num_burnin <- 1000 -num_mcmc <- 1000 +num_burnin <- 2000 +num_mcmc <- 100 num_samples <- num_gfr + num_burnin + num_mcmc general_params <- list(keep_every = 5) mu_forest_params <- list(sample_sigma2_leaf = F) diff --git a/vignettes/CustomSamplingRoutine.Rmd b/vignettes/CustomSamplingRoutine.Rmd index 9ed1725c..22399325 100644 --- a/vignettes/CustomSamplingRoutine.Rmd +++ b/vignettes/CustomSamplingRoutine.Rmd @@ -101,10 +101,11 @@ beta <- 1.25 min_samples_leaf <- 1 max_depth <- 10 num_trees <- 100 -cutpoint_grid_size = 100 -global_variance_init = 1. -tau_init = 0.5 -leaf_prior_scale = matrix(c(tau_init), ncol = 1) +cutpoint_grid_size <- 100 +global_variance_init <- 1. +current_sigma2 <- global_variance_init +tau_init <- 1/num_trees +leaf_prior_scale <- as.matrix(ifelse(p_W >= 1, diag(tau_init, p_W), diag(tau_init, 1))) nu <- 4 lambda <- 0.5 a_leaf <- 2. @@ -121,9 +122,11 @@ Initialize R-level access to the C++ classes needed to sample our model if (leaf_regression) { forest_dataset <- createForestDataset(X, W) outcome_model_type <- 1 + leaf_dimension <- p_W } else { forest_dataset <- createForestDataset(X) outcome_model_type <- 0 + leaf_dimension <- 1 } outcome <- createOutcome(resid) @@ -131,9 +134,15 @@ outcome <- createOutcome(resid) rng <- createCppRNG() # Sampling data structures -forest_model <- createForestModel(forest_dataset, feature_types, - num_trees, n, alpha, beta, - min_samples_leaf, max_depth) +forest_model_config <- createForestModelConfig( + feature_types = feature_types, num_trees = num_trees, num_features = p_X, + num_observations = n, variable_weights = var_weights, leaf_dimension = leaf_dimension, + alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, max_depth = max_depth, + leaf_model_type = outcome_model_type, leaf_model_scale = leaf_prior_scale, + cutpoint_grid_size = cutpoint_grid_size +) +global_model_config <- createGlobalModelConfig(global_error_variance = global_variance_init) +forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) # "Active forest" (which gets updated by the sample) and # container of forest samples (which is written to when @@ -145,6 +154,10 @@ if (leaf_regression) { forest_samples <- createForestSamples(num_trees, 1, T) active_forest <- createForest(num_trees, 1, T) } + +# Initialize the leaves of each tree in the forest +active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, outcome_model_type, mean(resid)) +active_forest$adjust_residual(forest_dataset, outcome, forest_model, ifelse(outcome_model_type==1, T, F), F) ``` Prepare to run the sampler @@ -163,21 +176,23 @@ Run the grow-from-root sampler to "warm-start" BART for (i in 1:num_warmstart) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = T + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, global_model_config, keep_forest = T, gfr = T ) # Sample global variance parameter - global_var_samples[i+1] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset, rng, nu, lambda ) + global_var_samples[i+1] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] + forest_model_config$update_leaf_model_scale(leaf_prior_scale) } ``` @@ -188,21 +203,23 @@ scale parameters) with an MCMC sampler for (i in (num_warmstart+1):num_samples) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = F + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, global_model_config, keep_forest = T, gfr = F ) # Sample global variance parameter - global_var_samples[i+1] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset, rng, nu, lambda ) + global_var_samples[i+1] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] + forest_model_config$update_leaf_model_scale(leaf_prior_scale) } ``` @@ -280,9 +297,10 @@ min_samples_leaf <- 1 max_depth <- 10 num_trees <- 100 cutpoint_grid_size = 100 -global_variance_init = 1. -tau_init = 0.5 -leaf_prior_scale = matrix(c(tau_init), ncol = 1) +global_variance_init <- 1. +current_sigma2 <- global_variance_init +tau_init <- 1/num_trees +leaf_prior_scale <- as.matrix(ifelse(p_W >= 1, diag(tau_init, p_W), diag(tau_init, 1))) nu <- 4 lambda <- 0.5 a_leaf <- 2. @@ -320,9 +338,15 @@ outcome <- createOutcome(resid) rng <- createCppRNG() # Sampling data structures -forest_model <- createForestModel(forest_dataset, feature_types, - num_trees, n, alpha, beta, - min_samples_leaf, max_depth) +forest_model_config <- createForestModelConfig( + feature_types = feature_types, num_trees = num_trees, num_features = p_X, + num_observations = n, variable_weights = var_weights, leaf_dimension = leaf_dimension, + alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, max_depth = max_depth, + leaf_model_type = outcome_model_type, leaf_model_scale = leaf_prior_scale, + cutpoint_grid_size = cutpoint_grid_size +) +global_model_config <- createGlobalModelConfig(global_error_variance = global_variance_init) +forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) # "Active forest" (which gets updated by the sample) and # container of forest samples (which is written to when @@ -335,6 +359,10 @@ if (leaf_regression) { active_forest <- createForest(num_trees, 1, T) } +# Initialize the leaves of each tree in the forest +active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, outcome_model_type, mean(resid)) +active_forest$adjust_residual(forest_dataset, outcome, forest_model, ifelse(outcome_model_type==1, T, F), F) + # Random effects dataset rfx_basis <- as.matrix(rfx_basis) group_ids <- as.integer(group_ids) @@ -376,25 +404,27 @@ Run the grow-from-root sampler to "warm-start" BART for (i in 1:num_warmstart) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = T + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, global_model_config, keep_forest = T, gfr = T ) # Sample global variance parameter - global_var_samples[i+1] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset, rng, nu, lambda ) + global_var_samples[i+1] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] + forest_model_config$update_leaf_model_scale(leaf_prior_scale) # Sample random effects model rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, - TRUE, global_var_samples[i+1], rng) + TRUE, current_sigma2, rng) } ``` @@ -405,25 +435,27 @@ scale parameters) with an MCMC sampler for (i in (num_warmstart+1):num_samples) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = F + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, global_model_config, keep_forest = T, gfr = F ) # Sample global variance parameter - global_var_samples[i+1] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset, rng, nu, lambda ) + global_var_samples[i+1] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] + forest_model_config$update_leaf_model_scale(leaf_prior_scale) # Sample random effects model rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, - TRUE, global_var_samples[i+1], rng) + TRUE, current_sigma2, rng) } ``` @@ -515,9 +547,10 @@ min_samples_leaf <- 1 max_depth <- 10 num_trees <- 100 cutpoint_grid_size = 100 -global_variance_init = 1. -tau_init = 0.5 -leaf_prior_scale = matrix(c(tau_init), ncol = 1) +global_variance_init <- 1. +current_sigma2 <- global_variance_init +tau_init <- 1/num_trees +leaf_prior_scale <- as.matrix(ifelse(p_W >= 1, diag(tau_init, p_W), diag(tau_init, 1))) nu <- 4 lambda <- 0.5 a_leaf <- 2. @@ -555,9 +588,15 @@ outcome <- createOutcome(resid) rng <- createCppRNG() # Sampling data structures -forest_model <- createForestModel(forest_dataset, feature_types, - num_trees, n, alpha, beta, - min_samples_leaf, max_depth) +forest_model_config <- createForestModelConfig( + feature_types = feature_types, num_trees = num_trees, num_features = p_X, + num_observations = n, variable_weights = var_weights, leaf_dimension = leaf_dimension, + alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, max_depth = max_depth, + leaf_model_type = outcome_model_type, leaf_model_scale = leaf_prior_scale, + cutpoint_grid_size = cutpoint_grid_size +) +global_model_config <- createGlobalModelConfig(global_error_variance = global_variance_init) +forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) # "Active forest" (which gets updated by the sample) and # container of forest samples (which is written to when @@ -570,6 +609,10 @@ if (leaf_regression) { active_forest <- createForest(num_trees, 1, T) } +# Initialize the leaves of each tree in the forest +active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, outcome_model_type, mean(resid)) +active_forest$adjust_residual(forest_dataset, outcome, forest_model, ifelse(outcome_model_type==1, T, F), F) + # Random effects dataset rfx_basis <- as.matrix(rfx_basis) group_ids <- as.integer(group_ids) @@ -611,21 +654,23 @@ Run the grow-from-root sampler to "warm-start" BART for (i in 1:num_warmstart) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = T + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, global_model_config, keep_forest = T, gfr = T ) # Sample global variance parameter - global_var_samples[i+1] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset, rng, nu, lambda ) + global_var_samples[i+1] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] + forest_model_config$update_leaf_model_scale(leaf_prior_scale) # Sample random effects model rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, @@ -640,21 +685,23 @@ scale parameters) with an MCMC sampler for (i in (num_warmstart+1):num_samples) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = F + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, global_model_config, keep_forest = T, gfr = F ) # Sample global variance parameter - global_var_samples[i+1] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset, rng, nu, lambda ) + global_var_samples[i+1] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) # Sample leaf node variance parameter and update `leaf_prior_scale` leaf_scale_samples[i+1] <- sampleLeafVarianceOneIteration( active_forest, rng, a_leaf, b_leaf ) leaf_prior_scale[1,1] <- leaf_scale_samples[i+1] + forest_model_config$update_leaf_model_scale(leaf_prior_scale) # Sample random effects model rfx_model$sample_random_effect(rfx_dataset, outcome, rfx_tracker, rfx_samples, @@ -759,9 +806,10 @@ min_samples_leaf <- 1 max_depth <- 10 num_trees <- 100 cutpoint_grid_size = 100 -global_variance_init = 1. -tau_init = 0.5 -leaf_prior_scale = matrix(c(tau_init), ncol = 1) +global_variance_init <- 1. +current_sigma2 <- global_variance_init +tau_init <- 1/num_trees +leaf_prior_scale <- as.matrix(ifelse(p_W >= 1, diag(tau_init, p_W), diag(tau_init, 1))) nu <- 4 lambda <- 0.5 a_leaf <- 2. @@ -789,9 +837,15 @@ outcome <- createOutcome(resid) rng <- createCppRNG() # Sampling data structures -forest_model <- createForestModel(forest_dataset, feature_types, - num_trees, n, alpha_bart, beta_bart, - min_samples_leaf, max_depth) +forest_model_config <- createForestModelConfig( + feature_types = feature_types, num_trees = num_trees, num_features = p_X, + num_observations = n, variable_weights = var_weights, leaf_dimension = leaf_dimension, + alpha = alpha, beta = beta, min_samples_leaf = min_samples_leaf, max_depth = max_depth, + leaf_model_type = outcome_model_type, leaf_model_scale = leaf_prior_scale, + cutpoint_grid_size = cutpoint_grid_size +) +global_model_config <- createGlobalModelConfig(global_error_variance = global_variance_init) +forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) # "Active forest" (which gets updated by the sample) and # container of forest samples (which is written to when @@ -803,6 +857,10 @@ if (leaf_regression) { forest_samples <- createForestSamples(num_trees, 1, T) active_forest <- createForest(num_trees, 1, T) } + +# Initialize the leaves of each tree in the forest +active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, outcome_model_type, mean(resid)) +active_forest$adjust_residual(forest_dataset, outcome, forest_model, ifelse(outcome_model_type==1, T, F), F) ``` Prepare to run the sampler @@ -842,15 +900,16 @@ for (i in 1:num_warmstart) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, sigma2, cutpoint_grid_size, keep_forest = T, gfr = T + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, global_model_config, keep_forest = T, gfr = T ) # Sample global variance parameter - global_var_samples[i+1] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset, rng, nu, lambda ) + global_var_samples[i+1] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) } ``` @@ -880,15 +939,16 @@ for (i in (num_warmstart+1):num_samples) { # Sample forest forest_model$sample_one_iteration( - forest_dataset, outcome, forest_samples, active_forest, rng, feature_types, - outcome_model_type, leaf_prior_scale, var_weights, - 1, 1, global_var_samples[i], cutpoint_grid_size, keep_forest = T, gfr = F + forest_dataset, outcome, forest_samples, active_forest, rng, + forest_model_config, global_model_config, keep_forest = T, gfr = F ) # Sample global variance parameter - global_var_samples[i+1] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset, rng, nu, lambda ) + global_var_samples[i+1] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) } ``` @@ -1135,14 +1195,22 @@ outcome <- createOutcome(resid) rng <- createCppRNG() # Sampling data structures -forest_model_mu <- createForestModel( - forest_dataset_mu, feature_types_mu, num_trees_mu, nrow(X_mu), - alpha_mu, beta_mu, min_samples_leaf_mu, max_depth_mu +global_model_config <- createGlobalModelConfig(global_error_variance = current_sigma2) +forest_model_config_mu <- createForestModelConfig( + feature_types = feature_types_mu, num_trees = num_trees_mu, num_features = ncol(X_mu), + num_observations = nrow(X_mu), variable_weights = variable_weights_mu, leaf_dimension = 1, + alpha = alpha_mu, beta = beta_mu, min_samples_leaf = min_samples_leaf_mu, max_depth = max_depth_mu, + leaf_model_type = 0, leaf_model_scale = current_leaf_scale_mu, cutpoint_grid_size = cutpoint_grid_size ) -forest_model_tau <- createForestModel( - forest_dataset_tau, feature_types_tau, num_trees_tau, nrow(X_tau), - alpha_tau, beta_tau, min_samples_leaf_tau, max_depth_tau +forest_model_mu <- createForestModel(forest_dataset_mu, forest_model_config_mu, global_model_config) +forest_model_config_tau <- createForestModelConfig( + feature_types = feature_types_tau, num_trees = num_trees_tau, num_features = ncol(X_tau), + num_observations = nrow(X_tau), variable_weights = variable_weights_tau, leaf_dimension = 1, + alpha = alpha_tau, beta = beta_tau, min_samples_leaf = min_samples_leaf_tau, max_depth = max_depth_tau, + leaf_model_type = 1, leaf_model_scale = current_leaf_scale_tau, + cutpoint_grid_size = cutpoint_grid_size ) +forest_model_tau <- createForestModel(forest_dataset_tau, forest_model_config_tau, global_model_config) # Container of forest samples forest_samples_mu <- createForestSamples(num_trees_mu, 1, T) @@ -1167,23 +1235,20 @@ if (num_gfr > 0){ # Sample the prognostic forest forest_model_mu$sample_one_iteration( forest_dataset_mu, outcome, forest_samples_mu, active_forest_mu, rng, - feature_types_mu, 0, current_leaf_scale_mu, variable_weights_mu, - 1, 1, current_sigma2, cutpoint_grid_size, keep_forest = T, gfr = T, - pre_initialized = T + forest_model_config_mu, global_model_config, keep_forest = T, gfr = T ) # Sample variance parameters (if requested) - global_var_samples[i] <- sampleGlobalErrorVarianceOneIteration( + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( outcome, forest_dataset_mu, rng, nu, lambda ) - current_sigma2 <- global_var_samples[i] + global_var_samples[i] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) # Sample the treatment forest forest_model_tau$sample_one_iteration( forest_dataset_tau, outcome, forest_samples_tau, active_forest_tau, rng, - feature_types_tau, 1, current_leaf_scale_tau, variable_weights_tau, - 1, 1, current_sigma2, cutpoint_grid_size, keep_forest = T, gfr = T, - pre_initialized = T + forest_model_config_tau, global_model_config, keep_forest = T, gfr = T ) # Sample adaptive coding parameters @@ -1205,8 +1270,11 @@ if (num_gfr > 0){ b_1_samples[i] <- current_b_1 # Sample variance parameters (if requested) - global_var_samples[i] <- sampleGlobalErrorVarianceOneIteration(outcome, forest_dataset_tau, rng, nu, lambda) - current_sigma2 <- global_var_samples[i] + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome, forest_dataset_tau, rng, nu, lambda + ) + global_var_samples[i] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) } } ``` @@ -1218,23 +1286,24 @@ if (num_burnin + num_mcmc > 0) { for (i in (num_gfr+1):num_samples) { # Sample the prognostic forest forest_model_mu$sample_one_iteration( - forest_dataset_mu, outcome, forest_samples_mu, active_forest_mu, rng, feature_types_mu, - 0, current_leaf_scale_mu, variable_weights_mu, 1, 1, current_sigma2, - cutpoint_grid_size, keep_forest = T, gfr = F, pre_initialized = T + forest_dataset_mu, outcome, forest_samples_mu, active_forest_mu, rng, + forest_model_config_mu, global_model_config, keep_forest = T, gfr = F ) - # Sample global variance parameter - global_var_samples[i] <- sampleGlobalErrorVarianceOneIteration(outcome, forest_dataset_mu, rng, nu, lambda) - current_sigma2 <- global_var_samples[i] + # Sample variance parameters (if requested) + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome, forest_dataset_mu, rng, nu, lambda + ) + global_var_samples[i] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) # Sample the treatment forest forest_model_tau$sample_one_iteration( - forest_dataset_tau, outcome, forest_samples_tau, active_forest_tau, rng, feature_types_tau, - 1, current_leaf_scale_tau, variable_weights_tau, 1, 1, current_sigma2, - cutpoint_grid_size, keep_forest = T, gfr = F, pre_initialized = T + forest_dataset_tau, outcome, forest_samples_tau, active_forest_tau, rng, + forest_model_config_tau, global_model_config, keep_forest = T, gfr = F ) - # Sample coding parameters + # Sample adaptive coding parameters mu_x_raw <- active_forest_mu$predict_raw(forest_dataset_mu) tau_x_raw <- active_forest_tau$predict_raw(forest_dataset_tau) s_tt0 <- sum(tau_x_raw*tau_x_raw*(Z==0)) @@ -1251,10 +1320,13 @@ if (num_burnin + num_mcmc > 0) { forest_model_tau$propagate_basis_update(forest_dataset_tau, outcome, active_forest_tau) b_0_samples[i] <- current_b_0 b_1_samples[i] <- current_b_1 - - # Sample global variance parameter - global_var_samples[i] <- sampleGlobalErrorVarianceOneIteration(outcome, forest_dataset_tau, rng, nu, lambda) - current_sigma2 <- global_var_samples[i] + + # Sample variance parameters (if requested) + current_sigma2 <- sampleGlobalErrorVarianceOneIteration( + outcome, forest_dataset_tau, rng, nu, lambda + ) + global_var_samples[i] <- current_sigma2 + global_model_config$update_global_error_variance(current_sigma2) } } ```