Skip to content

Commit

Permalink
+ comapred real parameters and sampled
Browse files Browse the repository at this point in the history
  • Loading branch information
gregorysedykh committed Jan 25, 2024
1 parent 60a6376 commit bb069a3
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 12 deletions.
Binary file added res/sample_compare_Y_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/sample_compare_Y_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/sample_compare_Y_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
38 changes: 26 additions & 12 deletions src/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@
class_samples = []

for variable_params in parameters[y]:

mean, std = variable_params
samples = np.random.normal(mean, std, SAMPLES)
class_samples.append(samples)

class_samples = np.column_stack(class_samples)
sampled_data.append(class_samples)
sampled_labels.extend([y]*SAMPLES)
sampled_labels.extend([y] * SAMPLES)

# On concatène les échantillons
sampled_data = np.vstack(sampled_data)
Expand All @@ -56,7 +55,28 @@
print(f"Mean: {mean_sampled}")
print(f"Std: {std_sampled}")

vars = [0, 1, 2, 3]
vars = [0, 1, 2, 3]

plt.figure()
plt.title(f"Courbe des distribution de probabilité sachant Y={c}")

for var in vars:
plt.plot(
np.linspace(means[var] - 10, means[var] + 10, 1000),
normal_pdf(means[var], stds[var])(np.linspace(means[var] - 10, means[var] + 10, 1000)),
label=f"X_{var} réelle",
)
plt.plot(
np.linspace(mean_sampled[var] - 10, mean_sampled[var] + 10, 1000),
normal_pdf(mean_sampled[var], std_sampled[var])(
np.linspace(mean_sampled[var] - 10, mean_sampled[var] + 10, 1000)
),
label=f"X_{var} échantillonnée",
)

plt.legend()
plt.savefig(f"src/res/sample_compare_Y_{c}")
plt.close()


print("\n\n")
Expand All @@ -66,7 +86,9 @@
X_train = pd.DataFrame(X_train)
X_test = pd.DataFrame(X_test)

sampled_params = {c: list(zip(np.mean(X_train[y_train == c], axis=0), np.std(X_train[y_train == c], axis=0))) for c in classes}
sampled_params = {
c: list(zip(np.mean(X_train[y_train == c], axis=0), np.std(X_train[y_train == c], axis=0))) for c in classes
}

# --- Notre implémentation de Naive Bayes ---
print("Notre Naive Bayes")
Expand Down Expand Up @@ -103,11 +125,3 @@
# --------------------------------------------

# ---------------------------------------------------------------------------








0 comments on commit bb069a3

Please sign in to comment.