Skip to content

Commit

Permalink
attempt to clear memory with client.run
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Aug 27, 2024
1 parent e82df89 commit 6ccef0b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pfb/utils/stokes2vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import xarray as xr
from numba import njit, prange, literally
from dask.graph_manipulation import clone
from distributed import get_client, worker_client
from distributed import worker_client
import dask.array as da
from xarray import Dataset
# from quartical.utils.numba import coerce_literal
Expand Down
38 changes: 25 additions & 13 deletions pfb/workers/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,18 @@ def _init(**kw):
if opts.progressbar:
print(f"\rProcessing: {n_launched}/{nds}", end='', flush=True)

times_out = []
freqs_out = []
ac_iter = as_completed(futures)
for completed_future in ac_iter:

result = completed_future.result()
try:
times_out.append(result[0])
freqs_out.append(result[1])
except:
pass

if n_launched == nds: # Stop once all jobs have been launched.
break

Expand All @@ -340,6 +349,8 @@ def _init(**kw):
wc = opts.weight_column
weight = None if wc is None else getattr(subds, wc).data

client.run(clear_memory, workers=idle_workers)

worker = associated_workers.pop(completed_future)

future = client.submit(single_stokes,
Expand All @@ -366,7 +377,6 @@ def _init(**kw):
workers=worker,
key='image-'+uuid4().hex)

futures.append(future)
ac_iter.add(future)
associated_workers[future] = worker
n_launched += 1
Expand All @@ -377,18 +387,16 @@ def _init(**kw):
if opts.progressbar:
print(f"\rProcessing: {n_launched}/{nds}", end='', flush=True)

wait(futures)

times_out = []
freqs_out = []
for f in futures:
result = f.result()
# this should fail if a chunk is fully flagged
try:
times_out.append(result[0])
freqs_out.append(result[1])
except:
pass
# times_out = []
# freqs_out = []
# for f in futures:
# result = f.result()
# # this should fail if a chunk is fully flagged
# try:
# times_out.append(result[0])
# freqs_out.append(result[1])
# except:
# pass

times_out = np.unique(times_out)
freqs_out = np.unique(freqs_out)
Expand All @@ -401,3 +409,7 @@ def _init(**kw):
f"{ntime} output times", file=log)

return


def clear_memory():
gc.collect()

0 comments on commit 6ccef0b

Please sign in to comment.