Skip to content

Commit

Permalink
revert numba updates in pcg
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Sep 23, 2024
1 parent 90b85db commit 55882b6
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 18 deletions.
10 changes: 5 additions & 5 deletions pfb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
__version__ = '0.0.4'

def set_envs(nthreads, ncpu):
os.environ["OMP_NUM_THREADS"] = '2'
os.environ["OPENBLAS_NUM_THREADS"] = '2'
os.environ["MKL_NUM_THREADS"] = '2'
os.environ["VECLIB_MAXIMUM_THREADS"] = '2'
os.environ["NPY_NUM_THREADS"] = '2'
os.environ["OMP_NUM_THREADS"] = str(nthreads)
os.environ["OPENBLAS_NUM_THREADS"] = str(nthreads)
os.environ["MKL_NUM_THREADS"] = str(nthreads)
os.environ["VECLIB_MAXIMUM_THREADS"] = str(nthreads)
os.environ["NPY_NUM_THREADS"] = str(nthreads)
os.environ["NUMBA_NUM_THREADS"] = str(nthreads)
os.environ["JAX_PLATFORMS"] = 'cpu'
os.environ["JAX_ENABLE_X64"] = 'True'
Expand Down
38 changes: 30 additions & 8 deletions pfb/opt/pcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,29 @@ def M(x): return x
Ap = A(p)
tA += (time() - ti)
ti = time()
# rnorm = np.vdot(r, y)
# alpha = rnorm / np.vdot(p, Ap)
rnorm = np.vdot(r, y)
alpha = rnorm / np.vdot(p, Ap)
# import ipdb; ipdb.set_trace()
rnorm, alpha = alpha_update(r, y, p, Ap)
# rnorm, alpha = alpha_update(r, y, p, Ap)
tvdot += (time() - ti)
ti = time()
# x = xp + alpha * p
# r = rp + alpha * Ap
x, r = update(x, xp, r, rp, p, Ap, alpha)
ne.evaluate('xp + alpha*p',
out=x,
local_dict={
'xp': xp,
'alpha': alpha,
'p': p},
casting='unsafe')
ne.evaluate('rp + alpha*Ap',
out=r,
local_dict={
'rp': rp,
'alpha': alpha,
'Ap': Ap},
casting='unsafe')
# x, r = update(x, xp, r, rp, p, Ap, alpha)
tupdate += (time() - ti)
y = M(r)

Expand All @@ -243,12 +257,20 @@ def M(x): return x
# rnorm_next = np.vdot(r, y)

ti = time()
# rnorm_next = np.vdot(r, y)
# beta = rnorm_next / rnorm
rnorm_next = np.vdot(r, y)
beta = rnorm_next / rnorm
ne.evaluate('beta*p-y',
out=p,
local_dict={
'beta': beta,
'p': p,
'y': y},
casting='unsafe')

# p = beta * p - y
rnorm, p = beta_update(r, y, p, rnorm)
# rnorm, p = beta_update(r, y, p, rnorm)
tp += (time() - ti)
# rnorm = rnorm_next
rnorm = rnorm_next
k += 1
epsp = eps
ti = time()
Expand Down
5 changes: 0 additions & 5 deletions pfb/workers/model2comps.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,3 @@ def _model2comps(**kw):
fits_name,
hdr,
overwrite=True)




print("All done here.", file=log)

0 comments on commit 55882b6

Please sign in to comment.