Skip to content

Commit

Permalink
Merge pull request #151 from vijay-arya/master
Browse files Browse the repository at this point in the history
protodash update
  • Loading branch information
vijay-arya authored Jun 25, 2022
2 parents b0749d1 + 5d90091 commit 93b8d34
Show file tree
Hide file tree
Showing 8 changed files with 937 additions and 12,592 deletions.
6 changes: 3 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ python:

# command to install dependencies
install:
- pip3 install --upgrade setuptools==41.0.0
- pip3 install --upgrade setuptools
- pip3 install .
#- pip3 install -r requirements.txt

Expand All @@ -20,8 +20,8 @@ script:
- python3.6 ./tests/rbm/test_Boolean_Rule_CG.py
- python3.6 ./tests/rbm/test_Linear_Rule_Regression.py
- python3.6 ./tests/rbm/test_Logistic_Rule_Regression.py
- python3.6 ./tests/lime/test_lime.py
- python3.6 ./tests/shap/test_shap.py
# - python3.6 ./tests/lime/test_lime.py
# - python3.6 ./tests/shap/test_shap.py

after_success:
# - codecov
Expand Down
7 changes: 4 additions & 3 deletions aix360/algorithms/protodash/PDASH.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def set_params(self, *argv, **kwargs):
"""
pass

def explain(self, X, Y, m, kernelType='other', sigma=2):
def explain(self, X, Y, m, kernelType='other', sigma=2, optimizer='cvxpy'):
"""
Return prototypes for data X, Y.
Expand All @@ -40,8 +40,9 @@ def explain(self, X, Y, m, kernelType='other', sigma=2):
m (int): Number of prototypes
kernelType (str): Type of kernel (viz. 'Gaussian', / 'other')
sigma (double): width of kernel
optimizer (string): qpsolver ('cvxpy' or 'osqp')
Returns:
m selected prototypes from X and their (unnormalized) importance weights
"""
return( HeuristicSetSelection(X, Y, m, kernelType, sigma) )
return( HeuristicSetSelection(X, Y, m, kernelType, sigma, optimizer) )
128 changes: 103 additions & 25 deletions aix360/algorithms/protodash/PDASH_utils.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,124 @@
from __future__ import print_function
import numpy as np
from numpy import array, ndarray
import xport
from sklearn.preprocessing import OneHotEncoder
from cvxopt.solvers import qp
from cvxopt import matrix, spmatrix
from numpy import array, ndarray
from scipy.spatial.distance import cdist
from qpsolvers import solve_qp

def runOptimiser(K, u, preOptw, initialValue, maxWeight=10000):
import cvxpy as cp
from scipy import sparse
from osqp import OSQP


# Removing dependence on cxvopt and qpsolver packages.
# from cvxopt.solvers import qp
# from cvxopt import matrix, spmatrix
# from qpsolvers import solve_qp
# def runOptimiser(K, u, preOptw, initialValue, maxWeight=10000):
# """
# Args:
# K (double 2d array): Similarity/distance matrix
# u (double array): Mean similarity of each prototype
# preOptw (double): Weight vector
# initialValue (double): Initialize run
# maxWeight (double): Upper bound on weight

# Returns:
# Prototypes, weights and objective values
# """
# d = u.shape[0]
# lb = np.zeros((d, 1))
# ub = maxWeight * np.ones((d, 1))
# x0 = np.append( preOptw, initialValue/K[d-1, d-1] )

# G = np.vstack((np.identity(d), -1*np.identity(d)))
# h = np.vstack((ub, -1*lb))

# # Solve a QP defined as follows:
# # minimize
# # (1/2) * x.T * P * x + q.T * x
# # subject to
# # G * x <= h
# # A * x == b
# sol = solve_qp(K, -u, G, h, A=None, b=None, solver='cvxopt', initvals=x0)

# # compute objective function value
# x = sol.reshape(sol.shape[0], 1)
# P = K
# q = - u.reshape(u.shape[0], 1)
# obj_value = 1/2 * np.matmul(np.matmul(x.T, P), x) + np.matmul(q.T, x)
# return(sol, obj_value[0,0])


def runOptimiser(K, u, preOptw, initialValue, optimizer, maxWeight=10000):
"""
Args:
K (double 2d array): Similarity/distance matrix
u (double array): Mean similarity of each prototype
preOptw (double): Weight vector
initialValue (double): Initialize run
optimizer (string): qpsolver ('cvxpy' or 'osqp')
maxWeight (double): Upper bound on weight
Returns:
Prototypes, weights and objective values
"""
d = u.shape[0]
lb = np.zeros((d, 1))
ub = maxWeight * np.ones((d, 1))
x0 = np.append( preOptw, initialValue/K[d-1, d-1] )

G = np.vstack((np.identity(d), -1*np.identity(d)))
h = np.vstack((ub, -1*lb))

# Solve a QP defined as follows:

# Standard QP:
# minimize
# (1/2) * x.T * P * x + q.T * x
# subject to
# G * x <= h
# A * x == b
sol = solve_qp(K, -u, G, h, A=None, b=None, solver='cvxopt', initvals=x0)

# compute objective function value
x = sol.reshape(sol.shape[0], 1)

# QP Solved by Protodash:
# minimize
# (1/2) * x.T * K * x + (-u).T * x
# subject to
# G * x <= h

assert (optimizer=='cvxpy' or optimizer=='osqp'), "Please set optimizer as 'cvxpy' or 'osqp'"

d = u.shape[0]
lb = np.zeros((d, 1))
ub = maxWeight * np.ones((d, 1))

# x0 = initial value, provided optimizer supports it.
x0 = np.append( preOptw, initialValue/K[d-1, d-1] )

G = np.vstack((np.identity(d), -1*np.identity(d)))
h = np.vstack((ub, -1*lb)).ravel()

# variable shapes: K = (d,d), u = (d,) G = (2d, d), h = (2d,)

if (optimizer == 'cvxpy'):
x = cp.Variable(d)
prob = cp.Problem(cp.Minimize((1/2)*cp.quad_form(x, K) + (-u).T@x), [G@x <= h])
prob.solve()

xv = x.value.reshape(-1, 1)
xreturn = x.value

elif (optimizer == 'osqp'):

Ks = sparse.csc_matrix(K)
Gs = sparse.csc_matrix(G)
l_inf = -np.inf * np.ones(len(h))

solver = OSQP()
solver.setup(P=Ks, q=-u, A=Gs, l=l_inf, u=h, eps_abs=1e-4, eps_rel=1e-4, polish= True, verbose=False)
solver.warm_start(x=x0)
res = solver.solve()

xv = res.x.reshape(-1, 1)
xreturn = res.x

# compute objective function value
P = K
q = - u.reshape(u.shape[0], 1)
obj_value = 1/2 * np.matmul(np.matmul(x.T, P), x) + np.matmul(q.T, x)
return(sol, obj_value[0,0])
q = - u.reshape(-1, 1)
obj_value = 1/2 * np.matmul(np.matmul(xv.T, P), xv) + np.matmul(q.T, xv)

return(xreturn, obj_value[0,0])



def get_Processed_NHANES_Data(filename):
Expand Down Expand Up @@ -100,7 +177,7 @@ def get_Gaussian_Data(nfeat, numX, numY):

# expects X & Y in (observations x features) format

def HeuristicSetSelection(X, Y, m, kernelType, sigma):
def HeuristicSetSelection(X, Y, m, kernelType, sigma, optimizer):
"""
Main prototype selection function.
Expand All @@ -110,6 +187,7 @@ def HeuristicSetSelection(X, Y, m, kernelType, sigma):
m (double): Number of prototypes
kernelType (str): Gaussian, linear or other
sigma (double): Gaussian kernel width
optimizer (string): qpsolver ('cvxpy' or 'osqp')
Returns:
Current optimum, the prototypes and objective values throughout selection
Expand Down Expand Up @@ -218,7 +296,7 @@ def HeuristicSetSelection(X, Y, m, kernelType, sigma):
newCurrOptw = np.append(currOptw, [0], axis=0)
newCurrSetValue = currSetValue
else:
[newCurrOptw, value] = runOptimiser(currK, curru, currOptw, maxGradient)
[newCurrOptw, value] = runOptimiser(currK, curru, currOptw, maxGradient, optimizer)
newCurrSetValue = -value

currOptw = newCurrOptw
Expand Down
981 changes: 196 additions & 785 deletions examples/protodash/Protodash-CDC.ipynb

Large diffs are not rendered by default.

11,488 changes: 104 additions & 11,384 deletions examples/tutorials/CDC.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 93b8d34

Please sign in to comment.