Skip to content

Commit

Permalink
lightning.data: Fix some bugs with optimize (#18949)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas <[email protected]>
  • Loading branch information
tchaton and thomas authored Nov 5, 2023
1 parent 0e7a3b0 commit 3a86097
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
11 changes: 6 additions & 5 deletions src/lightning/data/streaming/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def _get_cache_data_dir(name: Optional[str] = None) -> str:
return os.path.join(cache_dir, name.lstrip("/"))


def _wait_for_file_to_exist(s3: Any, obj: parse.ParseResult, sleep_time: int = 2) -> Any:
def _wait_for_file_to_exist(s3: S3Client, obj: parse.ParseResult, sleep_time: int = 2) -> Any:
"""This function check."""
while True:
try:
return s3.head_object(Bucket=obj.netloc, Key=obj.path.lstrip("/"))
return s3.client.head_object(Bucket=obj.netloc, Key=obj.path.lstrip("/"))
except botocore.exceptions.ClientError as e:
if "the HeadObject operation: Not Found" in str(e):
sleep(sleep_time)
Expand Down Expand Up @@ -659,7 +659,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
obj = parse.urlparse(remote_filepath)
_wait_for_file_to_exist(s3, obj)
with open(node_index_filepath, "wb") as f:
s3.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)
s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)
elif os.path.isdir(output_dir.path):
copyfile(remote_filepath, node_index_filepath)

Expand Down Expand Up @@ -799,15 +799,16 @@ def run(self, data_recipe: DataRecipe) -> None:
break

num_nodes = _get_num_nodes()
node_rank = _get_node_rank()
# TODO: Understand why it hangs.
if num_nodes == 1:
for w in self.workers:
w.join(0)

print("Workers are finished.")
result = data_recipe._done(num_items, self.delete_cached_files, self.output_dir)
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)

if num_nodes == _get_node_rank() + 1:
if num_nodes == node_rank + 1:
_create_dataset(
input_dir=self.input_dir.path,
storage_dir=self.output_dir.path,
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_data/streaming/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def fn(*_, **__):
raise botocore.exceptions.ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
return

s3.head_object = fn
s3.client.head_object = fn

_wait_for_file_to_exist(s3, obj, sleep_time=0.01)

Expand All @@ -213,7 +213,7 @@ def fn(*_, **__):
def fn(*_, **__):
raise ValueError("HERE")

s3.head_object = fn
s3.client.head_object = fn

with pytest.raises(ValueError, match="HERE"):
_wait_for_file_to_exist(s3, obj, sleep_time=0.01)
Expand Down

0 comments on commit 3a86097

Please sign in to comment.