Skip to content

Commit

Permalink
Bug fix: Correctly select the coincident attribution with the lowest …
Browse files Browse the repository at this point in the history
…distance
  • Loading branch information
moeyensj committed Dec 1, 2023
1 parent 5771c92 commit 04bb04d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 23 deletions.
20 changes: 13 additions & 7 deletions thor/orbits/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ def drop_coincident_attributions(
# Flatten the table so nested columns are dot-delimited at the top level
flattened_table = self.flattened_table()

# Add index to flattened table
flattened_table = flattened_table.add_column(
0, "index", pa.array(np.arange(len(flattened_table)))
)

# Drop the residual values (a list column) due to: https://github.com/apache/arrow/issues/32504
flattened_table = flattened_table.drop(["residuals.values"])

Expand All @@ -76,17 +81,13 @@ def drop_coincident_attributions(
flattened_observations, ["obs_id"], right_keys=["id"]
)

# Add index column
flattened_table = flattened_table.add_column(
0, "index", pa.array(np.arange(len(flattened_table)))
)

# Sort the table
flattened_table = flattened_table.sort_by(
[
("orbit_id", "ascending"),
("coordinates.time.days", "ascending"),
("coordinates.time.nanos", "ascending"),
("distance", "ascending"),
]
)

Expand All @@ -100,7 +101,11 @@ def drop_coincident_attributions(
.column("index_first")
)

return self.take(indices)
filtered = self.take(indices)
if filtered.fragmented():
filtered = qv.defragment(filtered)

return filtered


def attribution_worker(
Expand Down Expand Up @@ -336,6 +341,8 @@ def attribute_observations(

attributions = qv.concatenate(attributions_list)
attributions = attributions.sort_by(["orbit_id", "obs_id", "distance"])
if attributions.fragmented():
attributions = qv.defragment(attributions)

time_end = time.time()
logger.info(
Expand Down Expand Up @@ -459,7 +466,6 @@ def merge_and_extend_orbits(
# the same time, keep only observation with smallest distance
attributions = attributions.drop_coincident_attributions(observations)

attributions = qv.defragment(attributions)
# Create a new orbit members table with the newly attributed observations and
# filter the orbits to only include those that still have observations
orbit_members_iter = FittedOrbitMembers.from_kwargs(
Expand Down
34 changes: 18 additions & 16 deletions thor/orbits/tests/test_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,34 @@

def test_Attributions_drop_coincident_attributions():
observations = Observations.from_kwargs(
id=["01", "02", "03", "04"],
exposure_id=["e01", "e01", "e02", "e02"],
id=["01", "02", "03", "04", "05"],
exposure_id=["e01", "e01", "e02", "e02", "e02"],
coordinates=SphericalCoordinates.from_kwargs(
time=Timestamp.from_mjd([59001.1, 59001.1, 59002.1, 59002.1], scale="utc"),
lon=[1, 2, 3, 4],
lat=[5, 6, 7, 8],
origin=Origin.from_kwargs(code=["500", "500", "500", "500"]),
time=Timestamp.from_mjd(
[59001.1, 59001.1, 59002.1, 59002.1, 59002.1], scale="utc"
),
lon=[1, 2, 3, 4, 5],
lat=[5, 6, 7, 8, 9],
origin=Origin.from_kwargs(code=["500", "500", "500", "500", "500"]),
),
photometry=Photometry.from_kwargs(
filter=["g", "g", "g", "g"],
mag=[10, 11, 12, 13],
filter=["g", "g", "g", "g", "g"],
mag=[10, 11, 12, 13, 14],
),
state_id=[0, 0, 1, 1],
state_id=[0, 0, 1, 1, 1],
)

attributions = Attributions.from_kwargs(
orbit_id=["o01", "o01", "o02", "o03"],
obs_id=["01", "02", "03", "03"],
distance=[0.5 / 3600, 1 / 3600, 2 / 3600, 1 / 3600],
orbit_id=["o01", "o01", "o02", "o03", "o04", "o04"],
obs_id=["01", "02", "03", "03", "04", "05"],
distance=[1 / 3600, 0.5 / 3600, 2 / 3600, 1 / 3600, 2 / 3600, 1 / 3600],
)

filtered = attributions.drop_coincident_attributions(observations)
# Orbit 1 gets linked to two observations at the same time
# We should expect to only keep the one with the smallest distance
# Orbit 2 and 3 get linked to the same observation but we should keep both
assert len(filtered) == 3
assert filtered.orbit_id.to_pylist() == ["o01", "o02", "o03"]
assert filtered.obs_id.to_pylist() == ["01", "03", "03"]
assert filtered.distance.to_pylist() == [0.5 / 3600, 2 / 3600, 1 / 3600]
assert len(filtered) == 4
assert filtered.orbit_id.to_pylist() == ["o01", "o02", "o03", "o04"]
assert filtered.obs_id.to_pylist() == ["02", "03", "03", "05"]
assert filtered.distance.to_pylist() == [0.5 / 3600, 2 / 3600, 1 / 3600, 1 / 3600]

0 comments on commit 04bb04d

Please sign in to comment.