diff --git a/NAMESPACE b/NAMESPACE index 0158412..bf7c43e 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -8,6 +8,7 @@ S3method(merge,semforest) S3method(nobs,semtree) S3method(partialDependence,semforest) S3method(partialDependence,semforest_stripped) +S3method(plot,boruta) S3method(plot,diversityMatrix) S3method(plot,partialDependence) S3method(plot,semforest.proximity) diff --git a/R/boruta.R b/R/boruta.R index ec87a56..71059c1 100644 --- a/R/boruta.R +++ b/R/boruta.R @@ -43,10 +43,10 @@ boruta <- function(model, tmp <- getPredictorsOpenMx(mxmodel=model, dataset=data, covariates=predictors) model.ids <- tmp[[1]] covariate.ids <- tmp[[2]] - # } else if (inherits(model,"lavaan")){ - - # } else if ((inherits(model,"ctsemFit")) || (inherits(model,"ctsemInit"))) { - # + } else if (inherits(model,"lavaan")){ + tmp <- getPredictorsLavaan(model, dataset=data, covariates=predictors) + model.ids <- tmp[[1]] + covariate.ids <- tmp[[2]] } else { ui_stop("Unknown model type selected. Use OpenMx or lavaanified lavaan models!"); } @@ -172,7 +172,7 @@ plot.boruta = function(vim, type=0, ...) { tidyr::pivot_longer(cols = -last_col()) |> #everything()) |> dplyr::left_join(data.frame(decisionList), by=join_by("name"=="predictor")) |> dplyr::mutate(decision = case_when(is.na(decision)~"Shadow", .default=decision)) |> - dplyr::group_by(name) |> mutate(median_value = median(value,na.rm=TRUE)) + dplyr::group_by(name) |> dplyr::mutate(median_value = median(value,na.rm=TRUE)) if (type==0) { ggplot2::ggplot(impHistory, @@ -181,12 +181,16 @@ plot.boruta = function(vim, type=0, ...) { ggplot2::geom_boxplot()+ ggplot2::xlab("")+ ggplot2::ylab("Importance")+ - scale_color_discrete(name = "Predictor")+ - theme(axis.text.x = element_text(angle = 45, hjust = 1)) + ggplot2::scale_color_discrete(name = "Decision")+ + ggplot2::theme(axis.text.x = element_text(angle = 45, hjust = 1)) } else if (type==1) { ggplot2::ggplot(impHistory, - aes(x=rnd, y=value,group=name,col=name))+geom_line()+ geom_hline(aes(yintercept=median_value,col=name),lwd=2)+ - xlab("Round")+ylab("Importance")+scale_color_discrete(name = "Predictor") + ggplot2::aes(x=rnd, y=value,group=name,col=name))+ + ggplot2::geom_line()+ + ggplot2::geom_hline(aes(yintercept=median_value,col=name),lwd=2)+ + ggplot2::xlab("Round")+ + ggplot2::ylab("Importance")+ + ggplot2::scale_color_discrete(name = "Predictor") } else { stop("Unknown graph type. Please choose 0 or 1.") }