-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimpurity.py
41 lines (32 loc) · 901 Bytes
/
impurity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
from util import weight, group, make_pairs, make_nonexc
def impurepairs(Y, idx, w):
'''#impure pair'''
return make_pairs(Y, idx)
def nonexcluded(Y, idx, w):
'''sum of #objects in other classes over each object'''
return make_nonexc(Y[idx], sorted=False, aggregate=True)
def entropy(Y, idx, w):
if len(idx) == 0:
return 0
Y = Y[idx]
iY = np.argsort(Y)
Y, idx = Y[iY], idx[iY]
E = 0
ptot = weight(idx, w)
for gidx in [idx[b:b+l] for b,l in group(Y)]:
p = weight(gidx, w) / ptot
E = E - p * np.log2(p)
return E
def gini(Y, idx, w):
if len(idx) == 0:
return 0
Y = Y[idx]
iY = np.argsort(Y)
Y, idx = Y[iY], idx[iY]
sum = 0
ptot = weight(idx, w)
for gidx in [idx[b:b+l] for b,l in group(Y)]:
p = weight(gidx, w) / ptot
sum = sum + p**2
return 1 - sum