-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpair_plot.py
82 lines (67 loc) · 2.02 KB
/
pair_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
from colors import colors
import seaborn as sns
# drops all non-numeric columns besides "Hogwarts House"
def arrange_columns(df):
# drop index column
df.drop(columns=['Index'], inplace=True)
for column in df:
if df[column].dtype.kind in 'biufc':
continue
if column == 'Hogwarts House':
continue
else:
df.drop(columns=[column], inplace=True)
# adds newlines to long label names
def adjust_plot_labels(df):
new_column_names = []
for column in df:
if column == 'Hogwarts House':
new_column_names.append(column)
elif len(column) > 15:
split_col = column.split(" ")
tmp_col_name = ""
for i in range(len(split_col)):
tmp_col_name += split_col[i]
if i % 2:
tmp_col_name += "\n"
else:
tmp_col_name += " "
new_column_names.append(tmp_col_name)
else:
new_column_names.append(column)
df.columns = new_column_names
def main():
try:
df = pd.read_csv("datasets/dataset_train.csv")
except:
print(f"{colors().RED}Error: could not read file{colors().END}")
exit()
arrange_columns(df)
adjust_plot_labels(df)
########### normalize data start (min-max) ###########
for column in df:
if df[column].dtype.kind not in 'biufc':
continue
max_norm = df[column].max()
min_norm = df[column].min()
for i in range(len(df)):
df.iloc[i, df.columns.get_loc(column)] = (df.iloc[i, df.columns.get_loc(column)] - min_norm) / (max_norm - min_norm)
########### normalize data end ###########
# scale down the font size
sns.set_style('darkgrid')
sns.set(font_scale=0.5)
# plot the pairplot and give a hue according to the "Hogwarts House" feature
pairplot = sns.pairplot(df, hue="Hogwarts House", height=0.8, palette="bright", kind="scatter", plot_kws={"s": 3})
# remove labels in sub-plots
for ax in pairplot.axes.flatten():
ax.set_xticklabels([])
ax.set_yticklabels([])
# plt.tight_layout()
plt.savefig('pair_plot.png')
plt.show()
if __name__ == '__main__':
main()