Skip to content

Commit

Permalink
fix(cluster): fix edge case with connected_components() and no edges …
Browse files Browse the repository at this point in the history
…but records

The .max() of the labels was getting called on a column with 0 entries, so the result was NULL
  • Loading branch information
NickCrews committed Oct 6, 2024
1 parent 9194de4 commit b40e768
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mismo/cluster/_connected_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ def _label_datasets(ds: Datasets, labels: ir.Table, *, label_as: str) -> Dataset
def _get_additional_labels(labels: ir.Table, record_ids: ir.Column) -> ir.Table:
nodes = record_ids.name("record_id").as_table()
is_missing_label = nodes.record_id.notin(labels.record_id)
max_existing_label = labels.component.max()
# fill_null(0) in case labels is empty
max_existing_label = labels.component.max().fill_null(0)
additional_labels = nodes.filter(is_missing_label).select(
"record_id",
component=(ibis.row_number() + max_existing_label + 1).cast("int64"),
Expand Down
11 changes: 11 additions & 0 deletions mismo/cluster/test/test_connected_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,17 @@ def test_cc_single_records(table_factory, label_as):
assert clusters == {frozenset({0, 1, 2}), frozenset({3})}


def test_cc_no_links_but_records(table_factory, label_as):
"""If there are no links, each record should be its own cluster."""
link_schema = {"record_id_l": "int64", "record_id_r": "int64"}
link_df = pd.DataFrame({"record_id_l": [], "record_id_r": []})
links = table_factory(link_df, schema=link_schema)
nodes = table_factory({"record_id": [0, 1, 2]})
labeled = connected_components(links=links, records=nodes, label_as=label_as)
clusters = _labels_to_clusters(labeled, label_as)
assert clusters == {frozenset({0}), frozenset({1}), frozenset({2})}


def test_cc_multi_records(table_factory, label_as):
# multiple input record tables
links = table_factory([(0, 1), (1, 2)], columns=["record_id_l", "record_id_r"])
Expand Down

0 comments on commit b40e768

Please sign in to comment.