Skip to content

Commit

Permalink
Loop over orbits first then collect results before running next chunk…
Browse files Browse the repository at this point in the history
… of observations in attribute_observations
  • Loading branch information
moeyensj committed Dec 1, 2023
1 parent b83d680 commit d2f1074
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions thor/orbits/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
__all__ = ["Attributions", "attribute_observations", "merge_and_extend_orbits"]


LATLOT_INDEX = np.array([2, 1])
LATLON_INDEX = np.array([2, 1])


class Attributions(qv.Table):
Expand Down Expand Up @@ -182,8 +182,8 @@ def attribution_worker(
coords_predicted = ephemeris_i.coordinates

# Haversine metric requires latitude first then longitude...
coords_latlon = np.radians(coords.values[:, LATLOT_INDEX])
coords_predicted_latlon = np.radians(coords_predicted.values[:, LATLOT_INDEX])
coords_latlon = np.radians(coords.values[:, LATLON_INDEX])
coords_predicted_latlon = np.radians(coords_predicted.values[:, LATLON_INDEX])

num_obs = len(coords_predicted)
k = np.minimum(3, num_obs)
Expand Down Expand Up @@ -296,11 +296,15 @@ def attribute_observations(
refs_to_free.append(observations_ref)
logger.info("Placed observations in the object store.")

futures = []
for orbit_id_chunk in _iterate_chunks(orbit_ids, orbits_chunk_size):
for observations_indices_chunk in _iterate_chunks(
observation_indices, observations_chunk_size
):
# For each chunk of observations run attribution with all orbits.
# We wait for each chunk of orbits to finish before starting the next
# chunk of observations to reduce the memory pressure. If not, the number
# of expected futures will be large (num_orbits / orbit_chunk_size * num_observation_chunks)
for observations_indices_chunk in _iterate_chunks(
observation_indices, observations_chunk_size
):
futures = []
for orbit_id_chunk in _iterate_chunks(orbit_ids, orbits_chunk_size):
futures.append(
attribution_worker_remote.remote(
orbit_id_chunk,
Expand All @@ -313,9 +317,9 @@ def attribute_observations(
)
)

while futures:
finished, futures = ray.wait(futures, num_returns=1)
attributions_list.append(ray.get(finished[0]))
while futures:
finished, futures = ray.wait(futures, num_returns=1)
attributions_list.append(ray.get(finished[0]))

if len(refs_to_free) > 0:
ray.internal.free(refs_to_free)
Expand Down

0 comments on commit d2f1074

Please sign in to comment.