Skip to content

Commit

Permalink
Export bvls and pinball, and add gaussian_cov tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesYang007 committed Oct 19, 2024
1 parent f932538 commit b2b9bda
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
2 changes: 2 additions & 0 deletions adelie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
GroupElasticNet
)
from .solver import (
bvls,
gaussian_cov,
grpnet,
pinball,
)
61 changes: 61 additions & 0 deletions tests/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,62 @@ def test_grpnet(
check_solutions(args, state, cvxpy_glm, eps=1e-4)


# ==========================================================================================
# TEST gaussian_cov
# ==========================================================================================


@pytest.mark.parametrize("constraint", [False, True])
@pytest.mark.parametrize("n, p, G", [
[10, 50, 10],
[40, 13, 7],
])
def test_gaussian_cov(
n, p, G, constraint, adev_tol=0.2,
):
data = ad.data.dense(n, p, p)
X, glm = data["X"], data["glm"]
groups = np.concatenate([
[0],
np.random.choice(np.arange(1, p), size=G-1, replace=False)
])
groups = np.sort(groups).astype(int)
group_sizes = np.concatenate([groups, [p]], dtype=int)
group_sizes = group_sizes[1:] - group_sizes[:-1]

if constraint:
constraints = [None] * G
c_order = np.random.choice(G, G // 2, replace=False)
for i in c_order:
size = group_sizes[i]
constraints[i] = zero_constraint(size, dtype=np.float64)
else:
constraints = None

state_naive = ad.grpnet(
X=X,
glm=glm,
constraints=constraints,
groups=groups,
intercept=False,
adev_tol=adev_tol,
progress_bar=False,
)

A = np.asfortranarray(X.T @ X) / n
v = X.T @ glm.y / n
state_cov = ad.gaussian_cov(
A=A,
v=v,
constraints=constraints,
groups=groups,
lmda_path=state_naive.lmdas,
progress_bar=False,
)

assert np.allclose(state_naive.betas.toarray(), state_cov.betas.toarray())


# ==========================================================================================
# TEST bvls
# ==========================================================================================
Expand Down Expand Up @@ -1017,6 +1073,11 @@ def test_bvls(n, p, seed=0):
assert np.allclose(actual, expected)


# ==========================================================================================
# TEST pinball
# ==========================================================================================


@pytest.mark.parametrize("m", [3, 5, 10, 20])
@pytest.mark.parametrize("d", [1, 5, 10])
@pytest.mark.parametrize("seed", np.arange(10))
Expand Down

0 comments on commit b2b9bda

Please sign in to comment.