diff --git a/dice_ml/explainer_interfaces/dice_tensorflow1.py b/dice_ml/explainer_interfaces/dice_tensorflow1.py index df17b561..6fc8a605 100644 --- a/dice_ml/explainer_interfaces/dice_tensorflow1.py +++ b/dice_ml/explainer_interfaces/dice_tensorflow1.py @@ -61,20 +61,22 @@ def __init__(self, data_interface, model_interface): self.loss_weights = [] # yloss_type, diversity_loss_type, feature_weights self.optimizer_weights = [] # optimizer - def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opposite", proximity_weight=0.5, - diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF", - features_to_vary="all", permitted_range=None, yloss_type="hinge_loss", - diversity_loss_type="dpp_style:inverse_dist", feature_weights="inverse_mad", - optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500, max_iter=5000, - project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False, - init_near_query_instance=True, tie_random=False, stopping_threshold=0.5, - posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000): + def _generate_counterfactuals(self, query_instance, total_CFs, + desired_class="opposite", desired_range=None, + proximity_weight=0.5, + diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF", features_to_vary="all", + permitted_range=None, yloss_type="hinge_loss", diversity_loss_type="dpp_style:inverse_dist", + feature_weights="inverse_mad", optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500, + max_iter=5000, project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False, + init_near_query_instance=True, tie_random=False, stopping_threshold=0.5, + posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000): """Generates diverse counterfactual explanations :param query_instance: Test point of interest. A dictionary of feature names and values or a single row dataframe. :param total_CFs: Total number of counterfactuals required. :param desired_class: Desired counterfactual class - can take 0 or 1. Default value is "opposite" to the outcome class of query_instance for binary classification. + :param desired_range: Not supported currently. :param proximity_weight: A positive float. Larger this weight, more close the counterfactuals are to the query_instance. :param diversity_weight: A positive float. Larger this weight, more diverse the counterfactuals are. @@ -159,7 +161,7 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp loss_diff_thres, loss_converge_maxiter, verbose, init_near_query_instance, tie_random, stopping_threshold, posthoc_sparsity_param, posthoc_sparsity_algorithm) - counterfactual_explanations = exp.CounterfactualExamples( + return exp.CounterfactualExamples( data_interface=self.data_interface, final_cfs_df=final_cfs_df, test_instance_df=test_instance_df, @@ -167,8 +169,6 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp posthoc_sparsity_param=posthoc_sparsity_param, desired_class=desired_class) - return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations]) - def do_cf_initializations(self, total_CFs, algorithm, features_to_vary): """Intializes TF variables required for CF generation.""" diff --git a/dice_ml/explainer_interfaces/dice_tensorflow2.py b/dice_ml/explainer_interfaces/dice_tensorflow2.py index aca80ade..20f0e702 100644 --- a/dice_ml/explainer_interfaces/dice_tensorflow2.py +++ b/dice_ml/explainer_interfaces/dice_tensorflow2.py @@ -49,20 +49,22 @@ def __init__(self, data_interface, model_interface): self.hyperparameters = [1, 1, 1] # proximity_weight, diversity_weight, categorical_penalty self.optimizer_weights = [] # optimizer, learning_rate - def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opposite", proximity_weight=0.5, - diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF", - features_to_vary="all", permitted_range=None, yloss_type="hinge_loss", - diversity_loss_type="dpp_style:inverse_dist", feature_weights="inverse_mad", - optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500, max_iter=5000, - project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False, - init_near_query_instance=True, tie_random=False, stopping_threshold=0.5, - posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000): + def _generate_counterfactuals(self, query_instance, total_CFs, + desired_class="opposite", desired_range=None, + proximity_weight=0.5, + diversity_weight=1.0, categorical_penalty=0.1, algorithm="DiverseCF", features_to_vary="all", + permitted_range=None, yloss_type="hinge_loss", diversity_loss_type="dpp_style:inverse_dist", + feature_weights="inverse_mad", optimizer="tensorflow:adam", learning_rate=0.05, min_iter=500, + max_iter=5000, project_iter=0, loss_diff_thres=1e-5, loss_converge_maxiter=1, verbose=False, + init_near_query_instance=True, tie_random=False, stopping_threshold=0.5, + posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", limit_steps_ls=10000): """Generates diverse counterfactual explanations :param query_instance: Test point of interest. A dictionary of feature names and values or a single row dataframe :param total_CFs: Total number of counterfactuals required. :param desired_class: Desired counterfactual class - can take 0 or 1. Default value is "opposite" to the outcome class of query_instance for binary classification. + :param desired_range: Not supported currently. :param proximity_weight: A positive float. Larger this weight, more close the counterfactuals are to the query_instance. :param diversity_weight: A positive float. Larger this weight, more diverse the counterfactuals are. @@ -136,7 +138,7 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp init_near_query_instance, tie_random, stopping_threshold, posthoc_sparsity_param, posthoc_sparsity_algorithm, limit_steps_ls) - counterfactual_explanations = exp.CounterfactualExamples( + return exp.CounterfactualExamples( data_interface=self.data_interface, final_cfs_df=final_cfs_df, test_instance_df=test_instance_df, @@ -144,7 +146,6 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp posthoc_sparsity_param=posthoc_sparsity_param, desired_class=desired_class) - return CounterfactualExplanations(cf_examples_list=[counterfactual_explanations]) def predict_fn(self, input_instance): """prediction function"""