-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sort clusters and orbits by linkage ID and observation time
- Loading branch information
Showing
6 changed files
with
250 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from typing import Tuple | ||
|
||
import numpy as np | ||
import pyarrow as pa | ||
import quivr as qv | ||
|
||
from ..observations import Observations | ||
|
||
__all__ = [ | ||
"sort_by_id_and_time", | ||
] | ||
|
||
|
||
def sort_by_id_and_time( | ||
linkages: qv.AnyTable, | ||
members: qv.AnyTable, | ||
observations: Observations, | ||
linkage_column: str, | ||
) -> Tuple[qv.AnyTable, qv.AnyTable]: | ||
""" | ||
Sort linkages and linkage members by orbit ID and observation time. | ||
Parameters | ||
---------- | ||
linkages : qv.AnyTable | ||
Linkages to sort. | ||
members : qv.AnyTable | ||
Linkage members to sort. | ||
observations : Observations | ||
Observations from which orbit members were generated. Observations | ||
are used to determine the observation time of each orbit member. | ||
linkage_column : str | ||
Column name in the linkage table to use for sorting. For clusters | ||
this is "cluster_id" and for orbits this is "orbit_id". | ||
Returns | ||
------- | ||
orbits : qv.AnyTable | ||
Sorted orbits. | ||
members : qv.AnyTable | ||
Sorted orbit members. | ||
""" | ||
# Grab the orbit ID column from the orbits table and add an index column | ||
linkage_table = linkages.table.select([linkage_column]) | ||
linkage_table = linkage_table.add_column( | ||
0, "index", pa.array(np.arange(0, len(linkage_table))) | ||
) | ||
|
||
# Grab the orbit ID and observation ID columns from the orbit members table and add an index column | ||
members_table = members.table.select([linkage_column, "obs_id"]) | ||
members_table = members_table.add_column( | ||
0, "index", pa.array(np.arange(0, len(members_table))) | ||
) | ||
|
||
# Grab the observation ID, observation time columns and join with the orbit members table on the observation ID | ||
observation_times = observations.flattened_table().select( | ||
["id", "coordinates.time.days", "coordinates.time.nanos"] | ||
) | ||
member_times = members_table.join( | ||
observation_times, keys=["obs_id"], right_keys=["id"] | ||
) | ||
|
||
# Sort the reduced orbit table by orbit ID and the member times table by orbit ID and observation time | ||
linkage_table = linkage_table.sort_by([(linkage_column, "ascending")]) | ||
member_times = member_times.sort_by( | ||
[ | ||
(linkage_column, "ascending"), | ||
("coordinates.time.days", "ascending"), | ||
("coordinates.time.nanos", "ascending"), | ||
] | ||
) | ||
|
||
sorted = linkages.take(linkage_table["index"]) | ||
sorted_members = members.take(member_times["index"]) | ||
|
||
if sorted.fragmented(): | ||
sorted = qv.defragment(sorted) | ||
if sorted_members.fragmented(): | ||
sorted_members = qv.defragment(sorted_members) | ||
return sorted, sorted_members |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import uuid | ||
|
||
import numpy as np | ||
import pyarrow as pa | ||
import quivr as qv | ||
from adam_core.coordinates import Origin, SphericalCoordinates | ||
from adam_core.time import Timestamp | ||
|
||
from ...observations.observations import Observations | ||
from ...observations.photometry import Photometry | ||
from ..linkages import sort_by_id_and_time | ||
|
||
|
||
class Linkages(qv.Table): | ||
linkage_id = qv.StringColumn(default=lambda: uuid.uuid4().hex) | ||
|
||
|
||
class LinkageMembers(qv.Table): | ||
linkage_id = qv.StringColumn(nullable=True) | ||
obs_id = qv.StringColumn(nullable=True) | ||
|
||
|
||
def test_sort_by_id_and_time(): | ||
# Create a table of linkages and linkage members and test that sorting them by linkage ID | ||
# and observation time works as expected | ||
linkages = Linkages.from_kwargs( | ||
linkage_id=[ | ||
"linkage_03", | ||
"linkage_04", | ||
"linkage_01", | ||
"linkage_05", | ||
"linkage_02", | ||
], | ||
) | ||
|
||
linkage_members = LinkageMembers.from_kwargs( | ||
linkage_id=[ | ||
"linkage_03", | ||
"linkage_03", | ||
"linkage_03", | ||
"linkage_04", | ||
"linkage_04", | ||
"linkage_04", | ||
"linkage_01", | ||
"linkage_01", | ||
"linkage_01", | ||
"linkage_05", | ||
"linkage_05", | ||
"linkage_05", | ||
"linkage_02", | ||
"linkage_02", | ||
"linkage_02", | ||
], | ||
obs_id=[ | ||
"obs_03", | ||
"obs_02", | ||
"obs_04", | ||
"obs_05", | ||
"obs_03", | ||
"obs_04", | ||
"obs_01", | ||
"obs_03", | ||
"obs_02", | ||
"obs_04", | ||
"obs_05", | ||
"obs_03", | ||
"obs_02", | ||
"obs_03", | ||
"obs_01", | ||
], | ||
) | ||
|
||
observations = Observations.from_kwargs( | ||
id=[f"obs_{i:02d}" for i in range(1, 6)], | ||
exposure_id=[f"exposure_{i:01d}" for i in range(1, 6)], | ||
coordinates=SphericalCoordinates.from_kwargs( | ||
rho=np.random.random(5), | ||
lon=np.random.random(5), | ||
lat=np.random.random(5), | ||
vrho=np.random.random(5), | ||
vlon=np.random.random(5), | ||
vlat=np.random.random(5), | ||
time=Timestamp.from_mjd(np.arange(59000, 59005)), | ||
origin=Origin.from_kwargs(code=pa.repeat("500", 5)), | ||
frame="eclipitic", | ||
), | ||
photometry=Photometry.from_kwargs( | ||
mag=np.random.random(5), | ||
filter=pa.repeat("V", 5), | ||
), | ||
state_id=np.arange(0, 5), | ||
) | ||
|
||
sorted_linkages, sorted_linkage_members = sort_by_id_and_time( | ||
linkages, linkage_members, observations, "linkage_id" | ||
) | ||
|
||
assert sorted_linkages.linkage_id.to_pylist() == [ | ||
"linkage_01", | ||
"linkage_02", | ||
"linkage_03", | ||
"linkage_04", | ||
"linkage_05", | ||
] | ||
assert sorted_linkage_members.linkage_id.to_pylist() == [ | ||
"linkage_01", | ||
"linkage_01", | ||
"linkage_01", | ||
"linkage_02", | ||
"linkage_02", | ||
"linkage_02", | ||
"linkage_03", | ||
"linkage_03", | ||
"linkage_03", | ||
"linkage_04", | ||
"linkage_04", | ||
"linkage_04", | ||
"linkage_05", | ||
"linkage_05", | ||
"linkage_05", | ||
] | ||
assert sorted_linkage_members.obs_id.to_pylist() == [ | ||
"obs_01", | ||
"obs_02", | ||
"obs_03", | ||
"obs_01", | ||
"obs_02", | ||
"obs_03", | ||
"obs_02", | ||
"obs_03", | ||
"obs_04", | ||
"obs_03", | ||
"obs_04", | ||
"obs_05", | ||
"obs_03", | ||
"obs_04", | ||
"obs_05", | ||
] |