Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added toolkit compare #719

Merged
merged 10 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions src/datachain/lib/diff.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random
import string
from collections.abc import Sequence
from enum import Enum
from typing import TYPE_CHECKING, Optional, Union

import sqlalchemy as sa
Expand All @@ -16,6 +17,21 @@
C = Column


def get_status_col_name() -> str:
Copy link
Member

@skshetry skshetry Dec 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the column name need to be random? Can we have a default column name that can be changed by users?

Eg:

def compare(col="status"):
   pass


dc.compare(col=...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this column name will not be in the results. It's only needed for our internal implementation of the diff. User will have separate chains for each status and status column is not needed in that case.

"""Returns new unique status col name"""
return "diff_" + "".join(
random.choice(string.ascii_letters) # noqa: S311
for _ in range(10)
)


class CompareStatus(str, Enum):
ADDED = "A"
DELETED = "D"
MODIFIED = "M"
UNCHANGED = "U"


def compare( # noqa: PLR0912, PLR0915, C901
left: "DataChain",
right: "DataChain",
Expand Down Expand Up @@ -72,13 +88,10 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]:
"At least one of added, deleted, modified, unchanged flags must be set"
)

# we still need status column for internal implementation even if not
# needed in output
need_status_col = bool(status_col)
status_col = status_col or "diff_" + "".join(
random.choice(string.ascii_letters) # noqa: S311
for _ in range(10)
)
# we still need status column for internal implementation even if not
# needed in the output
status_col = status_col or get_status_col_name()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make sure you drop this random column if it was None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Status column is dropped with select_except(...) before returning to the user.


# calculate on and compare column names
right_on = right_on or on
Expand Down Expand Up @@ -112,25 +125,27 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]:
for c in [f"{_rprefix(c, rc)}{rc}" for c, rc in zip(on, right_on)]
]
)
diff_cond.append((added_cond, "A"))
diff_cond.append((added_cond, CompareStatus.ADDED))

This comment was marked as off-topic.

if modified and compare:
modified_cond = sa.or_(
*[
C(c) != C(f"{_rprefix(c, rc)}{rc}")
for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
]
)
diff_cond.append((modified_cond, "M"))
diff_cond.append((modified_cond, CompareStatus.MODIFIED))
if unchanged and compare:
unchanged_cond = sa.and_(
*[
C(c) == C(f"{_rprefix(c, rc)}{rc}")
for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
]
)
diff_cond.append((unchanged_cond, "U"))
diff_cond.append((unchanged_cond, CompareStatus.UNCHANGED))

diff = sa.case(*diff_cond, else_=None if compare else "M").label(status_col)
diff = sa.case(*diff_cond, else_=None if compare else CompareStatus.MODIFIED).label(
status_col
)
diff.type = String()

left_right_merge = left.merge(
Expand All @@ -145,7 +160,7 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]:
)
)

diff_col = sa.literal("D").label(status_col)
diff_col = sa.literal(CompareStatus.DELETED).label(status_col)
diff_col.type = String()

right_left_merge = right.merge(
Expand Down
3 changes: 2 additions & 1 deletion src/datachain/toolkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .diff import compare
from .split import train_test_split

__all__ = ["train_test_split"]
__all__ = ["compare", "train_test_split"]
111 changes: 111 additions & 0 deletions src/datachain/toolkit/diff.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain the difference between a toolkit and a lib?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea was to move "heavy" code from dc.py to somewhere else and keep only simple wrapper function in dc.py.

If it's implemented in lib/diff.py - it's enough and we don't need toolkit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we organize into individual top-level modules, eg: datachain.diff, instead of cramming everything in a nested module in datachain.toolkit or datahchain.lib modules?

Namespaces are one honking great idea -- let's do more of those!
Flat is better than nested.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that looks like the best option.

@ilongin what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also for top-level modules. I will move this to datachain.diff

Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional, Union

from datachain.lib.diff import CompareStatus, get_status_col_name
from datachain.lib.diff import compare as chain_compare
from datachain.query.schema import Column

if TYPE_CHECKING:
from datachain.lib.dc import DataChain


C = Column


def compare(
left: "DataChain",
right: "DataChain",
on: Union[str, Sequence[str]],
right_on: Optional[Union[str, Sequence[str]]] = None,
compare: Optional[Union[str, Sequence[str]]] = None,
right_compare: Optional[Union[str, Sequence[str]]] = None,
added: bool = True,
deleted: bool = True,
modified: bool = True,
unchanged: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need these arguments? Can we leave this up to the user to filter?

dc.compare(...).filter(C("col") == "added")

Copy link
Contributor Author

@ilongin ilongin Dec 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataChain.compare() returns new chain with that column which can be filtered and this new toolkit method does exactly what you described, so it saves the user that one filter step. Now, we can discuss if toolkit method is even needed in the first place.
So with this PR we have:

  1. compare() in src.datachain.diff -> accepts 2 chains and returns new "diff" chain with status column
  2. DataChain.compare() -> simple wrapper around 1) where left chain is self
  3. compare() in src.datachain.toolkit -> wrapper around 1) but instead of returning one "diff" chain with status column, it splits that chain into multiple chains where each chain represents only one status which is basically what you did in your comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with this new toolkit we maybe have too many functions, although 2) was meant to be "private"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the 3rd should have a different name and should follow similar pattern as the compare() - it should be in dc.py if there is not much code or in lib.diff otherwise with a wrapper in dc.py

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, lib.diff is a better place for the code than a new toolkit.

PS: I might be the person who proposed the toolkit file but it does not seem a good idea in this case 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 3) should not be in dc.py as it returns multiple instances of DataChain so it should be in util file.
Also, the question is should we use src.datachain.diff.py or src.datachain.lib.diff.py for public util functions?

) -> dict[str, "DataChain"]:
"""Comparing two chains by identifying rows that are added, deleted, modified
or unchanged. Result is the new chain that has additional column with possible
values: `A`, `D`, `M`, `U` representing added, deleted, modified and unchanged
rows respectively. Note that if only one "status" is asked, by setting proper
flags, this additional column is not created as it would have only one value
for all rows. Beside additional diff column, new chain has schema of the chain
on which method was called.

ilongin marked this conversation as resolved.
Show resolved Hide resolved
Comparing two chains and returning multiple chains, one for each of `added`,
`deleted`, `modified` and `unchanged` status. Result is returned in form of
dictionary where each item represents one of the statuses and key values
are `A`, `D`, `M`, `U` corresponding. Note that status column is not in the
resulting chains.

Parameters:
left: Chain to calculate diff on.
right: Chain to calculate diff from.
on: Column or list of columns to match on. If both chains have the
same columns then this column is enough for the match. Otherwise,
`right_on` parameter has to specify the columns for the other chain.
This value is used to find corresponding row in other dataset. If not
found there, row is considered as added (or removed if vice versa), and
if found then row can be either modified or unchanged.
right_on: Optional column or list of columns
for the `other` to match.
compare: Column or list of columns to compare on. If both chains have
the same columns then this column is enough for the compare. Otherwise,
`right_compare` parameter has to specify the columns for the other
chain. This value is used to see if row is modified or unchanged. If
not set, all columns will be used for comparison
right_compare: Optional column or list of columns
for the `other` to compare to.
added (bool): Whether to return chain containing only added rows.
deleted (bool): Whether to return chain containing only deleted rows.
modified (bool): Whether to return chain containing only modified rows.
unchanged (bool): Whether to return chain containing only unchanged rows.

Example:
```py
chains = compare(
persons,
new_persons,
on=["id"],
right_on=["other_id"],
compare=["name"],
added=True,
deleted=True,
modified=True,
unchanged=True,
)
```
"""
status_col = get_status_col_name()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User can define it, can't they?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this toolkit method status column is not returned to the user so it's created only for our internal implementation and removed before returning to the user. User can define status column in core DataChain.compare() which returns one chain with all statuses written in that status column


res = chain_compare(
left,
right,
on,
right_on=right_on,
compare=compare,
right_compare=right_compare,
added=added,
deleted=deleted,
modified=modified,
unchanged=unchanged,
status_col=status_col,
)

chains = {}

def filter_by_status(compare_status) -> "DataChain":
return res.filter(C(status_col) == compare_status).select_except(status_col)

if added:
chains[CompareStatus.ADDED.value] = filter_by_status(CompareStatus.ADDED)
if deleted:
chains[CompareStatus.DELETED.value] = filter_by_status(CompareStatus.DELETED)
if modified:
chains[CompareStatus.MODIFIED.value] = filter_by_status(CompareStatus.MODIFIED)
if unchanged:
chains[CompareStatus.UNCHANGED.value] = filter_by_status(
CompareStatus.UNCHANGED
)

return chains
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how status column is cleaned up. or I'm missing something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's cleaned in filter_by_status() method with select_except()

67 changes: 66 additions & 1 deletion tests/func/test_toolkit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from datachain.toolkit import train_test_split
from datachain.lib.dc import DataChain
from datachain.toolkit import compare, train_test_split


@pytest.mark.parametrize(
Expand Down Expand Up @@ -49,3 +50,67 @@ def test_train_test_split_errors(not_random_ds):
train_test_split(not_random_ds, [0.5])
with pytest.raises(ValueError, match="Weights should be non-negative"):
train_test_split(not_random_ds, [-1, 1])


@pytest.mark.parametrize("added", (True, False))
@pytest.mark.parametrize("deleted", (True, False))
@pytest.mark.parametrize("modified", (True, False))
@pytest.mark.parametrize("unchanged", (True, False))
def test_compare(test_session, added, deleted, modified, unchanged):
ds1 = DataChain.from_values(
id=[1, 2, 4],
name=["John1", "Doe", "Andy"],
session=test_session,
).save("ds1")

ds2 = DataChain.from_values(
id=[1, 3, 4],
name=["John", "Mark", "Andy"],
session=test_session,
).save("ds2")

if not any([added, deleted, modified, unchanged]):
with pytest.raises(ValueError) as exc_info:
compare(
ds1,
ds2,
added=added,
deleted=deleted,
modified=modified,
unchanged=unchanged,
on=["id"],
)
assert str(exc_info.value) == (
"At least one of added, deleted, modified, unchanged flags must be set"
)
return

chains = compare(
ds1,
ds2,
added=added,
deleted=deleted,
modified=modified,
unchanged=unchanged,
on=["id"],
)

collect_fields = ["id", "name"]
if added:
assert "diff" not in chains["A"].signals_schema.db_signals()
assert list(chains["A"].order_by("id").collect(*collect_fields)) == [(2, "Doe")]
if deleted:
assert "diff" not in chains["D"].signals_schema.db_signals()
assert list(chains["D"].order_by("id").collect(*collect_fields)) == [
(3, "Mark")
]
if modified:
assert "diff" not in chains["M"].signals_schema.db_signals()
assert list(chains["M"].order_by("id").collect(*collect_fields)) == [
(1, "John1")
]
if unchanged:
assert "diff" not in chains["U"].signals_schema.db_signals()
assert list(chains["U"].order_by("id").collect(*collect_fields)) == [
(4, "Andy")
]
Loading
Loading