-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SageMaker] Add SageMaker GraphBolt internal implementation (#1116)
*Issue #, if available:* *Description of changes:* * The original GraphBolt on SageMaker PR was missing the internal (GSF library) implementation, we only provided the launch script and entry point. This adds the missing module in the GraphStorm library. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
- Loading branch information
Showing
1 changed file
with
112 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
""" Copyright 2023 Contributors | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
SageMaker implementation of DistDGL to GraphBolt conversion. | ||
""" | ||
|
||
import importlib.metadata | ||
import logging | ||
import os | ||
import time | ||
from collections import defaultdict | ||
from packaging import version | ||
|
||
import boto3 | ||
import sagemaker | ||
from sagemaker.s3 import S3Uploader | ||
|
||
from graphstorm.gpartition.convert_to_graphbolt import run_gb_conversion | ||
|
||
|
||
def run_gb_convert(s3_output_path: str, local_dist_part_config: str, njobs: int): | ||
"""Convert DistDGL partitions to GraphBolt format on SageMaker Processing. | ||
Parameters | ||
---------- | ||
s3_output_path : str | ||
The S3 prefix under which we will upload the GB converted graph files. | ||
local_dist_part_config : str | ||
The local path to the downloaded partition data metadata.json file. | ||
njobs: int | ||
Number of parallel processes to use during GraphBolt conversion. | ||
One process per partition is used, so ``njobs`` partitions needs | ||
to be able to fit into memory. | ||
Raises | ||
------ | ||
ValueError | ||
If the version of DGL used is under 2.1.0 | ||
""" | ||
assert os.path.isfile( | ||
local_dist_part_config | ||
), f"{local_dist_part_config=} does not exist." | ||
|
||
dgl_version = importlib.metadata.version("dgl") | ||
if version.parse(dgl_version) < version.parse("2.1.0"): | ||
raise ValueError( | ||
"GraphBolt conversion requires DGL version >= 2.1.0, " | ||
f"but DGL version was {dgl_version}. " | ||
) | ||
|
||
boto_session = boto3.Session(region_name=os.environ["AWS_REGION"]) | ||
sagemaker_session = sagemaker.Session(boto_session=boto_session) | ||
|
||
# Run the actual conversion, this will create the fused_csc_sampling_graph.pt | ||
# under each partition in local_dist_part_path and a new metadata.json | ||
gb_start = time.time() | ||
run_gb_conversion(local_dist_part_config, njobs) | ||
logging.info("GraphBolt conversion took %f sec.", time.time() - gb_start) | ||
|
||
# Iterate through the partition data and upload only the modified/new | ||
# files to the corresponding path on S3 | ||
upload_start = time.time() | ||
fused_files_exist = defaultdict(lambda: False) | ||
for root, _, files in os.walk(os.path.dirname(local_dist_part_config)): | ||
for file in files: | ||
if file.endswith("fused_csc_sampling_graph.pt"): | ||
partition_id = root.split("/")[-1] | ||
# Set fused file existence to true for this partition | ||
fused_files_exist[partition_id] = True | ||
# Partition data need to be uploaded to partition-id dirs | ||
s3_path = os.path.join(s3_output_path, f"{partition_id}") | ||
elif file.endswith(".json"): | ||
# Partition output metadata file needs to be uploaded to root dir | ||
s3_path = s3_output_path | ||
else: | ||
# We skip other files | ||
partition_id = root.split("/")[-1] | ||
if "part" in partition_id: | ||
# Set file existence to False only if | ||
# we haven't encountered a fused file already | ||
fused_files_exist[partition_id] = ( | ||
False or fused_files_exist[partition_id] | ||
) | ||
continue | ||
|
||
logging.info("Uploading local %s to %s", os.path.join(root, file), s3_path) | ||
S3Uploader.upload( | ||
local_path=os.path.join(root, file), | ||
desired_s3_uri=s3_path, | ||
sagemaker_session=sagemaker_session, | ||
) | ||
|
||
for partition_id, fused_file_exists in fused_files_exist.items(): | ||
if not fused_file_exists: | ||
raise RuntimeError( | ||
f"Partition {partition_id} did not have " | ||
"a fused_csc_sampling_graph.pt file." | ||
) | ||
|
||
logging.info("Uploading took %f sec.", time.time() - upload_start) |