Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use pyarrow groupby on joined strings of observation IDs to … #136

Merged
merged 1 commit into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 34 additions & 20 deletions thor/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class Clusters(qv.Table):
num_obs = qv.Int64Column()

def drop_duplicates(
self, cluster_members: "ClusterMembers"
self,
cluster_members: "ClusterMembers",
) -> Tuple["Clusters", "ClusterMembers"]:
"""
Drop clusters that have identical sets of observation IDs.
Expand All @@ -58,33 +59,46 @@ def drop_duplicates(
`~thor.clusters.Clusters`
A table of clusters with duplicate clusters removed.
"""

# Sort by cluster_id and obs_id
sorted = self.sort_by(["cluster_id"])
cluster_members = cluster_members.sort_by(["cluster_id", "obs_id"])
clusters_sorted = self.sort_by([("cluster_id", "ascending")])
cluster_members_sorted = cluster_members.sort_by(
[("cluster_id", "ascending"), ("obs_id", "ascending")]
)

grouped_by_cluster_id = cluster_members.table.group_by(
["cluster_id"]
# Group by cluster_id and aggregate a list of distinct obs_ids
grouped_by_cluster_id = cluster_members_sorted.table.group_by(
["cluster_id"], use_threads=False
).aggregate([("obs_id", "distinct")])
grouped_by_cluster_id = grouped_by_cluster_id.append_column(
"index", pa.array(np.arange(0, len(sorted)))
)
obs_ids_per_cluster = grouped_by_cluster_id["obs_id_distinct"].to_pylist()

# We revert to pandas here because grouping by a list of observation IDs with
# pyarrow functions fails at the table creation stage during aggregation.
# This is likely a missing feature in pyarrow. The following code doesn't work:
# Group by with a distinct aggregation is not guaranteed to preserve the order of the elements within each list
# but does preserve the order of the lists themselves. So we sort each list of obs_ids and while we are
# sorting we also convert the lists to a single string on which we can group later.
# Pyarrow currently does not support groupby on lists of strings, this is likely a missing feature.
# As an example, the following code doesn't work:
# grouped_by_obs_lists = grouped_by_cluster_id.group_by(
# ["obs_id_distinct"],
# use_threads=False
# ).aggregate([("index", "first")

df = grouped_by_cluster_id.to_pandas()
df["obs_id_distinct"] = df["obs_id_distinct"].apply(lambda x: x.tolist())
indices = df.drop_duplicates(subset=["obs_id_distinct"])["index"].values
# ).aggregate([("index", "first")])
for i, obs_ids_i in enumerate(obs_ids_per_cluster):
obs_ids_i.sort()
obs_ids_per_cluster[i] = "".join(obs_ids_i)

squashed_obs_ids = pa.table(
{
"index": pa.array(np.arange(0, len(obs_ids_per_cluster))),
"obs_ids": obs_ids_per_cluster,
}
)
indices = (
squashed_obs_ids.group_by(["obs_ids"], use_threads=False)
.aggregate([("index", "first")])["index_first"]
.combine_chunks()
)

filtered = sorted.take(indices)
filtered_cluster_members = cluster_members.apply_mask(
pc.is_in(cluster_members.cluster_id, filtered.cluster_id)
filtered = clusters_sorted.take(indices)
filtered_cluster_members = cluster_members_sorted.apply_mask(
pc.is_in(cluster_members_sorted.cluster_id, filtered.cluster_id)
)
return filtered, filtered_cluster_members

Expand Down
56 changes: 56 additions & 0 deletions thor/tests/test_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import pytest

from ..clusters import (
ClusterMembers,
Clusters,
_adjust_labels,
_build_label_aliases,
_extend_2d_array,
Expand Down Expand Up @@ -348,3 +350,57 @@ def test_label_clusters():
expected = np.array([-1, 0, -1, 1, -1, 0])
labels = _label_clusters(hits, points)
np.testing.assert_array_equal(expected, labels)


def test_Clusters_drop_duplicates():
# Test that the cluster deduplication works as expected
# Here we duplicate the same 5 clusters 10000 times and check that the
# deduplication correctly identifies the first 5 clusters
obs_ids = [
["obs_01", "obs_02", "obs_03", "obs_04", "obs_05"],
["obs_02", "obs_03", "obs_04", "obs_05", "obs_06"],
["obs_03", "obs_04", "obs_05", "obs_06", "obs_07"],
["obs_04", "obs_05", "obs_06", "obs_07", "obs_08"],
["obs_05", "obs_06", "obs_07", "obs_08", "obs_09"],
]

obs_ids_duplicated = []
for i in range(10000):
obs_ids_duplicated += obs_ids
cluster_ids = [f"c{i:05d}" for i in range(len(obs_ids_duplicated))]

clusters = Clusters.from_kwargs(
cluster_id=cluster_ids,
vtheta_x=np.full(len(cluster_ids), 0.0),
vtheta_y=np.full(len(cluster_ids), 0.0),
arc_length=np.full(len(cluster_ids), 0.0),
num_obs=np.full(len(cluster_ids), 5),
)
cluster_members = ClusterMembers.from_kwargs(
cluster_id=np.repeat(cluster_ids, 5),
obs_id=[
obs for cluster_members_i in obs_ids_duplicated for obs in cluster_members_i
],
)

clusters_filtered, cluster_members_filtered = clusters.drop_duplicates(
cluster_members
)
assert len(clusters_filtered) == 5
assert clusters_filtered.cluster_id.to_pylist() == [
"c00000",
"c00001",
"c00002",
"c00003",
"c00004",
]

assert len(cluster_members_filtered) == 25
np.testing.assert_equal(
cluster_members_filtered.cluster_id.to_numpy(zero_copy_only=False),
np.repeat(cluster_ids[:5], 5),
)
np.testing.assert_equal(
cluster_members_filtered.obs_id.to_numpy(zero_copy_only=False),
np.hstack(np.array(obs_ids)),
)
Loading