From a08dd82650e147e8095408993e76b762b04c3576 Mon Sep 17 00:00:00 2001 From: Joachim Moeyens Date: Thu, 30 Nov 2023 17:18:48 -0800 Subject: [PATCH] Sort clusters and orbits by linkage ID and observation time --- thor/clusters.py | 12 ++- thor/orbits/iod.py | 20 ++++- thor/orbits/od.py | 12 ++- thor/utils/linkages.py | 80 +++++++++++++++++ thor/utils/tests/__init__.py | 0 thor/utils/tests/test_linkages.py | 138 ++++++++++++++++++++++++++++++ 6 files changed, 250 insertions(+), 12 deletions(-) create mode 100644 thor/utils/linkages.py create mode 100644 thor/utils/tests/__init__.py create mode 100644 thor/utils/tests/test_linkages.py diff --git a/thor/clusters.py b/thor/clusters.py index 086eff4f..2e5075e1 100644 --- a/thor/clusters.py +++ b/thor/clusters.py @@ -14,6 +14,7 @@ from adam_core.ray_cluster import initialize_use_ray from .range_and_transform import TransformedDetections +from .utils.linkages import sort_by_id_and_time # Disable GPU until the GPU-accelerated clustering codes # are better tested and implemented @@ -779,16 +780,19 @@ def cluster_and_link( f"Cluster deduplication completed in {time_end_drop - time_start_drop:.3f} seconds." ) + # Sort clusters by cluster ID and observation time + clusters, cluster_members = sort_by_id_and_time( + clusters, cluster_members, observations, "cluster_id" + ) + else: clusters = Clusters.empty() cluster_members = ClusterMembers.empty() time_end_cluster = time.time() - logger.info("Found {} clusters.".format(len(clusters))) + logger.info(f"Found {len(clusters)} clusters.") logger.info( - "Clustering completed in {:.3f} seconds.".format( - time_end_cluster - time_start_cluster - ) + f"Clustering completed in {time_end_cluster - time_start_cluster:.3f} seconds." ) return clusters, cluster_members diff --git a/thor/orbits/iod.py b/thor/orbits/iod.py index b18b568d..2f3a1c12 100644 --- a/thor/orbits/iod.py +++ b/thor/orbits/iod.py @@ -17,6 +17,7 @@ from ..clusters import ClusterMembers from ..observations.observations import Observations from ..orbit_determination.fitted_orbits import FittedOrbitMembers, FittedOrbits +from ..utils.linkages import sort_by_id_and_time from .gauss import gaussIOD logger = logging.getLogger(__name__) @@ -641,6 +642,9 @@ def initial_orbit_determination( iod_orbits = qv.concatenate(iod_orbits_list) iod_orbit_members = qv.concatenate(iod_orbit_members_list) + time_start_drop = time.time() + logger.info("Removing duplicate initial orbits...") + num_orbits = len(iod_orbits) iod_orbits, iod_orbit_members = iod_orbits.drop_duplicates( iod_orbit_members, subset=[ @@ -655,18 +659,26 @@ def initial_orbit_determination( ], keep="first", ) + time_end_drop = time.time() + logger.info(f"Removed {num_orbits - len(iod_orbits)} duplicate clusters.") + time_end_drop = time.time() + logger.info( + f"Inital orbit deduplication completed in {time_end_drop - time_start_drop:.3f} seconds." + ) - logger.info("Found {} initial orbits.".format(len(iod_orbits))) + # Sort initial orbits by orbit ID and observation time + iod_orbits, iod_orbit_members = sort_by_id_and_time( + iod_orbits, iod_orbit_members, observations, "orbit_id" + ) else: iod_orbits = FittedOrbits.empty() iod_orbit_members = FittedOrbitMembers.empty() time_end = time.perf_counter() + logger.info(f"Found {len(iod_orbits)} initial orbits.") logger.info( - "Initial orbit determination completed in {:.3f} seconds.".format( - time_end - time_start - ) + f"Initial orbit determination completed in {time_end - time_start:.3f} seconds." ) return iod_orbits, iod_orbit_members diff --git a/thor/orbits/od.py b/thor/orbits/od.py index d30fbaf7..6c87dccb 100644 --- a/thor/orbits/od.py +++ b/thor/orbits/od.py @@ -16,6 +16,7 @@ from ..observations.observations import Observations from ..orbit_determination import FittedOrbitMembers, FittedOrbits +from ..utils.linkages import sort_by_id_and_time logger = logging.getLogger(__name__) @@ -687,16 +688,19 @@ def differential_correction( od_orbits = qv.concatenate(od_orbits_list) od_orbit_members = qv.concatenate(od_orbit_members_list) + # Sort orbits by orbit ID and observation time + od_orbits, od_orbit_members = sort_by_id_and_time( + od_orbits, od_orbit_members, observations, "orbit_id" + ) + else: od_orbits = FittedOrbits.empty() od_orbit_members = FittedOrbitMembers.empty() time_end = time.perf_counter() - logger.info("Differentially corrected {} orbits.".format(len(od_orbits))) + logger.info(f"Differentially corrected {len(od_orbits)} orbits.") logger.info( - "Differential correction completed in {:.3f} seconds.".format( - time_end - time_start - ) + f"Differential correction completed in {time_end - time_start:.3f} seconds." ) return od_orbits, od_orbit_members diff --git a/thor/utils/linkages.py b/thor/utils/linkages.py new file mode 100644 index 00000000..06d2528b --- /dev/null +++ b/thor/utils/linkages.py @@ -0,0 +1,80 @@ +from typing import Tuple + +import numpy as np +import pyarrow as pa +import quivr as qv + +from ..observations import Observations + +__all__ = [ + "sort_by_id_and_time", +] + + +def sort_by_id_and_time( + linkages: qv.AnyTable, + members: qv.AnyTable, + observations: Observations, + linkage_column: str, +) -> Tuple[qv.AnyTable, qv.AnyTable]: + """ + Sort linkages and linkage members by linkage ID and observation time. + + Parameters + ---------- + linkages : qv.AnyTable + Linkages to sort. + members : qv.AnyTable + Linkage members to sort. + observations : Observations + Observations from which linkage members were generated. Observations + are used to determine the observation time of each linkage member. + linkage_column : str + Column name in the linkage table to use for sorting. For clusters + this is "cluster_id" and for orbits this is "orbit_id". + + Returns + ------- + orbits : qv.AnyTable + Sorted linkages. + members : qv.AnyTable + Sorted linkage members. + """ + # Grab the orbit ID column from the linkages table and add an index column + linkage_table = linkages.table.select([linkage_column]) + linkage_table = linkage_table.add_column( + 0, "index", pa.array(np.arange(0, len(linkage_table))) + ) + + # Grab the linkage ID and observation ID columns from the linkage members table and add an index column + members_table = members.table.select([linkage_column, "obs_id"]) + members_table = members_table.add_column( + 0, "index", pa.array(np.arange(0, len(members_table))) + ) + + # Grab the observation ID, observation time columns and join with the linkage members table on the observation ID + observation_times = observations.flattened_table().select( + ["id", "coordinates.time.days", "coordinates.time.nanos"] + ) + member_times = members_table.join( + observation_times, keys=["obs_id"], right_keys=["id"] + ) + + # Sort the reduced linkages table by linkage ID and the linkage member times table by linkage ID and observation time + linkage_table = linkage_table.sort_by([(linkage_column, "ascending")]) + member_times = member_times.sort_by( + [ + (linkage_column, "ascending"), + ("coordinates.time.days", "ascending"), + ("coordinates.time.nanos", "ascending"), + ] + ) + + sorted = linkages.take(linkage_table["index"]) + sorted_members = members.take(member_times["index"]) + + if sorted.fragmented(): + sorted = qv.defragment(sorted) + if sorted_members.fragmented(): + sorted_members = qv.defragment(sorted_members) + return sorted, sorted_members diff --git a/thor/utils/tests/__init__.py b/thor/utils/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/thor/utils/tests/test_linkages.py b/thor/utils/tests/test_linkages.py new file mode 100644 index 00000000..9961ce05 --- /dev/null +++ b/thor/utils/tests/test_linkages.py @@ -0,0 +1,138 @@ +import uuid + +import numpy as np +import pyarrow as pa +import quivr as qv +from adam_core.coordinates import Origin, SphericalCoordinates +from adam_core.time import Timestamp + +from ...observations.observations import Observations +from ...observations.photometry import Photometry +from ..linkages import sort_by_id_and_time + + +class Linkages(qv.Table): + linkage_id = qv.StringColumn(default=lambda: uuid.uuid4().hex) + + +class LinkageMembers(qv.Table): + linkage_id = qv.StringColumn(nullable=True) + obs_id = qv.StringColumn(nullable=True) + + +def test_sort_by_id_and_time(): + # Create a table of linkages and linkage members and test that sorting them by linkage ID + # and observation time works as expected + linkages = Linkages.from_kwargs( + linkage_id=[ + "linkage_03", + "linkage_04", + "linkage_01", + "linkage_05", + "linkage_02", + ], + ) + + linkage_members = LinkageMembers.from_kwargs( + linkage_id=[ + "linkage_03", + "linkage_03", + "linkage_03", + "linkage_04", + "linkage_04", + "linkage_04", + "linkage_01", + "linkage_01", + "linkage_01", + "linkage_05", + "linkage_05", + "linkage_05", + "linkage_02", + "linkage_02", + "linkage_02", + ], + obs_id=[ + "obs_03", + "obs_02", + "obs_04", + "obs_05", + "obs_03", + "obs_04", + "obs_01", + "obs_03", + "obs_02", + "obs_04", + "obs_05", + "obs_03", + "obs_02", + "obs_03", + "obs_01", + ], + ) + + observations = Observations.from_kwargs( + id=[f"obs_{i:02d}" for i in range(1, 6)], + exposure_id=[f"exposure_{i:01d}" for i in range(1, 6)], + coordinates=SphericalCoordinates.from_kwargs( + rho=np.random.random(5), + lon=np.random.random(5), + lat=np.random.random(5), + vrho=np.random.random(5), + vlon=np.random.random(5), + vlat=np.random.random(5), + time=Timestamp.from_mjd(np.arange(59000, 59005)), + origin=Origin.from_kwargs(code=pa.repeat("500", 5)), + frame="eclipitic", + ), + photometry=Photometry.from_kwargs( + mag=np.random.random(5), + filter=pa.repeat("V", 5), + ), + state_id=np.arange(0, 5), + ) + + sorted_linkages, sorted_linkage_members = sort_by_id_and_time( + linkages, linkage_members, observations, "linkage_id" + ) + + assert sorted_linkages.linkage_id.to_pylist() == [ + "linkage_01", + "linkage_02", + "linkage_03", + "linkage_04", + "linkage_05", + ] + assert sorted_linkage_members.linkage_id.to_pylist() == [ + "linkage_01", + "linkage_01", + "linkage_01", + "linkage_02", + "linkage_02", + "linkage_02", + "linkage_03", + "linkage_03", + "linkage_03", + "linkage_04", + "linkage_04", + "linkage_04", + "linkage_05", + "linkage_05", + "linkage_05", + ] + assert sorted_linkage_members.obs_id.to_pylist() == [ + "obs_01", + "obs_02", + "obs_03", + "obs_01", + "obs_02", + "obs_03", + "obs_02", + "obs_03", + "obs_04", + "obs_03", + "obs_04", + "obs_05", + "obs_03", + "obs_04", + "obs_05", + ]