Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor model parameters into "config" objects to future-proof low-level interface #135

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ export(createCppRNG)
export(createForest)
export(createForestDataset)
export(createForestModel)
export(createForestModelConfig)
export(createForestSamples)
export(createGlobalModelConfig)
export(createOutcome)
export(createPreprocessorFromJson)
export(createPreprocessorFromJsonString)
Expand Down
59 changes: 43 additions & 16 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -540,11 +540,22 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train

# Sampling data structures
feature_types <- as.integer(feature_types)
global_model_config <- createGlobalModelConfig(global_error_variance=current_sigma2)
if (include_mean_forest) {
forest_model_mean <- createForestModel(forest_dataset_train, feature_types, num_trees_mean, nrow(X_train), alpha_mean, beta_mean, min_samples_leaf_mean, max_depth_mean)
forest_model_config_mean <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_mean, num_features=ncol(X_train),
num_observations=nrow(X_train), variable_weights=variable_weights_mean, leaf_dimension=leaf_dimension,
alpha=alpha_mean, beta=beta_mean, min_samples_leaf=min_samples_leaf_mean, max_depth=max_depth_mean,
leaf_model_type=leaf_model_mean_forest, leaf_model_scale=current_leaf_scale,
cutpoint_grid_size=cutpoint_grid_size)
forest_model_mean <- createForestModel(forest_dataset_train, forest_model_config_mean, global_model_config)
}
if (include_variance_forest) {
forest_model_variance <- createForestModel(forest_dataset_train, feature_types, num_trees_variance, nrow(X_train), alpha_variance, beta_variance, min_samples_leaf_variance, max_depth_variance)
forest_model_config_variance <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees_variance, num_features=ncol(X_train),
num_observations=nrow(X_train), variable_weights=variable_weights_variance, leaf_dimension=1,
alpha=alpha_variance, beta=beta_variance, min_samples_leaf=min_samples_leaf_variance,
max_depth=max_depth_variance, leaf_model_type=leaf_model_variance_forest,
cutpoint_grid_size=cutpoint_grid_size)
forest_model_variance <- createForestModel(forest_dataset_train, forest_model_config_variance, global_model_config)
}

# Container of forest samples
Expand Down Expand Up @@ -601,11 +612,13 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
if (requires_basis) init_values_mean_forest <- rep(0., ncol(leaf_basis_train))
else init_values_mean_forest <- 0.
active_forest_mean$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mean, leaf_model_mean_forest, init_values_mean_forest)
active_forest_mean$adjust_residual(forest_dataset_train, outcome_train, forest_model_mean, requires_basis, F)
}

# Initialize the leaves of each tree in the variance forest
if (include_variance_forest) {
active_forest_variance$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_variance, leaf_model_variance_forest, variance_forest_init)

}

# Run GFR (warm start) if specified
Expand All @@ -624,26 +637,28 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train

if (include_mean_forest) {
forest_model_mean$sample_one_iteration(
forest_dataset_train, outcome_train, forest_samples_mean, active_forest_mean,
rng, feature_types, leaf_model_mean_forest, current_leaf_scale, variable_weights_mean,
a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean,
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = T
)
}
if (include_variance_forest) {
forest_model_variance$sample_one_iteration(
forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance,
rng, feature_types, leaf_model_variance_forest, current_leaf_scale, variable_weights_variance,
a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = T, pre_initialized = T
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance,
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = T
)
}
if (sample_sigma_global) {
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
if (keep_sample) global_var_samples[sample_counter] <- current_sigma2
global_model_config$update_global_error_variance(current_sigma2)
}
if (sample_sigma_leaf) {
leaf_scale_double <- sampleLeafVarianceOneIteration(active_forest_mean, rng, a_leaf, b_leaf)
current_leaf_scale <- as.matrix(leaf_scale_double)
if (keep_sample) leaf_scale_samples[sample_counter] <- leaf_scale_double
forest_model_config_mean$update_leaf_model_scale(current_leaf_scale)
}
if (has_rfx) {
rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, keep_sample, current_sigma2, rng)
Expand All @@ -663,6 +678,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
if (sample_sigma_leaf) {
leaf_scale_double <- leaf_scale_samples[forest_ind + 1]
current_leaf_scale <- as.matrix(leaf_scale_double)
forest_model_config_mean$update_leaf_model_scale(current_leaf_scale)
}
}
if (include_variance_forest) {
Expand All @@ -673,14 +689,18 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
resetRandomEffectsModel(rfx_model, rfx_samples, forest_ind, sigma_alpha_init)
resetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train, rfx_samples)
}
if (sample_sigma_global) current_sigma2 <- global_var_samples[forest_ind + 1]
if (sample_sigma_global) {
current_sigma2 <- global_var_samples[forest_ind + 1]
global_model_config$update_global_error_variance(current_sigma2)
}
} else if (has_prev_model) {
if (include_mean_forest) {
resetActiveForest(active_forest_mean, previous_forest_samples_mean, previous_model_warmstart_sample_num - 1)
resetForestModel(forest_model_mean, active_forest_mean, forest_dataset_train, outcome_train, TRUE)
if (sample_sigma_leaf && (!is.null(previous_leaf_var_samples))) {
leaf_scale_double <- previous_leaf_var_samples[previous_model_warmstart_sample_num]
current_leaf_scale <- as.matrix(leaf_scale_double)
forest_model_config_mean$update_leaf_model_scale(current_leaf_scale)
}
}
if (include_variance_forest) {
Expand All @@ -696,6 +716,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
if (sample_sigma_global) {
if (!is.null(previous_global_var_samples)) {
current_sigma2 <- previous_global_var_samples[previous_model_warmstart_sample_num]
global_model_config$update_global_error_variance(current_sigma2)
}
}
} else {
Expand All @@ -705,6 +726,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
resetForestModel(forest_model_mean, active_forest_mean, forest_dataset_train, outcome_train, TRUE)
if (sample_sigma_leaf) {
current_leaf_scale <- as.matrix(sigma_leaf_init)
forest_model_config_mean$update_leaf_model_scale(current_leaf_scale)
}
}
if (include_variance_forest) {
Expand All @@ -717,7 +739,10 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
sigma_xi_init, sigma_xi_shape, sigma_xi_scale)
rootResetRandomEffectsTracker(rfx_tracker_train, rfx_model, rfx_dataset_train, outcome_train)
}
if (sample_sigma_global) current_sigma2 <- sigma2_init
if (sample_sigma_global) {
current_sigma2 <- sigma2_init
global_model_config$update_global_error_variance(current_sigma2)
}
}
for (i in (num_gfr+1):num_samples) {
is_mcmc <- i > (num_gfr + num_burnin)
Expand Down Expand Up @@ -746,26 +771,28 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train

if (include_mean_forest) {
forest_model_mean$sample_one_iteration(
forest_dataset_train, outcome_train, forest_samples_mean, active_forest_mean,
rng, feature_types, leaf_model_mean_forest, current_leaf_scale, variable_weights_mean,
a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean,
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = F
)
}
if (include_variance_forest) {
forest_model_variance$sample_one_iteration(
forest_dataset_train, outcome_train, forest_samples_variance, active_forest_variance,
rng, feature_types, leaf_model_variance_forest, current_leaf_scale, variable_weights_variance,
a_forest, b_forest, current_sigma2, cutpoint_grid_size, keep_forest = keep_sample, gfr = F, pre_initialized = T
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance,
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = F
)
}
if (sample_sigma_global) {
current_sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome_train, forest_dataset_train, rng, a_global, b_global)
if (keep_sample) global_var_samples[sample_counter] <- current_sigma2
global_model_config$update_global_error_variance(current_sigma2)
}
if (sample_sigma_leaf) {
leaf_scale_double <- sampleLeafVarianceOneIteration(active_forest_mean, rng, a_leaf, b_leaf)
current_leaf_scale <- as.matrix(leaf_scale_double)
if (keep_sample) leaf_scale_samples[sample_counter] <- leaf_scale_double
forest_model_config_mean$update_leaf_model_scale(current_leaf_scale)
}
if (has_rfx) {
rfx_model$sample_random_effect(rfx_dataset_train, outcome_train, rfx_tracker_train, rfx_samples, keep_sample, current_sigma2, rng)
Expand Down
Loading
Loading