Skip to content

Commit

Permalink
fix: merge_insert with subcols sometimes outputs unexpected nulls (#3407
Browse files Browse the repository at this point in the history
)

Fixes #3406

At the root of this is a bit of a footgun with DataFusion. Prior to this
change, the query plan for getting data that was supposed to be sorted
by `_rowaddr` was:

```
ProjectionExec: expr=[id@0 as id, vector@1 as vector, _rowaddr@2 as _rowaddr, _rowaddr@2 >> 32 as _fragment_id], schema=[id:Int64;N, vector:FixedSizeList(Field { name: "item", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 32);N, _rowaddr:UInt64;N, _fragment_id:UInt64;N]
  RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, schema=[id:Int64;N, vector:FixedSizeList(Field { name: "item", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 32);N, _rowaddr:UInt64;N]
    SortExec: expr=[_rowaddr@2 ASC], preserve_partitioning=[false], schema=[id:Int64;N, vector:FixedSizeList(Field { name: "item", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 32);N, _rowaddr:UInt64;N]
      StreamingTableExec: partition_sizes=1, projection=[id, vector, _rowaddr], schema=[id:Int64;N, vector:FixedSizeList(Field { name: "item", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 32);N, _rowaddr:UInt64;N]
```

Note the `RepartitionExec` **after** the `SortExec`. This caused the
final order to be non-deterministic.

After these changes, the plan is:

```
SortPreservingMergeExec: [_rowaddr@2 ASC], schema=[id:Int64;N, vector:FixedSizeList(Field { name: "item", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 32);N, _rowaddr:UInt64;N, _fragment_id:UInt64;N]
  SortExec: expr=[_rowaddr@2 ASC], preserve_partitioning=[true], schema=[id:Int64;N, vector:FixedSizeList(Field { name: "item", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 32);N, _rowaddr:UInt64;N, _fragment_id:UInt64;N]
    ProjectionExec: expr=[id@0 as id, vector@1 as vector, _rowaddr@2 as _rowaddr, _rowaddr@2 >> 32 as _fragment_id], schema=[id:Int64;N, vector:FixedSizeList(Field { name: "item", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 32);N, _rowaddr:UInt64;N, _fragment_id:UInt64;N]
      RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, schema=[id:Int64;N, vector:FixedSizeList(Field { name: "item", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 32);N, _rowaddr:UInt64;N]
        StreamingTableExec: partition_sizes=1, projection=[id, vector, _rowaddr], schema=[id:Int64;N, vector:FixedSizeList(Field { name: "item", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 32);N, _rowaddr:UInt64;N]
```

Which does provide a deterministic order.
  • Loading branch information
wjones127 authored Jan 22, 2025
1 parent 7f60aa0 commit 3cb54c6
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
33 changes: 33 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1670,6 +1670,39 @@ def test_merge_insert_vector_column(tmp_path: Path):
check_merge_stats(merge_dict, (1, 1, 0))


def test_merge_insert_large():
# Doing subcolumns update with merge insert triggers this error.
# Data needs to be large enough to make DataFusion create multiple batches
# when outputting join results.
# https://github.com/lancedb/lance/issues/3406
# This test is in Python because for whatever reason, the error doesn't
# reproduce in the equivalent Rust test.
dims = 32
nrows = 20_000
data = pa.table({"id": range(nrows), "num": [str(i) for i in range(nrows)]})

ds = lance.write_dataset(data, "memory://")

ds.add_columns({"vector": f"arrow_cast(NULL, 'FixedSizeList({dims}, Float32)')"})

batch_size = 10_000
other_columns = pa.table(
{
"id": range(batch_size),
"vector": pa.FixedSizeListArray.from_arrays(
pc.random(batch_size * dims).cast(pa.float32()), dims
),
}
)

(
ds.merge_insert(on="id")
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(other_columns)
)


def check_update_stats(update_dict, expected):
assert (update_dict["num_rows_updated"],) == expected

Expand Down
2 changes: 1 addition & 1 deletion rust/lance/src/dataset/write/merge_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -670,8 +670,8 @@ impl MergeInsertJob {
});
let mut group_stream = session_ctx
.read_one_shot(source)?
.sort(vec![col(ROW_ADDR).sort(true, true)])?
.with_column("_fragment_id", col(ROW_ADDR) >> lit(32))?
.sort(vec![col(ROW_ADDR).sort(true, true)])?
.group_by_stream(&["_fragment_id"])
.await?;

Expand Down

0 comments on commit 3cb54c6

Please sign in to comment.