Skip to content

Commit

Permalink
ENH modified poltting files
Browse files Browse the repository at this point in the history
  • Loading branch information
QB3 committed Sep 7, 2020
1 parent 8da0400 commit 069c664
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 33 deletions.
8 changes: 5 additions & 3 deletions expes/expe_fig_cross_val/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,24 @@
plt.semilogx(
p_alphas, objs, color=current_palette[0], linewidth=7.0)
plt.semilogx(
p_alphas, objs, 'bo', label='0-order method (grid-search)',
p_alphas, objs, 'bo', label='0-order (grid-search)',
color=current_palette[1], markersize=15)
plt.semilogx(
p_alphas_grad, objs_grad, 'bX', label='1-st order method',
p_alphas_grad, objs_grad, 'bX', label='1-st order',
color=current_palette[2], markersize=25)
plt.xlabel(r"$\lambda / \lambda_{\max}$", fontsize=28)
plt.ylabel(
r"$\|y^{\rm{val}} - X^{\rm{val}} \hat \beta^{(\lambda)} \|^2$",
fontsize=28)
plt.tick_params(width=5)
plt.legend(fontsize=28, loc=1)
plt.legend(fontsize=28)
plt.tight_layout()

if save_fig:
fig.savefig(
fig_dir + "cross_val_and_grad_search_real_sim.pdf", bbox_inches="tight")
fig.savefig(
fig_dir + "cross_val_and_grad_search_real_sim.png", bbox_inches="tight")
fig.savefig(
fig_dir_svg + "cross_val_and_grad_search_real_sim.svg", bbox_inches="tight")
fig.show()
Expand Down
67 changes: 37 additions & 30 deletions expes/expe_linear_convergence/plot_slides.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,45 +28,52 @@
dict_title["real-sim"] = "real-sim"

plt.close('all')
fig, axarr = plt.subplots(
2, 4, sharex=False, sharey=False, figsize=[14, 8],)

lines = []
model_names = ["lasso", "logreg", "svm"]


for idx, dataset in enumerate(dataset_names):
df_data = pandas.read_pickle("%s.pkl" % dataset)
diff_beta = df_data["diff_beta"].to_numpy()[0]
diff_jac = df_data["diff_jac"].to_numpy()[0]
supp_id = df_data["supp_id"].to_numpy()[0]
#
axarr.flat[idx].semilogy(diff_beta)
lines.append(axarr.flat[idx].axvline(
x=supp_id, c='red', linestyle="--", label="Support identification"))
for model_name in model_names:

axarr.flat[idx+4].semilogy(diff_jac)
axarr.flat[idx+4].axvline(x=supp_id, c='red', linestyle="--")
fig, axarr = plt.subplots(
2, 4, sharex=False, sharey=False, figsize=[14, 8],)

axarr.flat[idx+4].set_xlabel(r"$\#$ epochs", size=fontsize)
lines = []

axarr.flat[idx].set_title("%s" % (
dict_title[dataset]), size=fontsize)
# xarr.flat[idx].set_title("%s %s" % (
# dict_title[dataset], dict_n_feature[dataset]), size=fontsize)
for idx, dataset in enumerate(dataset_names):
df_data = pandas.read_pickle("%s_%s.pkl" % (dataset, model_name))
diff_beta = df_data["diff_beta"].to_numpy()[0]
diff_jac = df_data["diff_jac"].to_numpy()[0]
supp_id = df_data["supp_id"].to_numpy()[0]
#
axarr.flat[idx].semilogy(diff_beta)
lines.append(axarr.flat[idx].axvline(
x=supp_id, c='red', linestyle="--", label="Support identification"))

axarr.flat[0].set_ylabel(
r"$||\beta^{(k)} - \hat \beta||$", fontsize=fontsize)
axarr.flat[4].set_ylabel(
r"$||\mathcal{J}^{(k)} - \hat \mathcal{J}||$", fontsize=fontsize)
axarr.flat[idx+4].semilogy(diff_jac)
axarr.flat[idx+4].axvline(x=supp_id, c='red', linestyle="--")

fig.tight_layout()
axarr.flat[idx+4].set_xlabel(r"$\#$ epochs", size=fontsize)

if save_fig:
fig.savefig(
fig_dir + "linear_convergence_lasso.pdf", bbox_inches="tight")
fig.savefig(
fig_dir_svg + "linear_convergence_lasso.svg", bbox_inches="tight")
fig.show()
axarr.flat[idx].set_title("%s" % (
dict_title[dataset]), size=fontsize)
# xarr.flat[idx].set_title("%s %s" % (
# dict_title[dataset], dict_n_feature[dataset]), size=fontsize)

axarr.flat[0].set_ylabel(
r"$||\beta^{(k)} - \hat \beta||$", fontsize=fontsize)
axarr.flat[4].set_ylabel(
r"$||\mathcal{J}^{(k)} - \hat \mathcal{J}||$", fontsize=fontsize)

fig.tight_layout()

if save_fig:
fig.savefig(
fig_dir + "linear_convergence_%s.pdf" % model_name,
bbox_inches="tight")
fig.savefig(
fig_dir_svg + "linear_convergence_%s.svg" % model_name,
bbox_inches="tight")
fig.show()


labels = ["Support identification"]
Expand Down

0 comments on commit 069c664

Please sign in to comment.