Skip to content

Commit

Permalink
include blockwise layer to finish shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Dec 2, 2023
1 parent 7de1bfc commit 07ba167
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import partial
from functools import partial, cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -187,11 +187,17 @@ def rearrange_by_column_p2p(
meta_input=meta,
disk=disk,
)
_barrier_key = layer._tokens[1]
return new_dd_object(
HighLevelGraph.from_collections(name, layer, [df]),
name,
meta,
[None] * (npartitions + 1),
).map_partitions(
_get_partition_data,
_barrier_key,
meta=meta,
enforce_metadata=False,
)


Expand Down Expand Up @@ -312,10 +318,15 @@ def cull(
else:
return self, culled_deps

def _construct_graph(self) -> _T_LowLevelGraph:
@cached_property
def _tokens(self):
token = tokenize(self.name_input, self.column, self.npartitions, self.parts_out)
dsk: _T_LowLevelGraph = {}
_barrier_key = barrier_key(ShuffleId(token))
return token, _barrier_key

def _construct_graph(self) -> _T_LowLevelGraph:
token, _barrier_key = self._tokens
dsk: _T_LowLevelGraph = {}
name = "shuffle-transfer-" + token
transfer_keys = list()
for i in range(self.npartitions_input):
Expand All @@ -334,23 +345,14 @@ def _construct_graph(self) -> _T_LowLevelGraph:

dsk[_barrier_key] = (shuffle_barrier, token, transfer_keys)

name_lazy = f"lazy-{self.name}"
name = self.name
for part_out in self.parts_out:
dsk[(name_lazy, part_out)] = (
dsk[(name, part_out)] = (
shuffle_unpack,
token,
part_out,
_barrier_key,
)

# TODO: Do this in a Blockwise layer after the shuffle
name = self.name
for part_out in self.parts_out:
dsk[(name, part_out)] = (
_get_partition_data,
(name_lazy, part_out),
_barrier_key,
)
return dsk


Expand Down

0 comments on commit 07ba167

Please sign in to comment.