Skip to content

Commit

Permalink
Add weights to solver
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesYang007 committed Nov 7, 2023
1 parent 5f674b4 commit 5983539
Show file tree
Hide file tree
Showing 42 changed files with 400 additions and 406 deletions.
2 changes: 0 additions & 2 deletions adelie/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@ def create_test_data_basil(
/ snr
)
y = X @ beta + noise_scale * np.random.normal(0, 1, n)
X /= np.sqrt(n)
y /= np.sqrt(n)

return {
"X": X,
Expand Down
6 changes: 3 additions & 3 deletions adelie/diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def residuals(
n, p = X.rows(), X.cols()
betas = state.betas
intercepts = state.intercepts
Xbs = np.empty((betas.shape[0], n))
X.sp_btmul(0, p, betas, Xbs)
resids = y[None] - Xbs - intercepts[:, None]
WXbs = np.empty((betas.shape[0], n))
X.sp_btmul(0, p, betas, state.weights, WXbs)
resids = (state.weights * y)[None] - WXbs - (state.weights[None] * intercepts[:, None])
return resids


Expand Down
67 changes: 33 additions & 34 deletions adelie/research/pivot_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,32 +85,45 @@ def _screen_sets_safe(
):
X = state.X
X_means = state.X_means
X_group_norms = state.X_group_norms
penalty = state.penalty
weights = state.weights
betas = state.betas
lmdas = state.lmdas
intercept = state.intercept

assert np.all(penalty > 0)

G, n_lmdas, n, p = X_group_norms.shape[0], betas.shape[0], X.rows(), X.cols()
resids = resids / np.sqrt(weights)

G, n_lmdas, n, p = penalty.shape[0], betas.shape[0], X.rows(), X.cols()

Xd = np.empty((n, p), order="F")
X.to_dense(0, p, Xd)
if intercept:
Xd -= X_means[None]
Xd *= np.sqrt(weights)[:, None]
X_group_norms = np.array([
np.linalg.norm(Xd[:, g:g+gs])
for g, gs in zip(state.groups, state.group_sizes)
])

edpp_resid_0 = y
if intercept:
edpp_resid_0 = y - np.mean(y)
edpp_resid_0 = y - np.sum(y * weights)
edpp_resid_0 = edpp_resid_0 * np.sqrt(weights)
edpp_grad = np.empty(p)
X.mul(edpp_resid_0, edpp_grad)
X.mul(np.sqrt(weights) * edpp_resid_0, edpp_grad)
edpp_abs_grad = np.array([
np.linalg.norm(edpp_grad[g:g+gs])
for g, gs in zip(state.groups, state.group_sizes)
])
g_star = np.argmax(edpp_abs_grad / penalty)
tmp = np.empty(state.group_sizes[g_star])
X.bmul(state.groups[g_star], state.group_sizes[g_star], edpp_resid_0, tmp)
X.bmul(state.groups[g_star], state.group_sizes[g_star], np.sqrt(weights) * edpp_resid_0, tmp)
edpp_v1_0 = np.empty(n)
X.btmul(state.groups[g_star], state.group_sizes[g_star], tmp, edpp_v1_0)
X.btmul(state.groups[g_star], state.group_sizes[g_star], tmp, np.sqrt(weights), edpp_v1_0)
if intercept:
edpp_v1_0 -= np.sum(tmp * X_means[
edpp_v1_0 -= np.sqrt(weights) * np.sum(tmp * X_means[
state.groups[g_star] :
state.groups[g_star] + state.group_sizes[g_star]
])
Expand All @@ -123,7 +136,7 @@ def _screen_sets_safe(
v2_perp_norms = np.linalg.norm(v2_perps, axis=-1)
edpps = np.empty((n_lmdas-1, p))
for i in range(n_lmdas-1):
X.mul((resids[i] / lmdas[i] + 0.5 * v2_perps[i]), edpps[i])
X.mul(np.sqrt(weights) * (resids[i] / lmdas[i] + 0.5 * v2_perps[i]), edpps[i])
abs_edpps = ad.diagnostic.gradient_norms(state, grads=edpps)
is_edpp = (
abs_edpps >= (penalty[None] - 0.5 * v2_perp_norms[:, None] * X_group_norms[None])
Expand Down Expand Up @@ -551,10 +564,9 @@ def arcene(path):
subset = np.std(X_train, axis=0) > 0
X_train = X_train[:, subset]

n, _ = X_train.shape
X_train /= np.std(X_train, axis=0)[None] * np.sqrt(n)
X_train /= np.std(X_train, axis=0)[None]
X_train = np.asfortranarray(X_train)
y_train /= np.std(y_train) * np.sqrt(n)
y_train /= np.std(y_train)

return X_train, y_train

Expand Down Expand Up @@ -589,10 +601,9 @@ def csv_to_csr(f, p, delimiter=" "):
subset = np.std(X_train, axis=0) > 0
X_train = X_train[:, subset]

n, _ = X_train.shape
X_train /= np.std(X_train, axis=0)[None] * np.sqrt(n)
X_train /= np.std(X_train, axis=0)[None]
X_train = np.asfortranarray(X_train)
y_train /= np.std(y_train) * np.sqrt(n)
y_train /= np.std(y_train)

return X_train, y_train

Expand All @@ -608,10 +619,9 @@ def gisette(path):
subset = np.std(X_train, axis=0) > 0
X_train = X_train[:, subset]

n, _ = X_train.shape
X_train /= np.std(X_train, axis=0)[None] * np.sqrt(n)
X_train /= np.std(X_train, axis=0)[None]
X_train = np.asfortranarray(X_train)
y_train /= np.std(y_train) * np.sqrt(n)
y_train /= np.std(y_train)

return X_train, y_train

Expand All @@ -627,12 +637,9 @@ def mnist(path):
subset = np.std(X_train, axis=0) > 0
X_train = X_train[:, subset]

n, _ = X_train.shape
X_train /= np.std(X_train, axis=0)[None] * np.sqrt(n)
X_train /= np.std(X_train, axis=0)[None]
X_train = np.asfortranarray(X_train)
y_train /= np.std(y_train) * np.sqrt(n)

X_train.shape, y_train.shape
y_train /= np.std(y_train)

return X_train, y_train

Expand All @@ -652,10 +659,9 @@ def conv(x):
subset = np.std(X_train, axis=0) > 0
X_train = X_train[:, subset]

n, _ = X_train.shape
X_train /= np.std(X_train, axis=0)[None] * np.sqrt(n)
X_train /= np.std(X_train, axis=0)[None]
X_train = np.asfortranarray(X_train)
y_train /= np.std(y_train) * np.sqrt(n)
y_train /= np.std(y_train)

return X_train, y_train

Expand Down Expand Up @@ -688,10 +694,9 @@ def gene(path):
subset = np.std(X_train, axis=0) > 0
X_train = X_train[:, subset]

n, _ = X_train.shape
X_train /= np.std(X_train, axis=0)[None] * np.sqrt(n)
X_train /= np.std(X_train, axis=0)[None]
X_train = np.asfortranarray(X_train)
y_train /= np.std(y_train) * np.sqrt(n)
y_train /= np.std(y_train)

return X_train, y_train

Expand All @@ -702,26 +707,21 @@ def spline_basis(X, **kwargs):
order="F",
**kwargs,
)
n = X.shape[0]
X = X * np.sqrt(n)
X = spl_tr.fit_transform(X)
X /= np.sqrt(n)
return X


def real_data_analysis(
X: np.ndarray,
y: np.ndarray,
configs: dict,
lazify_screen: bool =True,
):
start = time()
strong_state = ad.grpnet(
X=X,
y=y,
**configs,
screen_rule="strong",
lazify_screen=False,
)
end = time()
strong_time = end - start
Expand All @@ -732,7 +732,6 @@ def real_data_analysis(
y=y,
**configs,
screen_rule="pivot",
lazify_screen=lazify_screen,
)
end = time()
pivot_time = end - start
Expand Down
31 changes: 16 additions & 15 deletions adelie/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def objective(
lmda: float,
alpha: float,
penalty: np.ndarray,
weights: np.ndarray,
):
"""Computes the group elastic net objective.
Expand Down Expand Up @@ -61,14 +62,16 @@ def objective(
Elastic net parameter :math:`\\alpha`.
penalty : (G,) np.ndarray
List of penalty factors corresponding to each element of ``groups``.
weights : (G,) np.ndarray
Observation weights.
Returns
-------
obj : float
Group elastic net objective.
"""
return core.solver.objective(
beta0, beta, X, y, groups, group_sizes, lmda, alpha, penalty,
beta0, beta, X, y, groups, group_sizes, lmda, alpha, penalty, weights,
)


Expand Down Expand Up @@ -175,6 +178,7 @@ def grpnet(
group_sizes: np.ndarray,
alpha: float =1,
penalty: np.ndarray =None,
weights: np.ndarray =None,
lmda_path: np.ndarray =None,
max_iters: int =int(1e5),
tol: float =1e-12,
Expand Down Expand Up @@ -219,6 +223,9 @@ def grpnet(
Penalty factor for each group in the same order as ``groups``.
It must be a non-negative vector.
Default is ``None``, in which case, it is set to ``np.sqrt(group_sizes)``.
weights : (n,) np.ndarray
Observation weights.
Default is ``None``, in which case, it is set to ``np.full(n, 1/n)``.
lmda_path : (l,) np.ndarray, optional
The regularization path to solve for.
The full path is not considered if ``early_exit`` is ``True``.
Expand Down Expand Up @@ -317,24 +324,18 @@ def grpnet(
n, p = _X.rows(), _X.cols()
G = len(groups)

if weights is None:
weights = np.full(n, 1/n)

X_means = np.empty(p, dtype=dtype)
_X.means(X_means)

X_group_norms = np.empty(G, dtype=dtype)
_X.group_norms(
groups,
group_sizes,
X_means,
intercept,
X_group_norms,
)
_X.means(weights, X_means)

y_mean = np.mean(y)
y_mean = np.sum(y * weights)
yc = y
if intercept:
yc = yc - y_mean
y_var = np.sum(yc ** 2)
resid = yc
y_var = np.sum(weights * yc ** 2)
resid = weights * yc

if penalty is None:
penalty = np.sqrt(group_sizes)
Expand All @@ -354,14 +355,14 @@ def grpnet(
state = ad.state.basil_naive(
X=X,
X_means=X_means,
X_group_norms=X_group_norms,
y_mean=y_mean,
y_var=y_var,
resid=resid,
groups=groups,
group_sizes=group_sizes,
alpha=alpha,
penalty=penalty,
weights=weights,
screen_set=screen_set,
screen_beta=screen_beta,
screen_is_active=screen_is_active,
Expand Down
32 changes: 12 additions & 20 deletions adelie/src/include/adelie_core/matrix/matrix_naive_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ class MatrixNaiveBase
) const =0;

/**
* @brief Computes v X[:, j]^T where X is the current matrix.
* @brief Computes v X[:, j]^T W where X is the current matrix.
*
* @param j column index.
* @param v scalar to multiply with.
* @param weights diagonal weights W.
* @param out resulting row vector.
*/
virtual void ctmul(
int j,
value_t v,
const Eigen::Ref<const vec_value_t>& weights,
Eigen::Ref<vec_value_t> out
) const =0;

Expand All @@ -58,16 +60,18 @@ class MatrixNaiveBase
) =0;

/**
* @brief Computes v^T X[:, j:j+q]^T where X is the current matrix.
* @brief Computes v^T X[:, j:j+q]^T W where X is the current matrix.
*
* @param j begin column index.
* @param q number of columns.
* @param v vector to multiply with.
* @param weights diagonal weights W.
* @param out resulting row vector.
*/
virtual void btmul(
int j, int q,
const Eigen::Ref<const vec_value_t>& v,
const Eigen::Ref<const vec_value_t>& weights,
Eigen::Ref<vec_value_t> out
) =0;

Expand Down Expand Up @@ -95,16 +99,18 @@ class MatrixNaiveBase
/* Used outside of fitting procedures */

/**
* @brief Computes v X[:, j:j+q]^T where X is the current matrix.
* @brief Computes v X[:, j:j+q]^T W where X is the current matrix.
*
* @param j begin column index.
* @param q number of columns.
* @param v (l, p) sparse matrix to multiply with.
* @param weights diagonal weights W.
* @param out (l, n) resulting row vector.
*/
virtual void sp_btmul(
int j, int q,
const sp_mat_value_t& v,
const Eigen::Ref<const vec_value_t>& weights,
Eigen::Ref<rowmat_value_t> out
) const =0;

Expand All @@ -121,27 +127,13 @@ class MatrixNaiveBase
) const =0;

/**
* @brief Computes column-wise mean.
* @brief Computes column-wise mean (weighted by W).
*
* @param weights diagonal weights W.
* @param out resulting column means.
*/
virtual void means(
Eigen::Ref<vec_value_t> out
) const =0;

/**
* @brief Computes group-wise column norms.
*
* @param groups mapping group number to starting position.
* @param group_sizes mapping group number to group size.
* @param center true to compute centered column norms.
* @param out resulting group-wise column norms.
*/
virtual void group_norms(
const Eigen::Ref<const vec_index_t>& groups,
const Eigen::Ref<const vec_index_t>& group_sizes,
const Eigen::Ref<const vec_value_t>& means,
bool center,
const Eigen::Ref<const vec_value_t>& weights,
Eigen::Ref<vec_value_t> out
) const =0;
};
Expand Down
Loading

0 comments on commit 5983539

Please sign in to comment.