From 5eb984d33c42ad2c47f3932babde164f390fc89b Mon Sep 17 00:00:00 2001 From: Theodore Vasiloudis Date: Sun, 5 Jan 2025 09:17:02 +0000 Subject: [PATCH 1/9] [Examples] SageMaker Pipelines distributed training --- .../Dockerfile.processing | 39 +++ .../analyze_training_time.py | 324 ++++++++++++++++++ .../build_and_push_papers100M_image.sh | 39 +++ .../convert_arxiv_to_gconstruct.py | 163 +++++++++ .../convert_ogb_papers100M_to_gconstruct.py | 288 ++++++++++++++++ .../process_papers100M.sh | 34 ++ 6 files changed, 887 insertions(+) create mode 100644 examples/sagemaker-pipelines-graphbolt/Dockerfile.processing create mode 100644 examples/sagemaker-pipelines-graphbolt/analyze_training_time.py create mode 100644 examples/sagemaker-pipelines-graphbolt/build_and_push_papers100M_image.sh create mode 100644 examples/sagemaker-pipelines-graphbolt/convert_arxiv_to_gconstruct.py create mode 100644 examples/sagemaker-pipelines-graphbolt/convert_ogb_papers100M_to_gconstruct.py create mode 100644 examples/sagemaker-pipelines-graphbolt/process_papers100M.sh diff --git a/examples/sagemaker-pipelines-graphbolt/Dockerfile.processing b/examples/sagemaker-pipelines-graphbolt/Dockerfile.processing new file mode 100644 index 0000000000..0470c21b2b --- /dev/null +++ b/examples/sagemaker-pipelines-graphbolt/Dockerfile.processing @@ -0,0 +1,39 @@ +FROM public.ecr.aws/ubuntu/ubuntu:22.04 + +# Avoid prompts from apt +ENV DEBIAN_FRONTEND=noninteractive + +# Install Python and other dependencies +RUN apt update && apt install -y \ + axel \ + curl \ + python3 \ + python3-pip \ + tree \ + unzip \ + && rm -rf /var/lib/apt/lists/* + + +COPY ripunzip_2.0.0-1_amd64.deb ripunzip_2.0.0-1_amd64.deb +RUN apt install -y ./ripunzip_2.0.0-1_amd64.deb + +RUN python3 -m pip install --no-cache-dir --upgrade pip==24.3.1 && \ + python3 -m pip install --no-cache-dir \ + numpy==1.26.4 \ + psutil==6.1.0 \ + pyarrow==18.1.0 \ + tqdm==4.67.1 \ + tqdm-loggable==0.2 + +RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" \ + && unzip awscliv2.zip \ + && ./aws/install + + +# Copy processing scripts +COPY process_papers100M.sh /opt/ml/code/ +COPY convert_ogb_papers100M_to_gconstruct.py /opt/ml/code/ + +WORKDIR /opt/ml/code/ + +CMD ["bash", "/opt/ml/code/process_papers100M.sh"] diff --git a/examples/sagemaker-pipelines-graphbolt/analyze_training_time.py b/examples/sagemaker-pipelines-graphbolt/analyze_training_time.py new file mode 100644 index 0000000000..c17674b825 --- /dev/null +++ b/examples/sagemaker-pipelines-graphbolt/analyze_training_time.py @@ -0,0 +1,324 @@ +""" +Copyright Contributors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Analyzes the epoch and evaluation time for GraphStorm training jobs. +""" + +import argparse +import re +import time +from datetime import datetime, timedelta +from typing import Iterator, Dict, List, Union + +import boto3 + +LOG_GROUP = "/aws/sagemaker/TrainingJobs" + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Analyze training epoch and eval time." + ) + source_group = parser.add_mutually_exclusive_group(required=True) + # Add pipeline name as arg + source_group.add_argument( + "--pipeline-name", + type=str, + help="The name of the pipeline.", + ) + # Add execution id as arg + parser.add_argument( + "--execution-name", + type=str, + help="The display name of the execution.", + ) + source_group.add_argument( + "--log-file", + type=str, + help="The name of a file containing logs from a local pipeline execution.", + ) + + parser.add_argument( + "--region", + type=str, + default="us-east-1", + help="The region of the log stream.", + ) + parser.add_argument( + "--verbose", + type=bool, + default=False, + help="Whether to print verbose output.", + ) + # Add days past as arg + parser.add_argument( + "--logs-days-before", + type=int, + default=2, + help="The number of days in the past to start analyzing logs.", + ) + return parser.parse_args() + + +def read_local_logs(file_path: str) -> Iterator[Dict]: + """Read logs from a local file and yield them in a format similar to CloudWatch events.""" + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + yield { + "message": line.strip(), + "timestamp": int(time.time() * 1000), # Current time in milliseconds + } + + +def get_pipeline_execution_arn(pipeline_name: str, execution_name: str) -> str: + """Get the execution ARN from a pipeline name and display name for the execution.""" + sm_client = boto3.client("sagemaker") + + try: + # List pipeline executions and find the matching one + paginator = sm_client.get_paginator("list_pipeline_executions") + for page in paginator.paginate(PipelineName=pipeline_name): + for execution in page["PipelineExecutionSummaries"]: + if execution_name in execution["PipelineExecutionDisplayName"]: + return execution["PipelineExecutionArn"] + + raise ValueError( + f"No execution found with display name {execution_name} in pipeline {pipeline_name}" + ) + + except Exception as e: + print(f"Error getting pipeline execution ARN: {e}") + raise e + + +def get_cloudwatch_logs( + logs_client, log_group: str, log_stream: str, start_time: int, end_time: int +) -> Iterator[Dict]: + """Get logs containing 'INFO' and either 'Epoch' or 'eval' from CloudWatch as a generator.""" + paginator = logs_client.get_paginator("filter_log_events") + + for page in paginator.paginate( + logGroupName=log_group, + logStreamNames=[log_stream], + startTime=start_time, + endTime=end_time, + filterPattern="INFO ?Epoch ?eval", + ): + events_generator: Iterator = page.get("events", []) + yield from events_generator + + +def analyze_logs( + log_source: Union[str, tuple[str, str, str]], + days_before: int = 2, +): + """ + Analyze logs from either CloudWatch or a local file. + + Args: + log_source: Either a path to a local file (str) or a tuple of + (pipeline_name, execution_id, step_name) + days_before: Number of days in the past to start analyzing logs + """ + + # Gather events, either from file or from CloudWatch + if isinstance(log_source, str): + print(f"Reading logs from file: {log_source}") + log_events = read_local_logs(log_source) + else: + try: + start_time = int( + (datetime.now() - timedelta(days=days_before)).timestamp() * 1000 + ) + end_time = int(datetime.now().timestamp() * 1000) + + # Unpack the logs source + pipeline_name, execution_name, step_name = log_source + + # Get the training job name which is the prefix of the log stream + train_job_id = get_training_job_name( + pipeline_name, execution_name, step_name + ) + + # Get the log stream + logs_client = boto3.client("logs") + # Get log streams that match the prefix + log_streams_response = logs_client.describe_log_streams( + logGroupName=LOG_GROUP, + logStreamNamePrefix=train_job_id, + ) + + for log_stream in log_streams_response["logStreams"]: + if "algo-1" in log_stream["logStreamName"]: + log_stream_name = log_stream["logStreamName"] + break + else: + raise RuntimeError( + f"No log stream found with prefix {train_job_id}/algo-1" + ) + + print(f"Analyzing log stream: {log_stream_name}") + print(f"Time range: {datetime.fromtimestamp(start_time/1000)}") + print(f" to: {datetime.fromtimestamp(end_time/1000)}") + + log_events = get_cloudwatch_logs( + logs_client, LOG_GROUP, log_stream, start_time, end_time + ) + except Exception as e: + print(f"Error while retrieving logs from CloudWatch: {e}") + raise e + + # Patterns for both types of logs + epoch_pattern = re.compile(r"INFO:root:Epoch (\d+) take (\d+\.\d+) seconds") + eval_pattern = re.compile( + r"INFO:root: Eval time: (\d+\.\d+), Evaluation step: (\d+)" + ) + epochs_data = [] + eval_data = [] + + for event in log_events: + epoch_match = epoch_pattern.search(event["message"]) + eval_match = eval_pattern.search(event["message"]) + + if epoch_match: + epochs_data.append( + { + "epoch": int(epoch_match.group(1)), + "time": float(epoch_match.group(2)), + "timestamp": datetime.fromtimestamp(event["timestamp"] / 1000), + } + ) + elif eval_match: + eval_data.append( + { + "time": float(eval_match.group(1)), + "step": int(eval_match.group(2)), + "timestamp": datetime.fromtimestamp(event["timestamp"] / 1000), + } + ) + + # We have gathered the relevant events, return for processing + return epochs_data, eval_data + + +def get_training_job_name(pipeline_name: str, execution_id: str, step_name: str) -> str: + """Get training job name for a step in a specific pipeline execution""" + sm_client = boto3.client("sagemaker") + + try: + # Get the full execution ARN first + execution_arn = get_pipeline_execution_arn(pipeline_name, execution_id) + print(f"Found execution ARN: {execution_arn}") + + # Get the pipeline execution details + execution_steps = sm_client.list_pipeline_execution_steps( + PipelineExecutionArn=execution_arn + ) + + # Find the desired step + target_step = None + for step in execution_steps["PipelineExecutionSteps"]: + if step["StepName"] == step_name: + target_step = step + break + else: + raise ValueError(f"Step '{step_name}' not found in pipeline execution") + + # Get the training job name from metadata + metadata = target_step["Metadata"] + if "TrainingJob" not in metadata: + raise ValueError(f"No training job found in step '{step_name}'") + + training_job_name = metadata["TrainingJob"]["Arn"].split("/")[-1] + + return training_job_name + + except Exception as e: + print(f"Error while getting training job name: {e}") + raise e + + +def print_training_summary( + epochs_data: List[Dict], eval_data: List[Dict], verbose: bool +): + """Prints a summary of the epoch time and eval time for a GraphStorm training job""" + + print("\n=== Training Epochs Summary ===") + if epochs_data: + total_epochs = len(epochs_data) + avg_time = sum(e["time"] for e in epochs_data) / total_epochs + min_time = min(epochs_data, key=lambda x: x["time"]) + max_time = max(epochs_data, key=lambda x: x["time"]) + + print(f"Total epochs completed: {total_epochs}") + print(f"Average epoch time: {avg_time:.2f} seconds") + print( + f"Fastest epoch: Epoch {min_time['epoch']} ({min_time['time']:.2f} seconds)" + ) + print( + f"Slowest epoch: Epoch {max_time['epoch']} ({max_time['time']:.2f} seconds)" + ) + + if verbose: + print("\nEpoch Details:") + for data in epochs_data: + print( + f"Epoch {data['epoch']:3d}: {data['time']:6.2f} seconds " + f"[{data['timestamp'].strftime('%Y-%m-%d %H:%M:%S')}]" + ) + + print("\n=== Evaluation Summary ===") + if eval_data: + total_evals = len(eval_data) + avg_eval_time = sum(e["time"] for e in eval_data) / total_evals + min_eval = min(eval_data, key=lambda x: x["time"]) + max_eval = max(eval_data, key=lambda x: x["time"]) + + print(f"Total evaluations: {total_evals}") + print(f"Average evaluation time: {avg_eval_time:.2f} seconds") + print( + f"Fastest evaluation: Step {min_eval['step']} ({min_eval['time']:.2f} seconds)" + ) + print( + f"Slowest evaluation: Step {max_eval['step']} ({max_eval['time']:.2f} seconds)" + ) + + if verbose: + print("\nEvaluation Details:") + for data in eval_data: + print( + f"Step {data['step']:4d}: {data['time']:6.2f} seconds " + f"[{data['timestamp'].strftime('%Y-%m-%d %H:%M:%S')}]" + ) + + +if __name__ == "__main__": + args = parse_args() + client = boto3.client("logs", region_name=args.region) + if args.log_file: + log_representation = args.log_file + else: + log_stream_prefix = get_training_job_name( + args.pipeline_name, args.execution_name, "Training" + ) + log_representation = (args.pipeline_name, args.execution_name, "Training") + # Get the training job name which is the prefix of the log stream + print(f"Found training job: {log_stream_prefix}") + + retrieved_epochs_data, retrieved_eval_data = analyze_logs( + log_representation, args.logs_days_before + ) + + print_training_summary(retrieved_epochs_data, retrieved_eval_data, args.verbose) diff --git a/examples/sagemaker-pipelines-graphbolt/build_and_push_papers100M_image.sh b/examples/sagemaker-pipelines-graphbolt/build_and_push_papers100M_image.sh new file mode 100644 index 0000000000..4c6fadaee3 --- /dev/null +++ b/examples/sagemaker-pipelines-graphbolt/build_and_push_papers100M_image.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash +set -xEeuo pipefail +trap cleanup SIGINT SIGTERM ERR EXIT + +cleanup() { + trap - SIGINT SIGTERM ERR EXIT + # script cleanup here + rm -f ripunzip_2.0.0-1_amd64.deb +} + + +ACCOUNT=$(aws sts get-caller-identity --query Account --output text) +REGION=$(aws configure get region) +REGION=${REGION:-us-east-1} +IMAGE=papers100m-processor + +curl -L -O https://github.com/google/ripunzip/releases/download/v2.0.0/ripunzip_2.0.0-1_amd64.deb + +# Auth to AWS public ECR gallery +aws ecr-public get-login-password --region $REGION | docker login --username AWS --password-stdin public.ecr.aws + +# Build and tag image +docker build -f Dockerfile.processing -t $IMAGE . + + +# Create repository if it doesn't exist +echo "Getting or creating container repository: $IMAGE" +if ! $(aws ecr describe-repositories --repository-names $IMAGE --region ${REGION} > /dev/null 2>&1); then + echo >&2 "WARNING: ECR repository $IMAGE does not exist in region ${REGION}. Creating..." + aws ecr create-repository --repository-name $IMAGE --region ${REGION} > /dev/null +fi + +# Auth to private ECR +aws ecr get-login-password --region $REGION | docker login --username AWS --password-stdin $ACCOUNT.dkr.ecr.$REGION.amazonaws.com + +# Tag and push the image +docker tag $IMAGE:latest $ACCOUNT.dkr.ecr.$REGION.amazonaws.com/$IMAGE:latest + +docker push $ACCOUNT.dkr.ecr.$REGION.amazonaws.com/$IMAGE:latest diff --git a/examples/sagemaker-pipelines-graphbolt/convert_arxiv_to_gconstruct.py b/examples/sagemaker-pipelines-graphbolt/convert_arxiv_to_gconstruct.py new file mode 100644 index 0000000000..dd66f09f01 --- /dev/null +++ b/examples/sagemaker-pipelines-graphbolt/convert_arxiv_to_gconstruct.py @@ -0,0 +1,163 @@ +""" +Copyright Contributors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Download ogbn-arxiv data and prepare for input to GConstruct +""" + +import argparse +import json + +import pyarrow as pa +import pyarrow.parquet as pq +from pyarrow import fs +from ogb.nodeproppred import NodePropPredDataset + + +def parse_args() -> argparse.Namespace: + """Get the output prefix argument for the scrip""" + parser = argparse.ArgumentParser( + description="Convert OGB arxiv data to gconstruct format and write to S3" + ) + parser.add_argument( + "--output-s3-prefix", + type=str, + required=True, + help="S3 prefix for the output directory for gconstruct format", + ) + return parser.parse_args() + + +def get_filesystem(path): + """Choose the appropriate filesystem based on the path""" + return fs.S3FileSystem() if path.startswith("s3://") else fs.LocalFileSystem() + + +def convert_ogbn_arxiv(output_prefix: str): + """Convert ogbn-arxiv data to GConstruct and output to output_prefix""" + pyarrow_fs = get_filesystem(output_prefix) + + if output_prefix.startswith("s3://"): + output_prefix = output_prefix[5:] + + # Load the entire dataset + dataset = NodePropPredDataset(name="ogbn-arxiv") + graph, labels = dataset[0] + split_idx = dataset.get_idx_split() + + # Convert node features and labels + node_feat = graph["node_feat"] + table = pa.Table.from_arrays( + [ + pa.array(range(len(node_feat))), + pa.array(list(node_feat)), + pa.array(labels.squeeze()), + pa.array(graph["node_year"].squeeze()), + ], + names=["nid", "feat", "labels", "year"], + ) + pq.write_table( + table, f"{output_prefix}/nodes/paper/nodes.parquet", filesystem=pyarrow_fs + ) + + # Convert edge index + edge_index = graph["edge_index"] + edge_table = pa.Table.from_arrays( + [pa.array(edge_index[0]), pa.array(edge_index[1])], names=["src", "dst"] + ) + pq.write_table( + edge_table, + f"{output_prefix}/edges/paper-cites-paper/edges.parquet", + filesystem=pyarrow_fs, + ) + + # Convert train/val/test splits + assert split_idx, "Split index must exist for ogbn-arxiv" + for split in ["train", "valid", "test"]: + split_indices = split_idx[split] + split_table = pa.Table.from_arrays([pa.array(split_indices)], names=["nid"]) + pq.write_table( + split_table, + f"{output_prefix}/splits/{split}_idx.parquet", + filesystem=pyarrow_fs, + ) + + config = { + "version": "gconstruct-v0.1", + "nodes": [ + { + "node_id_col": "nid", + "node_type": "node", + "format": {"name": "parquet"}, + "files": [f"{output_prefix}/nodes/paper/nodes.parquet"], + "features": [ + { + "feature_col": "feat", + "feature_name": "paper_feat", + }, + { + "feature_col": "year", + "feature_name": "paper_year", + "transform": {"name": "max_min_norm"}, + }, + ], + "labels": [ + { + "label_col": "labels", + "task_type": "classification", + "custom_split_filenames": { + "column": "nid", + "train": f"{output_prefix}/splits/train_idx.parquet", + "valid": f"{output_prefix}/splits/valid_idx.parquet", + "test": f"{output_prefix}/splits/test_idx.parquet", + }, + "label_stats_type": "frequency_cnt", + } + ], + } + ], + "edges": [ + { + "source_id_col": "src", + "dest_id_col": "dst", + "relation": ["node", "cites", "node"], + "format": {"name": "parquet"}, + "files": [f"{output_prefix}/edges/paper-cites-paper/edges.parquet"], + }, + { + "source_id_col": "dst", + "dest_id_col": "src", + "relation": ["node", "cites-rev", "node"], + "format": {"name": "parquet"}, + "files": [f"{output_prefix}/edges/paper-cites-paper/edges.parquet"], + }, + ], + } + + # Write config to output + with pyarrow_fs.open_output_stream( + f"{output_prefix}/gconstruct_config_arxiv.json" + ) as f: + f.write(json.dumps(config, indent=2).encode("utf-8")) + + print( + "Conversion for ogbn-arxiv completed. " + f"Output files and configuration are in {output_prefix}" + ) + + +if __name__ == "__main__": + args = parse_args() + + convert_ogbn_arxiv(args.output_prefix) diff --git a/examples/sagemaker-pipelines-graphbolt/convert_ogb_papers100M_to_gconstruct.py b/examples/sagemaker-pipelines-graphbolt/convert_ogb_papers100M_to_gconstruct.py new file mode 100644 index 0000000000..0dfda2e70c --- /dev/null +++ b/examples/sagemaker-pipelines-graphbolt/convert_ogb_papers100M_to_gconstruct.py @@ -0,0 +1,288 @@ +import argparse +import gzip +import json +import logging +from pathlib import Path +from concurrent.futures import ProcessPoolExecutor, as_completed + +import numpy as np +import psutil +import pyarrow as pa +import pyarrow.dataset as ds +import pyarrow.parquet as pq +import pyarrow.fs as fs +from tqdm_loggable.auto import tqdm + +# pylint: disable=logging-fstring-interpolation + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Convert raw OGB papers-100M data to GConstruct format" + ) + parser.add_argument( + "--input-dir", + type=str, + required=True, + help="Path to the input directory containing OGB papers-100M data", + ) + parser.add_argument( + "--output-prefix", + type=str, + required=True, + help="Prefix path to the output directory for gconstruct format. Can be local or s3://", + ) + return parser.parse_args() + + +def get_filesystem(path): + """Choose the appropriate filesystem based on the path""" + return fs.S3FileSystem() if path.startswith("s3://") else fs.LocalFileSystem() + + +def process_and_upload_chunk( + data, schema, output_dir, filesystem, entity_type, start, end +): + """Worker function that writes the input data as a parquet file""" + table = pa.Table.from_arrays(data, schema=schema) + ds.write_dataset( + table, + base_dir=f"{output_dir}/{entity_type}", + basename_template=f"{entity_type}-{start:012}-{end:012}-{{i}}.parquet", + format="parquet", + schema=schema, + filesystem=filesystem, + file_options=ds.ParquetFileFormat().make_write_options(compression="snappy"), + max_rows_per_file=end - start, + existing_data_behavior="overwrite_or_ignore", + ) + + +def process_data(input_dir, output_dir, filesystem): + """Process papers100M data using threads""" + # Load data using memory mapping to minimize memory usage + node_feat = np.load(input_dir / "raw" / "node_feat.npy", mmap_mode="r") + node_year = np.load(input_dir / "raw" / "node_year.npy", mmap_mode="r") + edge_index = np.load(input_dir / "raw" / "edge_index.npy", mmap_mode="r") + labels = np.load(input_dir / "raw" / "node-label.npz", mmap_mode="r")["node_label"] + + num_nodes, num_features = node_feat.shape + num_edges = edge_index.shape[1] + logging.info( + f"Node features shape: {node_feat.shape:,}, Number of edges: {num_edges:,}" + ) + + # Define schemas for nodes and edges + node_schema = pa.schema( + [ + ("nid", pa.int64()), + ("feat", pa.large_list(pa.float32())), + ("label", pa.float32()), + ("year", pa.int16()), + ] + ) + edge_schema = pa.schema([("src", pa.int64()), ("dst", pa.int64())]) + + # Calculate chunk sizes and number of workers based on available memory + available_ram = psutil.virtual_memory().available + + # Calculate memory usage per node row + node_row_bytes = ( + num_features * 4 + 8 + 2 + ) # 4 bytes per float32, 8 bytes for int64 nid, 2 bytes for int16 year + # Set node chunk size to fit within 1GB or the total number of nodes, whichever is smaller + node_chunk_size = min((1024**3) // node_row_bytes, num_nodes) + + # Calculate memory usage per edge row + edge_row_bytes = 16 # 8 bytes for each int64 (src and dst) + # Set edge chunk size to fit within 1GB or the total number of edges, whichever is smaller + edge_chunk_size = min((1024**3) // edge_row_bytes, num_edges) + + # Set the number of worker threads + # Use 2 times the number of CPU cores (or 4 if CPU count can't be determined) + # But limit based on available RAM, assuming each worker might use up to 2GB + max_workers = min(16, available_ram // (2 * 1024**3)) + + logging.info( + f"Node chunk size: {node_chunk_size:,} rows, Edge chunk size: {edge_chunk_size:,} rows." + ) + logging.info(f"Max concurrent workers: {max_workers}") + + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = [] + + num_node_chunks = num_nodes // node_chunk_size + # Process and upload nodes in chunks + for idx, start in enumerate( + tqdm(range(0, num_nodes, node_chunk_size)), start=1 + ): + end = min(start + node_chunk_size, num_nodes) + logging.info( + f"Reading data chunk {idx}/{num_node_chunks} for nodes {start:,}-{end:,}" + ) + data = [ + pa.array(range(start, end)), + pa.array(list(node_feat[start:end])), + pa.array(labels[start:end].squeeze()), + pa.array(node_year[start:end].astype(np.int16).squeeze()), + ] + logging.info(f"Submitting job {idx} for nodes {start:,}-{end:,}") + futures.append( + executor.submit( + process_and_upload_chunk, + data, + node_schema, + output_dir, + filesystem, + "nodes", + start, + end, + ) + ) + + # Process and upload edges in chunks + num_edge_chunks = num_edges // edge_chunk_size + for idx, start in enumerate( + tqdm(range(0, num_edges, edge_chunk_size)), start=1 + ): + end = min(start + edge_chunk_size, num_edges) + logging.info( + f"Reading data chunk {idx}/{num_edge_chunks} for edges {start:,}-{end:,}" + ) + data = [ + pa.array(edge_index[0, start:end]), + pa.array(edge_index[1, start:end]), + ] + logging.info(f"Submitting job {idx} for edges {start:,}-{end:,}") + futures.append( + executor.submit( + process_and_upload_chunk, + data, + edge_schema, + output_dir, + filesystem, + "edges", + start, + end, + ) + ) + + # Wait for all uploads to complete + for future in tqdm( + as_completed(futures), total=len(futures), desc="Processing and uploading" + ): + # This will raise any exceptions that occurred during processing or upload + future.result() + + # Process split files + split_files = {} + for split in ["train", "valid", "test"]: + with gzip.open(input_dir / "split" / "time" / f"{split}.csv.gz", "rt") as f: + split_indices = [int(line.strip()) for line in f] + split_table = pa.table({"nid": split_indices}) + pq.write_table( + split_table, f"{output_dir}/{split}_idx.parquet", filesystem=filesystem + ) + split_files[split] = f"{split}_idx.parquet" + + return split_files + + +def create_config(output_dir, filesystem, split_files): + """Create the GConstruct configuration file and write to output_dir""" + config = { + "version": "gconstruct-v0.1", + "nodes": [ + { + "node_id_col": "nid", + "node_type": "paper", + "format": {"name": "parquet"}, + "files": ["nodes"], + "features": [ + {"feature_col": "feat", "feature_name": "paper_feat"}, + { + "feature_col": "year", + "feature_name": "paper_year", + "transform": {"name": "max_min_norm"}, + }, + ], + "labels": [ + { + "label_col": "label", + "task_type": "classification", + "custom_split_filenames": { + "column": "nid", + "train": split_files["train"], + "valid": split_files["valid"], + "test": split_files["test"], + }, + "label_stats_type": "frequency_cnt", + } + ], + } + ], + "edges": [ + { + "source_id_col": "src", + "dest_id_col": "dst", + "relation": ["paper", "cites", "paper"], + "format": {"name": "parquet"}, + "files": ["edges"], + }, + { + "source_id_col": "dst", + "dest_id_col": "src", + "relation": ["paper", "cites-rev", "paper"], + "format": {"name": "parquet"}, + "files": ["edges"], + }, + ], + } + + # Write the configuration to a JSON file + with filesystem.open_output_stream( + f"{output_dir}/gconstruct_config_papers100m.json" + ) as f: + f.write(json.dumps(config, indent=2).encode("utf-8")) + + +def main(): + """Runs the conversion from raw data to GConstruct input format""" + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + ) + + args = parse_args() + input_path = Path(args.input_dir) + filesystem = get_filesystem(args.output_prefix) + + # Adjust the output prefix for S3 if necessary + if args.output_prefix.startswith("s3://"): + # PyArrow expects 'bucket/key...' for S3 + output_prefix = args.output_prefix[5:] + else: + output_prefix = args.output_prefix + + # Remove trailing slash from output prefix + output_prefix = output_prefix[:-1] if output_prefix.endswith("/") else output_prefix + + # Create output directories + for path in ["nodes", "edges"]: + filesystem.create_dir(f"{output_prefix}/{path}", recursive=True) + + # Process the data and get split files information + split_files = process_data(input_path, output_prefix, filesystem) + + # Create and write the configuration file + create_config(output_prefix, filesystem, split_files) + + print( + "Conversion for ogbn-papers100M completed. " + f"Output files and GConstruct configuration are in {output_prefix}" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/sagemaker-pipelines-graphbolt/process_papers100M.sh b/examples/sagemaker-pipelines-graphbolt/process_papers100M.sh new file mode 100644 index 0000000000..d24d5b92a8 --- /dev/null +++ b/examples/sagemaker-pipelines-graphbolt/process_papers100M.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash +set -Eeuox pipefail +trap cleanup SIGINT SIGTERM ERR EXIT + +cleanup() { + trap - SIGINT SIGTERM ERR EXIT + # script cleanup here + kill $DISK_USAGE_PID > /dev/null 2>&1 || true +} + +# Download and unzip data in parallel +TEMP_DATA_PATH=/tmp/raw-data +mkdir -p $TEMP_DATA_PATH +cd $TEMP_DATA_PATH || exit 1 + + +echo "Will execute script $1 with output prefix $2" + + +echo "$(date -u '+%Y-%m-%dT%H:%M:%SZ'): Downloading files using axel, this will take at least 10 minutes depending on network speed" +time axel -n 16 --quiet http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip + +echo "$(date -u '+%Y-%m-%dT%H:%M:%SZ'): Unzipping files using ripunzip this will take 10-20 minutes" +time ripunzip unzip-file papers100M-bin.zip +# npz files are zip files, so we can also unzip them in parallel +cd papers100M-bin/raw || exit 1 +time ripunzip unzip-file data.npz && rm data.npz + + +# Run the processing script +echo "$(date -u '+%Y-%m-%dT%H:%M:%SZ'): Processing data and uploading to S3, this will take around 20 minutes" +python3 /opt/ml/code/"$1" \ + --input-dir "$TEMP_DATA_PATH/papers100M-bin/" \ + --output-prefix "$2" From 8378cc6431151a9bd967dc390f7dab832a5e0c36 Mon Sep 17 00:00:00 2001 From: Theodore Vasiloudis Date: Tue, 7 Jan 2025 10:31:57 +0000 Subject: [PATCH 2/9] Add training yaml for papers100M --- training_scripts/gsgnn_np/papers100M_nc.yaml | 49 ++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 training_scripts/gsgnn_np/papers100M_nc.yaml diff --git a/training_scripts/gsgnn_np/papers100M_nc.yaml b/training_scripts/gsgnn_np/papers100M_nc.yaml new file mode 100644 index 0000000000..ab0ef7917b --- /dev/null +++ b/training_scripts/gsgnn_np/papers100M_nc.yaml @@ -0,0 +1,49 @@ +--- +version: 1.0 +gsf: + basic: + model_encoder_type: rgcn + graph_name: ogbn-papers100M + backend: gloo + ip_config: /ip_list.txt + part_config: null + verbose: false + mp_opt_level: O2 + no_validation: false + train_nodes: 10 + debug: false + evaluation_frequency: 500 + gnn: + num_layers: 3 + hidden_size: 128 + use_mini_batch_infer: true + input: + restore_model_path: null + output: + save_model_path: null + save_embed_path: null + hyperparam: + dropout: 0.1 + lr: 0.001 + bert_tune_lr: 0.0001 + num_epochs: 15 + fanout: "3,5,8" + eval_fanout: "3,5,8" + batch_size: 1024 + eval_batch_size: 1024 + bert_infer_bs: 128 + wd_l2norm: 0 + norm: "batch" + rgcn: + num_bases: -1 + use_self_loop: true + use_dot_product: true + lp_decoder_type: dot_product + self_loop_init: false + sparse_optimizer_lr: 1e-2 + use_node_embeddings: false + node_classification: + target_ntype: "paper" + label_field: "label" + multilabel: false + num_classes: 172 From ba1c81719d845707069c540b4098b45bb98bbb46 Mon Sep 17 00:00:00 2001 From: Theodore Vasiloudis Date: Wed, 8 Jan 2025 23:34:42 +0000 Subject: [PATCH 3/9] Add README with detailed walkthrough, scripts for pipeline deployment. --- .github/workflow_scripts/lint_check.sh | 9 +- .../Dockerfile.processing | 10 +- .../sagemaker-pipelines-graphbolt/README.md | 531 ++++++++++++++++++ .../analyze_training_time.py | 17 +- .../build_and_push_papers100M_image.sh | 45 +- ...> convert_ogb_papers100m_to_gconstruct.py} | 20 +- .../deploy_arxiv_pipeline.sh | 129 +++++ .../deploy_papers100M_pipeline.sh | 139 +++++ 8 files changed, 871 insertions(+), 29 deletions(-) create mode 100644 examples/sagemaker-pipelines-graphbolt/README.md rename examples/sagemaker-pipelines-graphbolt/{convert_ogb_papers100M_to_gconstruct.py => convert_ogb_papers100m_to_gconstruct.py} (93%) create mode 100644 examples/sagemaker-pipelines-graphbolt/deploy_arxiv_pipeline.sh create mode 100644 examples/sagemaker-pipelines-graphbolt/deploy_papers100M_pipeline.sh diff --git a/.github/workflow_scripts/lint_check.sh b/.github/workflow_scripts/lint_check.sh index 63f4231a81..c07c39d68c 100644 --- a/.github/workflow_scripts/lint_check.sh +++ b/.github/workflow_scripts/lint_check.sh @@ -1,9 +1,11 @@ -# Move to parent directory -cd ../../ - +#!/usr/bin/env bash set -ex +# Move to repo root +cd ../../ + pip install pylint==2.17.5 + pylint --rcfile=./tests/lint/pylintrc ./python/graphstorm/*.py pylint --rcfile=./tests/lint/pylintrc ./python/graphstorm/data/*.py pylint --rcfile=./tests/lint/pylintrc ./python/graphstorm/distributed/ @@ -21,3 +23,4 @@ pylint --rcfile=./tests/lint/pylintrc ./python/graphstorm/utils.py pylint --rcfile=./tests/lint/pylintrc ./tools/convert_feat_to_wholegraph.py pylint --rcfile=./tests/lint/pylintrc ./python/graphstorm/sagemaker/ +pylint --rcfile=./tests/lint/pylintrc ./examples/sagemaker-pipelines-graphbolt/ --recursive y diff --git a/examples/sagemaker-pipelines-graphbolt/Dockerfile.processing b/examples/sagemaker-pipelines-graphbolt/Dockerfile.processing index 0470c21b2b..0abb9cad47 100644 --- a/examples/sagemaker-pipelines-graphbolt/Dockerfile.processing +++ b/examples/sagemaker-pipelines-graphbolt/Dockerfile.processing @@ -4,7 +4,7 @@ FROM public.ecr.aws/ubuntu/ubuntu:22.04 ENV DEBIAN_FRONTEND=noninteractive # Install Python and other dependencies -RUN apt update && apt install -y \ +RUN apt-get update && apt-get install -y \ axel \ curl \ python3 \ @@ -13,9 +13,9 @@ RUN apt update && apt install -y \ unzip \ && rm -rf /var/lib/apt/lists/* - +# Copy and install ripunzip COPY ripunzip_2.0.0-1_amd64.deb ripunzip_2.0.0-1_amd64.deb -RUN apt install -y ./ripunzip_2.0.0-1_amd64.deb +RUN apt-get install -y ./ripunzip_2.0.0-1_amd64.deb RUN python3 -m pip install --no-cache-dir --upgrade pip==24.3.1 && \ python3 -m pip install --no-cache-dir \ @@ -25,14 +25,14 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip==24.3.1 && \ tqdm==4.67.1 \ tqdm-loggable==0.2 +# Install aws cli RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" \ && unzip awscliv2.zip \ && ./aws/install - # Copy processing scripts COPY process_papers100M.sh /opt/ml/code/ -COPY convert_ogb_papers100M_to_gconstruct.py /opt/ml/code/ +COPY convert_ogb_papers100m_to_gconstruct.py /opt/ml/code/ WORKDIR /opt/ml/code/ diff --git a/examples/sagemaker-pipelines-graphbolt/README.md b/examples/sagemaker-pipelines-graphbolt/README.md new file mode 100644 index 0000000000..44499a3dd6 --- /dev/null +++ b/examples/sagemaker-pipelines-graphbolt/README.md @@ -0,0 +1,531 @@ +# Faster distributed graph neural network training with GraphStorm 0.4 + +GraphStorm is a low-code enterprise graph machine learning (ML) framework that provides ML practitioners a simple way of building, training and deploying graph ML solutions on industry-scale graph data. While GraphStorm can run efficiently on single instances for small graphs, it truly shines when scaling to enterprise-level graphs in distributed mode using a cluster of EC2 instances or Amazon SageMaker. + +GraphStorm 0.4 introduced integration with DGL-GraphBolt, a new graph storage and sampling framework that uses a compact graph representation and pipelined sampling to reduce memory requirements and speed up Graph Neural Network (GNN) training by up to 3x. In this example we'll show how GraphStorm 0.4 brings training and inference speedups of up to 3x. + +In this example, you will: + +1. Learn how to use SageMaker Pipelines with GraphStorm. +2. Understand how GraphBolt enhances GraphStorm's performance in distributed settings. +3. Follow a hands-on example of using GraphStorm with GraphBolt on Amazon SageMaker for distributed training. + +## Background: challenges of graph training + +Before diving into our hands-on example, it's important to understand some challenges associated with graph training, especially as graphs grow in size and complexity: + +1. Memory Constraints: As graphs grow larger, they may no longer fit into the memory of a single machine. A graph with 1B nodes with 512 features per node and 10B edges will require more than 4TB of memory to store, even with optimal representation. This necessitates distributed processing and more efficient graph representation. +2. Graph Sampling: In GNN mini-batch training, you need to sample neighbors for each node to propagate their representations. For multi-layer GNNs, this can lead to exponential growth in the number of nodes sampled, potentially visiting the entire graph for a single node's representation. Efficient sampling methods become necessary. +3. Remote Data Access: When training on multiple machines, retrieving node features and sampling neighborhoods from other machines will significantly impact performance due to network latency. For example, reading a 1024-feature vector from main memory will take around 3μs, while reading that vector from a remote key/value store would take 50-100x longer. + +GraphStorm and GraphBolt help address these challenges through efficient graph representations, smart sampling techniques, and sophisticated partitioning algorithms like ParMETIS. + + +## GraphBolt: pipeline-driven graph sampling + + +GraphBolt is a new data loading and graph sampling framework developed by the [DGL](https://www.dgl.ai/) team. It streamlines the operations needed to sample efficiently from a heterogeneous graph and fetch the corresponding features. + +GraphBolt introduces a new, more compact graph structure representation for heterogeneous graphs, called fused Compressed Sparse Column (fCSC). This can reduce the memory cost of storing a heterogeneous graph by up to 56%, allowing users to fit larger graphs in memory and potentially use smaller, more cost-efficient instances for GNN model training. + + +### Integration with GraphStorm: + +GraphStorm 0.4.0 seamlessly integrates with GraphBolt, allowing users to leverage these performance improvements in their GNN workflows. This integration enables GraphStorm to handle larger graphs more efficiently and accelerate both training and inference processes. + +The integration of GraphBolt into GraphStorm's workflow means that users can now: + +1. Load and process larger graphs with fewer hardware resources. +2. Achieve faster training and inference times with more efficient graph sampling framework. +3. Utilize GPU resources more effectively for graph learning. + +### Performance improvements: + +Our benchmarks show significant improvements in both memory usage and training speed when using GraphStorm with GraphBolt: + + +* We've observed up to 1.8x training speedup on the [ogbn-papers 100M dataset](https://ogb.stanford.edu/docs/nodeprop/#ogbn-papers100M), with 111M nodes and 3.2B edges +* At the same time, memory usage for storing graph structures has been reduced by up to 56% in heterogeneous graphs like ogbn-papers. + +## Example model development lifecycle for GraphStorm on SageMaker + +Figure 1: GraphStorm SageMaker architecture. + +A common model development process is to perform model exploration locally on a subset of your full data, and once satisfied with the results train the full scale model. GraphStorm and SageMaker Pipelines allows you to do that by creating a model pipeline you can execute locally to retrieve model metrics, and when ready execute your pipeline on the full data, and produce models, predictions and graph embeddings to use in downstream tasks. In the next section you'll learn how to set up such pipelines for GraphStorm. + +## Set up environment for SageMaker distributed training + +You'll be using SageMaker Bring-Your-Own-Container (BYOC) to launch processing and training jobs. You need to create a PyTorch Docker image for distributed training, and we'll use the same image to process and prepare the graph for training. +You will use SageMaker Pipelines to automate jobs needed for GNN training. As a prerequisite, you'll need to have access to a [SageMaker Domain](https://docs.aws.amazon.com/sagemaker/latest/dg/gs-studio-onboard.html) to access [SageMaker Studio](https://aws.amazon.com/sagemaker-ai/studio/) and [SageMaker Pipelines](https://docs.aws.amazon.com/sagemaker/latest/dg/pipelines.html). + +### Create a SageMaker Domain + +In order to use SageMaker Studio you will need to have a SageMaker Domain available. If you don't have one already, follow the steps in the [quick setup](https://docs.aws.amazon.com/sagemaker/latest/dg/onboard-quick-start.html) to create one: + +1. Sign in to the [SageMaker AI console](https://console.aws.amazon.com/sagemaker/). +2. Open the left navigation pane. +3. Under **Admin configurations**, choose **Domains**. +4. Choose **Create domain**. +5. Choose **Set up for single user (Quick setup**). Your domain and user profile are created automatically. + +### Set up appropriate roles to use with SageMaker Pipelines + +To set up the SageMaker Pipelines you will need permissions to create ECR repositories, pull and push to them, pull from the AWS ECR Public Gallery, launch SageMaker jobs, manage SageMaker Pipelines, and interact with data on S3. We will create a role for Amazon EC2 on the AWS console, which will also create an associated instance profile to use with an EC2 instance. + +You will also need access to a SageMaker execution that your jobs assume during execution. You can use the [Amazon SageMaker Role Manager](https://docs.aws.amazon.com/sagemaker/latest/dg/role-manager.html) to streamline the creation of the necessary roles. + + +### Set up the pipeline management environment + +For this example you can either use your existing development environment or set up a new EC2 instance. If you plan to use a new instance to prepare the large-scale data for this example, ensure it has at least 300GB of disk space available. +To set up an EC2 instance with the appropriate environment: + + +1. Launch an EC2 instance: + +```bash +# Use an Ubuntu PyTorch 2.4.0 DLAMI (Ubuntu 22.04) +aws ec2 run-instances \ + --image-id "ami-0907e5206d941612f" \ + --instance-type "m6in.4xlarge" \ + --key-name my-key-name \ + --block-device-mappings '[{ + "DeviceName": "/dev/sdf", + "Ebs": { + "VolumeSize": 300, + "VolumeType": "gp3", + "DeleteOnTermination": true + } + }]' +``` + +This command creates an instance using the "Deep Learning OSS Nvidia Driver AMI GPU PyTorch 2.4.1 (Ubuntu 22.04) 20241116" AMI, in the default VPC with the default security group. Make your instance accessible through SSH, using an appropriate security group or the [AWS Systems Session Manager](https://docs.aws.amazon.com/systems-manager/latest/userguide/session-manager.html), and log in to the instance. You can also use the [AWS Console](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/tutorial-launch-my-first-ec2-instance.html) to create a new EC2 instance. + +> NOTE: You may need to update the --image-id to the latest available. See https://docs.aws.amazon.com/dlami/latest/devguide/find-dlami-id.html for instructions. + +Once logged in, you can set up your Python environment to run GraphStorm + +```bash +conda init +eval $SHELL +conda create -y --name gsf python=3.10 +conda activate gsf + +# Install dependencies +pip install sagemaker boto3 ogb pyarrow + +# Clone the GraphStorm repository to access the example code +git clone https://github.com/awslabs/graphstorm.git ~/graphstorm +cd ~/graphstorm/examples/sagemaker-pipelines-graphbolt +``` + +### Download and prepare datasets + +In this example you will use two related datasets to demonstrate the scalability of GraphStorm. The Open Graph Benchmark (OGB) project hosts a number of graph datasets that can be used to benchmark the performance of graph learning systems. In this example you will use two citation network datasets, the ogbn-arxiv dataset for a small-scale demo, and the ogbn-papers100M dataset for a demonstration of GraphStorm's large-scale learning capabilities. + +Because the two datasets have similar schemas and the same task (node classification) they allow us to emulate a typical data science pipeline, where we first do some model development and testing on a smaller dataset locally, and once ready launch SageMaker jobs to train on the full-scale data. + + +#### Prepare the ogbn-arxiv dataset + +You'lll download the smaller-scale [ogbn-arxiv](https://ogb.stanford.edu/docs/nodeprop/#ogbn-arxiv) dataset to run a local test before launching larger scale SageMaker jobs on AWS. This dataset has ~170K nodes and ~1.2M edges. You will use the following script to download the arxiv data and prepare them for GraphStorm. + + +```bash +# Provide the S3 bucket to use for output +BUCKET_NAME= +``` + + +You will use this script to directly download, transform and upload the data to S3: + + +```bash +python convert_ogb_arxiv_to_gconstruct.py \ + --output-prefix s3://$BUCKET_NAME/ogb-arxiv-input +``` + +This will create the tabular graph data on S3 which you can verify by running + + +```bash +aws s3 ls s3://$BUCKET_NAME/ogb-arxiv-input/ + PRE edges/ + PRE nodes/ + PRE splits/ +2024-12-11 02:13:27 1269 gconstruct_config_arxiv.json +``` + +Finally you'll also upload GraphStorm training configuration files for arxiv to use for training and inference + +``` +# Upload the training configurations to S3 +aws s3 cp ~/graphstorm/training_scripts/gsgnn_np/arxiv_nc.yaml \ + s3://$BUCKET_NAME/yaml/arxiv_nc_train.yaml +aws s3 cp ~/graphstorm/inference_scripts/np_infer/arxiv_nc.yaml \ + s3://$BUCKET_NAME/yaml/arxiv_nc_inference.yaml +``` + +**Prepare the ogbn-papers100M dataset on SageMaker** + +The papers-100M dataset is a large-scale graph dataset, with 111M nodes and ~3.2B edges when we add reverse edges. The data size is ~57GB so to make efficient use of our AWS resources we'll download and unzip the data in parallel, using multiple threads and upload directly to S3. To do so we will use the [axel](https://github.com/axel-download-accelerator/axel) and [ripunzip](https://github.com/google/ripunzip/) libraries. + +You can either run this job as a SageMaker processing job or you can run the processing locally in the background while you work on building the GraphStorm Docker image and training a local model for the ogbn-arxiv dataset. + +To run this process as a SageMaker Processing step, follow the steps below. You can launch and let the job execute in the background while proceeding through the rest of the steps, you can come back to this dataset later. + + +```bash +# Navigate to the example code and ensure Docker is installed +cd ~/graphstorm/examples/sagemaker-pipelines-graphbolt +sudo apt update +sudo apt install Docker.io +docker -v + +# Build and push a Docker image to download and process the papers100M data +bash build_and_push_papers100M_image.sh +# This creates an ECR repository at +# $ACCOUNT_ID.dkr.ecr.$REGION.amazonaws.com/papers100m-processor + +# Run a SageMaker job to do the processing and upload the output to S3 +SAGEMAKER_EXECUTION_ROLE= +ACCOUNT_ID= +REGION=us-east-1 +python sagemaker_convert_papers100M.py \ + --output-bucket $BUCKET_NAME \ + --execution-role-arn $SAGEMAKER_EXECUTION_ROLE \ + --region $REGION \ + --instance-type ml.m5.4xlarge \ + --image-uri $ACCOUNT_ID.dkr.ecr.$REGION.amazonaws.com/papers100m-processor +``` + +This will produce the processed data at `s3://$BUCKET_NAME/ogb-papers100M-input` which can then be used as input to GraphStorm. + + +#### [Optional] Prepare the ogbn-papers100M dataset locally + +If you prefer to pre-process the data locally, you can use the commands below on an Ubuntu 22.04 instance. + +```bash +# Install axel for parallel downloads +sudo apt update +sudo apt -y install axel + +# Download and install ripunzip for parallel unzipping +curl -L -O https://github.com/google/ripunzip/releases/download/v2.0.0/ripunzip_2.0.0-1_amd64.deb +sudo apt install -y ./ripunzip_2.0.0-1_amd64.deb + +# Download and unzip data using multiple threads, this will take 10-20 minutes +mkdir ~/papers100M-raw-data +cd ~/papers100M-raw-data +axel -n 16 http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip +ripuznip unzip-file papers100M-bin.zip +ripunzip unzip-file papers100M-bin/raw/data.npz && rm papers100M-bin/raw/data.npz + +# Install process script dependencies +python -m pip install \ + numpy==1.26.4 \ + psutil==6.1.0 \ + pyarrow==18.1.0 \ + tqdm==4.67.1 \ + tqdm-loggable==0.2 + + +# Process and upload to S3, this will take around 20 minutes +python convert_ogb_papers100m_to_gconstruct.py \ + --input-dir ~/papers100M-raw-data + --output-dir s3://$BUCKET_NAME/ogb-papers100M-input +``` + +### Build a GraphStorm Docker Image + +Next you will build and push the GraphStorm PyTorch Docker image that you'll use to run the graph construction, training and inference jobs. If you have the papers-100M data downloading in the background, open a new terminal to build and push the GraphStorm image. + + +```bash +# Ensure Docker is installed +sudo apt update +sudo apt install -y Docker.io +docker -v + +# Enter you account ID here +ACCOUNT_ID= +REGION=us-east-1 + +cd ~/graphstorm + +bash ./docker/build_graphstorm_image.sh --environment sagemaker --device cpu + +bash docker/push_graphstorm_image.sh -e sagemaker -r $REGION -a $ACCOUNT_ID -d cpu +# This will push an image to +# ${ACCOUNT_ID}.dkr.ecr.us-east-1.amazonaws.com/graphstorm:sagemaker-cpu + +# Install sagemaker with support for local mode +pip install sagemaker[local] +``` + +Next, you will create a SageMaker Pipeline to run the jobs that are necessary to train GNN models with GraphStorm. + +## Create SageMaker Pipeline + +In this section, you will create a [Sagemaker Pipeline](https://docs.aws.amazon.com/sagemaker/latest/dg/pipelines-overview.html) on AWS SageMaker. The pipeline will run the following jobs in sequence: + +* Launch GConstruct Processing job. This prepares and partitions the data for distributed training.. +* Launch GraphStorm Training Job. This will train the model and create model output on S3. +* Launch GraphStorm Inference Job. This will generate predictions and embeddings for every node in the input. + +```bash +PIPELINE_NAME="ogbn-arxiv-gs-pipeline" +BUCKET_NAME="my-s3-bucket" +bash deploy_papers100M_pipeline.sh \ + --account "" \ + --bucket-name $BUCKET_NAME --role "" \ + --pipeline-name $PIPELINE_NAME \ + --use-graphbolt false +``` + +### Inspect pipeline + +Running the above will create a SageMaker Pipeline configured to run 3 SageMaker jobs in sequence: + +* A GConstruct job that converts the tabular file input to a binary partitioned graph on S3. +* A GraphStorm training job that trains a node classification model and saves the model to S3. +* A GraphStorm inference job that produces predictions for all nodes in the test set, and creates embeddings for all nodes. + +To review the pipeline, navigate to [SageMaker AI Studio](https://us-east-1.console.aws.amazon.com/sagemaker/home?region=us-east-1#/studio-landing) on the AWS Console, select the domain and user profile you used to create the pipeline in the drop-down menus on the top right, then select **Open Studio**. + +On the left navigation menu, select **Pipelines**. There should be a pipeline named **ogbn-arxiv-gs-pipeline**. Select that, which will take you to the **Executions** tab for the pipeline. Select **Graph** to view the pipeline steps. + +### Execute SageMaker pipeline locally for ogbn-arxiv + +The ogbn-arxiv data are small enough that you can execute the pipeline locally. Execute the following command to start a local execution of the pipeline: + + +```bash +PIPELINE_NAME="ogbn-arxiv-gs-pipeline" +cd ~/graphstorm/sagemaker/pipeline +python execute_sm_pipeline.py \ + --pipeline-name $PIPELINE_NAME \ + --region us-east-1 \ + --local-execution | tee arxiv-local-logs.txt +``` + +Note that we save the log output to `arxiv-local-logs.txt`. We'll use that later to analyze the training speed. + +Once the pipeline finishes it will print a message like + +``` +Pipeline execution 655b9357-xxx-xxx-xxx-4fc691fcce94 SUCCEEDED +``` + +You can inspect its output on S3. Every pipeline execution will be under the prefix `s3://$BUCKET_NAME/pipelines-output/ogbn-arxiv-gs-pipeline/` + +Every pipeline execution that shares the same input arguments will be under a randomized execution-identifying output path. +Note that the particular execution subpath might be different in your case. + +```bash +aws s3 ls s3://$BUCKET_NAME/pipelines-output/ogbn-arxiv-gs-pipeline/ + +# 761a4ff194198d49469a3bb223d5f26e + +# There should only be one execution subpath, copy that into a new env variable +EXECUTION_SUBPATH="761a4ff194198d49469a3bb223d5f26e" +aws s3 ls --recursive \ + s3://$BUCKET_NAME/pipelines-output/ogbn-arxiv-gs-pipeline/$EXECUTION_SUBPATH + +# gconstruct: +# data_transform_new.json edge_label_stats.json edge_mapping.pt node_label_stats.json node_mapping.pt ogbn-arxiv.json part0 part1 + +# inference: +# embeddings predictions + +# model: +# epoch-0 epoch-1 epoch-2 epoch-3 epoch-4 epoch-5 epoch-6 epoch-7 epoch-8 epoch-9 + +``` + +You'll be able to see the output of each step in the pipeline. The GConstruct job created the partitioned graph, the training job created models for 10 epochs, and the inference job created embeddings for the nodes and predictions for the nodes in the test set. + +You can inspect the mean epoch and evaluation time using the provided `analyze_training_time.py` script and the log file you created: + + +```bash +python analyze_training_time.py --log-file arxiv-local-logs.txt + +Reading logs from file: arxiv-logs.txt + +=== Training Epochs Summary === +Total epochs completed: 10 +Average epoch time: 7.43 seconds + +=== Evaluation Summary === +Total evaluations: 11 +Average evaluation time: 2.25 seconds +``` + +Note that these numbers will vary depending on your instance type. + +### Create GraphBolt Pipeline + +Now that you have established a baseline for performance you can create another pipeline that uses the GraphBolt graph representation to compare the performance. + +You can use the same pipeline creation script, but change two variables, providing a new pipeline name, and setting `USE_GRAPHBOLT` to `“true”`. + + +```bash +# Deploy the GraphBolt-enabled pipeline +PIPELINE_NAME="ogbn-arxiv-gs-graphbolt-pipeline" +BUCKET_NAME="my-s3-bucket" +bash deploy_arxiv_pipeline.sh \ + --account "" \ + --bucket-name $BUCKET_NAME --role "" \ + --pipeline-name $PIPELINE_NAME \ + --use-graphbolt true +# Execute the pipeline locally +python execute_sm_pipeline.py \ + --pipeline-name $PIPELINE_NAME \ + --region us-east-1 \ + --local-execution | tee arxiv-local-gb-logs.txt +``` + +Analyzing the training logs you can see the per-epoch time has dropped somewhat: + +```bash +python analyze_training_time.py --log-file arxiv-local-gb-logs.txt + +Reading logs from file: arxiv-gb-logs.txt + +=== Training Epochs Summary === +Total epochs completed: 10 +Average epoch time: 6.83 seconds + +=== Evaluation Summary === +Total evaluations: 11 +Average evaluation time: 1.99 seconds +``` + +For such a small graph the performance gains are modest, around 13% per epoch time. Moving on to large data however, the potential gains are much larger. In the next section you will create a pipeline and train a model for `papers-100M`, a citation graph with 111M nodes and 3.2B edges. + +## Create SageMaker Pipeline for distributed training + +Once the papers-100M data have finished processing and exist on S3, either through your local job or the SageMaker Processing job, you can set up a pipeline to train a model on that dataset. + +### Build the GraphStorm GPU image + +For this job you will use large GPU instances, so you will build and push the GPU image this time: + + +```bash +cd ~/graphstorm + +bash ./docker/build_graphstorm_image.sh --environment sagemaker --device gpu + +bash docker/push_graphstorm_image.sh -e sagemaker -r $REGION -a $ACCOUNT_ID -d gpu +``` + +### Deploy and execute pipelines for papers-100M + +Before you deploy your new pipeline, upload the training YAML configuration for papers-100M to S3: + + +```bash +aws s3 cp \ + ~/graphstorm/training_scripts/gsgnn_np/papers_100M_nc.yaml \ + s3://$BUCKET_NAME/yaml/ +``` + + +Now you are ready to deploy your initial pipeline for papers-100M + +```bash +PIPELINE_NAME="ogb-papers100M-pipeline" +bash deploy_papers100M_pipeline.sh \ + --account \ + --bucket-name --role \ + --pipeline-name $PIPELINE_NAME \ + --use-graphbolt false +``` + +Execute the pipeline and let it run the background. + +```bash +python execute_sm_pipeline.py \ + --pipeline-name $PIPELINE_NAME \ + --region us-east-1 + --async-execution +``` + +>Note that your account needs to meet the required quotas for the requested instances. Here the defaults are set to four `ml.g5.48xlarge` for training jobs and one `ml.r5.24xlarge` instance for a processing job. To adjust your SageMaker service quotas you can use the [Service Quotas console UI](https://us-east-1.console.aws.amazon.com/servicequotas/home/services/sagemaker/quotas). To run both pipelines in parallel you will need 8 x $TRAIN_GPU_INSTANCE and 2 x $GCONSTRUCT_INSTANCE. + + +Next, you can deploy and execute another pipeline, now with GraphBolt enabled: + +```bash +PIPELINE_NAME="ogb-papers100M-graphbolt-pipeline" +bash deploy_papers100M_pipeline.sh \ + --account \ + --bucket-name --role \ + --pipeline-name $PIPELINE_NAME \ + --use-graphbolt true + +# Execute the GraphBolt-enabled pipeline on SageMaker +python execute_sm_pipeline.py \ + --pipeline-name $PIPELINE_NAME \ + --region us-east-1 \ + --async-execution +``` + +### Compare performance for GraphBolt-enabled training + +Once both pipelines have finished executing, which should take approximately 4 hours, you can compare the training times for both cases. To do so you will need to find the pipeline execution display names for the two papers-100M pipelines. + +The easiest way to do so is through the Studio pipeline interface. In the Pipelines page you visited previously, there should be two new pipelines named **ogb-papers100M-pipeline** and **ogb-papers100M-graphbolt-pipeline**. Select **ogb-papers100M-pipeline**, which will take you to the **Executions** tab for the pipeline. Copy the name of the latest successful execution and use that to run the training analysis script: + + +```bash +python analyze_training_time.py \ + --pipeline-name papers-100M-gs-pipeline \ + --execution-name execution-1734404366941 +``` + +Your output will look like + +```bash +== Training Epochs Summary === +Total epochs completed: 15 +Average epoch time: 73.95 seconds + +=== Evaluation Summary === +Total evaluations: 15 +Average evaluation time: 15.07 seconds +``` + +Now do the same for the GraphBolt-enabled pipeline: + +```bash +python analyze_training_time.py \ + --pipeline-name papers-100M-gs-graphbolt-pipeline \ + --execution-name execution-1734463209078 +``` + +You will see the improved per-epoch and evaluation times: + +```bash +== Training Epochs Summary === +Total epochs completed: 15 +Average epoch time: 54.54 seconds + +=== Evaluation Summary === +Total evaluations: 15 +Average evaluation time: 4.13 seconds +``` + +Without loss in accuracy, the latest version of GraphStorm achieved a **~1.4x speedup per epoch, and a 3.6x speedup in evaluation time!** + +## Conclusion: Accelerate Your Graph ML with GraphStorm + +This example showcased how GraphStorm 0.4, integrated with DGL-GraphBolt, significantly speeds up large-scale graph neural network training and inference. + +We encourage ML practitioners working with large graph data to try GraphStorm. Its low-code interface simplifies building, training, and deploying graph ML solutions on AWS, allowing you to focus on modeling rather than infrastructure. + +To get started, visit the GraphStorm [documentation](https://graphstorm.readthedocs.io/en/) and GraphStorm [Github repository](https://github.com/awslabs/graphstorm). diff --git a/examples/sagemaker-pipelines-graphbolt/analyze_training_time.py b/examples/sagemaker-pipelines-graphbolt/analyze_training_time.py index c17674b825..ab5c1548d0 100644 --- a/examples/sagemaker-pipelines-graphbolt/analyze_training_time.py +++ b/examples/sagemaker-pipelines-graphbolt/analyze_training_time.py @@ -28,6 +28,7 @@ def parse_args(): + """Parse log analysis args.""" parser = argparse.ArgumentParser( description="Analyze training epoch and eval time." ) @@ -259,17 +260,9 @@ def print_training_summary( if epochs_data: total_epochs = len(epochs_data) avg_time = sum(e["time"] for e in epochs_data) / total_epochs - min_time = min(epochs_data, key=lambda x: x["time"]) - max_time = max(epochs_data, key=lambda x: x["time"]) print(f"Total epochs completed: {total_epochs}") print(f"Average epoch time: {avg_time:.2f} seconds") - print( - f"Fastest epoch: Epoch {min_time['epoch']} ({min_time['time']:.2f} seconds)" - ) - print( - f"Slowest epoch: Epoch {max_time['epoch']} ({max_time['time']:.2f} seconds)" - ) if verbose: print("\nEpoch Details:") @@ -283,17 +276,9 @@ def print_training_summary( if eval_data: total_evals = len(eval_data) avg_eval_time = sum(e["time"] for e in eval_data) / total_evals - min_eval = min(eval_data, key=lambda x: x["time"]) - max_eval = max(eval_data, key=lambda x: x["time"]) print(f"Total evaluations: {total_evals}") print(f"Average evaluation time: {avg_eval_time:.2f} seconds") - print( - f"Fastest evaluation: Step {min_eval['step']} ({min_eval['time']:.2f} seconds)" - ) - print( - f"Slowest evaluation: Step {max_eval['step']} ({max_eval['time']:.2f} seconds)" - ) if verbose: print("\nEvaluation Details:") diff --git a/examples/sagemaker-pipelines-graphbolt/build_and_push_papers100M_image.sh b/examples/sagemaker-pipelines-graphbolt/build_and_push_papers100M_image.sh index 4c6fadaee3..fceb5dadef 100644 --- a/examples/sagemaker-pipelines-graphbolt/build_and_push_papers100M_image.sh +++ b/examples/sagemaker-pipelines-graphbolt/build_and_push_papers100M_image.sh @@ -9,11 +9,49 @@ cleanup() { } -ACCOUNT=$(aws sts get-caller-identity --query Account --output text) -REGION=$(aws configure get region) -REGION=${REGION:-us-east-1} +die() { + local msg=$1 + local code=${2-1} # default exit status 1 + msg "$msg" + exit "$code" +} + +parse_params() { + # default values of variables set from params + ACCOUNT=$(aws sts get-caller-identity --query Account --output text || true) + REGION=$(aws configure get region || true) + REGION=${REGION:-"us-east-1"} + + while :; do + case "${1-}" in + -h | --help) usage ;; + -x | --verbose) set -x ;; + -a | --account) + ACCOUNT="${2-}" + shift + ;; + -r | --region) + REGION="${2-}" + shift + ;; + -?*) die "Unknown option: $1" ;; + *) break ;; + esac + shift + done + + # check required params and arguments + [[ -z "${ACCOUNT-}" ]] && die "Missing required parameter: -a/--account " + [[ -z "${REGION-}" ]] && die "Missing required parameter: -r/--region " + + return 0 +} + +parse_params "$@" + IMAGE=papers100m-processor +# Download ripunzip to copy to image curl -L -O https://github.com/google/ripunzip/releases/download/v2.0.0/ripunzip_2.0.0-1_amd64.deb # Auth to AWS public ECR gallery @@ -22,7 +60,6 @@ aws ecr-public get-login-password --region $REGION | docker login --username AWS # Build and tag image docker build -f Dockerfile.processing -t $IMAGE . - # Create repository if it doesn't exist echo "Getting or creating container repository: $IMAGE" if ! $(aws ecr describe-repositories --repository-names $IMAGE --region ${REGION} > /dev/null 2>&1); then diff --git a/examples/sagemaker-pipelines-graphbolt/convert_ogb_papers100M_to_gconstruct.py b/examples/sagemaker-pipelines-graphbolt/convert_ogb_papers100m_to_gconstruct.py similarity index 93% rename from examples/sagemaker-pipelines-graphbolt/convert_ogb_papers100M_to_gconstruct.py rename to examples/sagemaker-pipelines-graphbolt/convert_ogb_papers100m_to_gconstruct.py index 0dfda2e70c..361c35b89d 100644 --- a/examples/sagemaker-pipelines-graphbolt/convert_ogb_papers100M_to_gconstruct.py +++ b/examples/sagemaker-pipelines-graphbolt/convert_ogb_papers100m_to_gconstruct.py @@ -1,3 +1,20 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"). +You may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Convert papers100M data and prepare for input to GConstruct +""" import argparse import gzip import json @@ -10,13 +27,14 @@ import pyarrow as pa import pyarrow.dataset as ds import pyarrow.parquet as pq -import pyarrow.fs as fs +from pyarrow import fs from tqdm_loggable.auto import tqdm # pylint: disable=logging-fstring-interpolation def parse_args(): + """Parse conversion arguments.""" parser = argparse.ArgumentParser( description="Convert raw OGB papers-100M data to GConstruct format" ) diff --git a/examples/sagemaker-pipelines-graphbolt/deploy_arxiv_pipeline.sh b/examples/sagemaker-pipelines-graphbolt/deploy_arxiv_pipeline.sh new file mode 100644 index 0000000000..e43b4f4335 --- /dev/null +++ b/examples/sagemaker-pipelines-graphbolt/deploy_arxiv_pipeline.sh @@ -0,0 +1,129 @@ +#!/bin/env bash +set -euox pipefail + +SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" &>/dev/null && pwd -P) + +msg() { + echo >&2 -e "${1-}" +} + +die() { + local msg=$1 + local code=${2-1} # default exit status 1 + msg "$msg" + exit "$code" +} + +parse_params() { + # default values of variables set from params + ACCOUNT=$(aws sts get-caller-identity --query Account --output text || true) + REGION=$(aws configure get region || true) + REGION=${REGION:-"us-east-1"} + PIPELINE_NAME="" + + + while :; do + case "${1-}" in + -h | --help) usage ;; + -x | --verbose) set -x ;; + -r | --role) + ROLE="${2-}" + shift + ;; + -a | --account) + ACCOUNT="${2-}" + shift + ;; + -b | --bucket) + BUCKET_NAME="${2-}" + shift + ;; + -n | --pipeline-name) + PIPELINE_NAME="${2-}" + shift + ;; + -g | --use-graphbolt) + USE_GRAPHBOLT="${2-}" + shift + ;; + -?*) die "Unknown option: $1" ;; + *) break ;; + esac + shift + done + + # check required params and arguments + [[ -z "${ACCOUNT-}" ]] && die "Missing required parameter: -a/--account " + [[ -z "${BUCKET-}" ]] && die "Missing required parameter: -b/--bucket " + [[ -z "${ROLE-}" ]] && die "Missing required parameter: -r/--role " + [[ -z "${USE_GRAPHBOLT-}" ]] && die "Missing required parameter: -g/--use-graphbolt " + + return 0 +} + +cleanup() { + trap - SIGINT SIGTERM ERR EXIT + # script cleanup here +} + +parse_params "$@" + +DATASET_S3_PATH="s3://${BUCKET_NAME}/ogb-arxiv-input" +OUTPUT_PATH="s3://${BUCKET_NAME}/pipelines-output" +GRAPH_NAME="ogbn-arxiv" +INSTANCE_COUNT="2" +REGION="us-east-1" +NUM_TRAINERS=4 + +PARTITION_OUTPUT_JSON="$GRAPH_NAME.json" +PARTITION_ALGORITHM="metis" +GCONSTRUCT_INSTANCE="ml.r5.4xlarge" +GCONSTRUCT_CONFIG="gconstruct_config_arxiv.json" + +TRAIN_CPU_INSTANCE="ml.m5.4xlarge" +TRAIN_YAML_S3="s3://$BUCKET_NAME/yaml/arxiv_nc_train.yaml" +INFERENCE_YAML_S3="s3://$BUCKET_NAME/yaml/arxiv_nc_inference.yaml" + +TASK_TYPE="node_classification" +INFERENCE_MODEL_SNAPSHOT="epoch-9" + +JOBS_TO_RUN="gconstruct train inference" +GSF_CPU_IMAGE_URI=${ACCOUNT}.dkr.ecr.$REGION.amazonaws.com/graphstorm:sagemaker-cpu +GSF_GPU_IMAGE_URI=${ACCOUNT}.dkr.ecr.$REGION.amazonaws.com/graphstorm:sagemaker-gpu +VOLUME_SIZE=50 + +if [[ -z "${PIPELINE_NAME-}" ]]; then + if [[ $USE_GRAPHBOLT == "true" ]]; then + PIPELINE_NAME="ogbn-arxiv-gs-graphbolt-pipeline" + else + PIPELINE_NAME="ogbn-arxiv-gs-pipeline" + fi +fi + +python3 $SCRIPT_DIR/../../sagemaker/pipeline/create_sm_pipeline.py \ + --cpu-instance-type ${TRAIN_CPU_INSTANCE} \ + --graph-construction-args "--num-processes 8" \ + --graph-construction-instance-type ${GCONSTRUCT_INSTANCE} \ + --graph-construction-config-filename ${GCONSTRUCT_CONFIG} \ + --graph-name ${GRAPH_NAME} \ + --graphstorm-pytorch-cpu-image-url "${GSF_CPU_IMAGE_URI}" \ + --graphstorm-pytorch-gpu-image-url "${GSF_GPU_IMAGE_URI}" \ + --inference-model-snapshot "${INFERENCE_MODEL_SNAPSHOT}" \ + --inference-yaml-s3 ${INFERENCE_YAML_S3} \ + --input-data-s3 ${DATASET_S3_PATH} \ + --instance-count ${INSTANCE_COUNT} \ + --jobs-to-run ${JOBS_TO_RUN} \ + --num-trainers ${NUM_TRAINERS} \ + --output-prefix-s3 "${OUTPUT_PATH}" \ + --pipeline-name "${PIPELINE_NAME}" \ + --partition-output-json ${PARTITION_OUTPUT_JSON} \ + --partition-algorithm ${PARTITION_ALGORITHM} \ + --region ${REGION} \ + --role "${ROLE}" \ + --train-on-cpu \ + --train-inference-task ${TASK_TYPE} \ + --train-yaml-s3 "${TRAIN_YAML_S3}" \ + --save-embeddings \ + --save-predictions \ + --volume-size-gb ${VOLUME_SIZE} \ + --use-graphbolt "${USE_GRAPHBOLT}" diff --git a/examples/sagemaker-pipelines-graphbolt/deploy_papers100M_pipeline.sh b/examples/sagemaker-pipelines-graphbolt/deploy_papers100M_pipeline.sh new file mode 100644 index 0000000000..d85a2edd1c --- /dev/null +++ b/examples/sagemaker-pipelines-graphbolt/deploy_papers100M_pipeline.sh @@ -0,0 +1,139 @@ +#!/bin/env bash +set -euox pipefail + +SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" &>/dev/null && pwd -P) + + +msg() { + echo >&2 -e "${1-}" +} + +die() { + local msg=$1 + local code=${2-1} # default exit status 1 + msg "$msg" + exit "$code" +} + +parse_params() { + # default values of variables set from params + ACCOUNT=$(aws sts get-caller-identity --query Account --output text || true) + REGION=$(aws configure get region || true) + REGION=${REGION:-"us-east-1"} + PIPELINE_NAME="" + + + while :; do + case "${1-}" in + -h | --help) usage ;; + -x | --verbose) set -x ;; + -r | --role) + ROLE="${2-}" + shift + ;; + -a | --account) + ACCOUNT="${2-}" + shift + ;; + -b | --bucket) + BUCKET_NAME="${2-}" + shift + ;; + -n | --pipeline-name) + PIPELINE_NAME="${2-}" + shift + ;; + -g | --use-graphbolt) + USE_GRAPHBOLT="${2-}" + shift + ;; + -?*) die "Unknown option: $1" ;; + *) break ;; + esac + shift + done + + # check required params and arguments + [[ -z "${ACCOUNT-}" ]] && die "Missing required parameter: -a/--account " + [[ -z "${BUCKET-}" ]] && die "Missing required parameter: -b/--bucket " + [[ -z "${ROLE-}" ]] && die "Missing required parameter: -r/--role " + [[ -z "${USE_GRAPHBOLT-}" ]] && die "Missing required parameter: -g/--use-graphbolt " + + return 0 +} + +cleanup() { + trap - SIGINT SIGTERM ERR EXIT + # script cleanup here +} + +parse_params "$@" + +if [[ ${USE_GRAPHBOLT} == "true" || ${USE_GRAPHBOLT} == "false" ]]; then + : # Do nothing +else + die "-g/--use-graphbolt parameter needs to be one of 'true' or 'false', got ${USE_GRAPHBOLT}" +fi + + +JOBS_TO_RUN="gconstruct train inference" + +OUTPUT_PATH="s3://${BUCKET_NAME}/pipelines-output" +GRAPH_NAME="papers-100M" +INSTANCE_COUNT="4" + +CPU_INSTANCE_TYPE="ml.r5.24xlarge" +TRAIN_GPU_INSTANCE="ml.g5.48xlarge" +GCONSTRUCT_INSTANCE="ml.r5.24xlarge" +NUM_TRAINERS=8 + +GSF_CPU_IMAGE_URI=${ACCOUNT}.dkr.ecr.$REGION.amazonaws.com/graphstorm:sagemaker-cpu +GSF_GPU_IMAGE_URI=${ACCOUNT}.dkr.ecr.$REGION.amazonaws.com/graphstorm:sagemaker-gpu + +GCONSTRUCT_CONFIG="gconstruct_config_papers100m.json" +GRAPH_CONSTRUCTION_ARGS="--add-reverse-edges False --num-processes 16" + +PARTITION_OUTPUT_JSON="metadata.json" +PARTITION_OUTPUT_JSON="$GRAPH_NAME.json" +PARTITION_ALGORITHM="metis" +TRAIN_YAML_S3="s3://$BUCKET_NAME/yaml/papers100M_nc.yaml" +INFERENCE_YAML_S3="s3://$BUCKET_NAME/yaml/papers100M_nc.yaml" +TASK_TYPE="node_classification" +INFERENCE_MODEL_SNAPSHOT="epoch-14" +VOLUME_SIZE=400 + +if [[ -z "${PIPELINE_NAME-}" ]]; then + if [[ $USE_GRAPHBOLT == "true" ]]; then + PIPELINE_NAME="papers100M-gs-graphbolt-pipeline" + else + PIPELINE_NAME="papers100M-gs-pipeline" + fi +fi + +python3 $SCRIPT_DIR/../../sagemaker/pipeline/create_sm_pipeline.py \ + --execution-role "${ROLE}" \ + --cpu-instance-type ${CPU_INSTANCE_TYPE} \ + --gpu-instance-type ${TRAIN_GPU_INSTANCE} \ + --graph-construction-args "${GRAPH_CONSTRUCTION_ARGS}" \ + --graph-construction-instance-type ${GCONSTRUCT_INSTANCE} \ + --graph-construction-config-filename ${GCONSTRUCT_CONFIG} \ + --graph-name ${GRAPH_NAME} \ + --graphstorm-pytorch-cpu-image-url "${GSF_CPU_IMAGE_URI}" \ + --graphstorm-pytorch-gpu-image-url "${GSF_GPU_IMAGE_URI}" \ + --inference-model-snapshot "${INFERENCE_MODEL_SNAPSHOT}" \ + --inference-yaml-s3 "${INFERENCE_YAML_S3}" \ + --input-data-s3 "${DATASET_S3_PATH}" \ + --instance-count ${INSTANCE_COUNT} \ + --jobs-to-run "${JOBS_TO_RUN}" \ + --num-trainers ${NUM_TRAINERS} \ + --output-prefix-s3 "${OUTPUT_PATH}" \ + --pipeline-name "${PIPELINE_NAME}" \ + --partition-output-json ${PARTITION_OUTPUT_JSON} \ + --partition-algorithm ${PARTITION_ALGORITHM} \ + --region ${REGION} \ + --train-inference-task ${TASK_TYPE} \ + --train-yaml-s3 "${TRAIN_YAML_S3}" \ + --save-embeddings \ + --save-predictions \ + --volume-size-gb ${VOLUME_SIZE} \ + --use-graphbolt ${USE_GRAPHBOLT} From 08fbb22b754fdd4fdcee9931601f690a39fcb7e2 Mon Sep 17 00:00:00 2001 From: Theodore Vasiloudis Date: Thu, 9 Jan 2025 14:13:55 -0800 Subject: [PATCH 4/9] Updates and bug fixes. Add sagemaker launch script for papers100m conversion job --- .../sagemaker-pipelines-graphbolt/README.md | 52 +++++----- .../convert_arxiv_to_gconstruct.py | 14 +-- .../convert_ogb_papers100m_to_gconstruct.py | 3 +- .../deploy_arxiv_pipeline.sh | 12 +-- .../deploy_papers100M_pipeline.sh | 15 +-- .../process_papers100M.sh | 4 +- .../sagemaker_convert_papers100m.py | 94 +++++++++++++++++++ 7 files changed, 143 insertions(+), 51 deletions(-) create mode 100644 examples/sagemaker-pipelines-graphbolt/sagemaker_convert_papers100m.py diff --git a/examples/sagemaker-pipelines-graphbolt/README.md b/examples/sagemaker-pipelines-graphbolt/README.md index 44499a3dd6..b98f912645 100644 --- a/examples/sagemaker-pipelines-graphbolt/README.md +++ b/examples/sagemaker-pipelines-graphbolt/README.md @@ -90,7 +90,7 @@ aws ec2 run-instances \ --instance-type "m6in.4xlarge" \ --key-name my-key-name \ --block-device-mappings '[{ - "DeviceName": "/dev/sdf", + "DeviceName": "/dev/sda1", "Ebs": { "VolumeSize": 300, "VolumeType": "gp3", @@ -108,15 +108,14 @@ Once logged in, you can set up your Python environment to run GraphStorm ```bash conda init eval $SHELL -conda create -y --name gsf python=3.10 -conda activate gsf +# Available on the DLAMI, otherwise create a new conda env +conda activate pytorch # Install dependencies -pip install sagemaker boto3 ogb pyarrow +pip install sagemaker[local] boto3 ogb pyarrow # Clone the GraphStorm repository to access the example code git clone https://github.com/awslabs/graphstorm.git ~/graphstorm -cd ~/graphstorm/examples/sagemaker-pipelines-graphbolt ``` ### Download and prepare datasets @@ -136,12 +135,11 @@ You'lll download the smaller-scale [ogbn-arxiv](https://ogb.stanford.edu/docs/no BUCKET_NAME= ``` - You will use this script to directly download, transform and upload the data to S3: - ```bash -python convert_ogb_arxiv_to_gconstruct.py \ +cd ~/graphstorm/examples/sagemaker-pipelines-graphbolt +python convert_arxiv_to_gconstruct.py \ --output-prefix s3://$BUCKET_NAME/ogb-arxiv-input ``` @@ -188,12 +186,14 @@ bash build_and_push_papers100M_image.sh # $ACCOUNT_ID.dkr.ecr.$REGION.amazonaws.com/papers100m-processor # Run a SageMaker job to do the processing and upload the output to S3 -SAGEMAKER_EXECUTION_ROLE= +SAGEMAKER_EXECUTION_ROLE_ARN= ACCOUNT_ID= REGION=us-east-1 -python sagemaker_convert_papers100M.py \ + +aws configure set region $REGION +python sagemaker_convert_papers100m.py \ --output-bucket $BUCKET_NAME \ - --execution-role-arn $SAGEMAKER_EXECUTION_ROLE \ + --execution-role-arn $SAGEMAKER_EXECUTION_ROLE_ARN \ --region $REGION \ --instance-type ml.m5.4xlarge \ --image-uri $ACCOUNT_ID.dkr.ecr.$REGION.amazonaws.com/papers100m-processor @@ -201,6 +201,8 @@ python sagemaker_convert_papers100M.py \ This will produce the processed data at `s3://$BUCKET_NAME/ogb-papers100M-input` which can then be used as input to GraphStorm. +> NOTE: Ensure your instance IAM profile is allow to perform `iam:GetRole` and `iam:GetPolicy` on your `SAGEMAKER_EXECUTION_ROLE_ARN`. + #### [Optional] Prepare the ogbn-papers100M dataset locally @@ -220,7 +222,8 @@ mkdir ~/papers100M-raw-data cd ~/papers100M-raw-data axel -n 16 http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip ripuznip unzip-file papers100M-bin.zip -ripunzip unzip-file papers100M-bin/raw/data.npz && rm papers100M-bin/raw/data.npz +cd papers100M-bin/raw +ripunzip unzip-file data.npz && rm data.npz # Install process script dependencies python -m pip install \ @@ -232,6 +235,7 @@ python -m pip install \ # Process and upload to S3, this will take around 20 minutes +cd ~/graphstorm/examples/sagemaker-pipelines-graphbolt python convert_ogb_papers100m_to_gconstruct.py \ --input-dir ~/papers100M-raw-data --output-dir s3://$BUCKET_NAME/ogb-papers100M-input @@ -248,10 +252,6 @@ sudo apt update sudo apt install -y Docker.io docker -v -# Enter you account ID here -ACCOUNT_ID= -REGION=us-east-1 - cd ~/graphstorm bash ./docker/build_graphstorm_image.sh --environment sagemaker --device cpu @@ -259,9 +259,6 @@ bash ./docker/build_graphstorm_image.sh --environment sagemaker --device cpu bash docker/push_graphstorm_image.sh -e sagemaker -r $REGION -a $ACCOUNT_ID -d cpu # This will push an image to # ${ACCOUNT_ID}.dkr.ecr.us-east-1.amazonaws.com/graphstorm:sagemaker-cpu - -# Install sagemaker with support for local mode -pip install sagemaker[local] ``` Next, you will create a SageMaker Pipeline to run the jobs that are necessary to train GNN models with GraphStorm. @@ -276,10 +273,10 @@ In this section, you will create a [Sagemaker Pipeline](https://docs.aws.amazon. ```bash PIPELINE_NAME="ogbn-arxiv-gs-pipeline" -BUCKET_NAME="my-s3-bucket" + bash deploy_papers100M_pipeline.sh \ - --account "" \ - --bucket-name $BUCKET_NAME --role "" \ + --account $ACCOUNT_ID \ + --bucket-name $BUCKET_NAME --role $SAGEMAKER_EXECUTION_ROLE_ARN \ --pipeline-name $PIPELINE_NAME \ --use-graphbolt false ``` @@ -303,8 +300,8 @@ The ogbn-arxiv data are small enough that you can execute the pipeline locally. ```bash PIPELINE_NAME="ogbn-arxiv-gs-pipeline" -cd ~/graphstorm/sagemaker/pipeline -python execute_sm_pipeline.py \ + +python ~/graphstorm/sagemaker/pipeline/execute_sm_pipeline.py \ --pipeline-name $PIPELINE_NAME \ --region us-east-1 \ --local-execution | tee arxiv-local-logs.txt @@ -382,7 +379,7 @@ bash deploy_arxiv_pipeline.sh \ --pipeline-name $PIPELINE_NAME \ --use-graphbolt true # Execute the pipeline locally -python execute_sm_pipeline.py \ +python ~/graphstorm/sagemaker/pipeline/execute_sm_pipeline.py \ --pipeline-name $PIPELINE_NAME \ --region us-east-1 \ --local-execution | tee arxiv-local-gb-logs.txt @@ -439,6 +436,7 @@ Now you are ready to deploy your initial pipeline for papers-100M ```bash PIPELINE_NAME="ogb-papers100M-pipeline" +cd ~/graphstorm/examples/sagemaker-pipelines-graphbolt/ bash deploy_papers100M_pipeline.sh \ --account \ --bucket-name --role \ @@ -449,7 +447,7 @@ bash deploy_papers100M_pipeline.sh \ Execute the pipeline and let it run the background. ```bash -python execute_sm_pipeline.py \ +python ~/graphstorm/sagemaker/pipeline/execute_sm_pipeline.py \ --pipeline-name $PIPELINE_NAME \ --region us-east-1 --async-execution @@ -469,7 +467,7 @@ bash deploy_papers100M_pipeline.sh \ --use-graphbolt true # Execute the GraphBolt-enabled pipeline on SageMaker -python execute_sm_pipeline.py \ +python ~/graphstorm/sagemaker/pipeline/execute_sm_pipeline.py \ --pipeline-name $PIPELINE_NAME \ --region us-east-1 \ --async-execution diff --git a/examples/sagemaker-pipelines-graphbolt/convert_arxiv_to_gconstruct.py b/examples/sagemaker-pipelines-graphbolt/convert_arxiv_to_gconstruct.py index dd66f09f01..456c2c5bdf 100644 --- a/examples/sagemaker-pipelines-graphbolt/convert_arxiv_to_gconstruct.py +++ b/examples/sagemaker-pipelines-graphbolt/convert_arxiv_to_gconstruct.py @@ -100,7 +100,7 @@ def convert_ogbn_arxiv(output_prefix: str): "node_id_col": "nid", "node_type": "node", "format": {"name": "parquet"}, - "files": [f"{output_prefix}/nodes/paper/nodes.parquet"], + "files": ["nodes/paper/nodes.parquet"], "features": [ { "feature_col": "feat", @@ -118,9 +118,9 @@ def convert_ogbn_arxiv(output_prefix: str): "task_type": "classification", "custom_split_filenames": { "column": "nid", - "train": f"{output_prefix}/splits/train_idx.parquet", - "valid": f"{output_prefix}/splits/valid_idx.parquet", - "test": f"{output_prefix}/splits/test_idx.parquet", + "train": "splits/train_idx.parquet", + "valid": "splits/valid_idx.parquet", + "test": "splits/test_idx.parquet", }, "label_stats_type": "frequency_cnt", } @@ -133,14 +133,14 @@ def convert_ogbn_arxiv(output_prefix: str): "dest_id_col": "dst", "relation": ["node", "cites", "node"], "format": {"name": "parquet"}, - "files": [f"{output_prefix}/edges/paper-cites-paper/edges.parquet"], + "files": ["edges/paper-cites-paper/edges.parquet"], }, { "source_id_col": "dst", "dest_id_col": "src", "relation": ["node", "cites-rev", "node"], "format": {"name": "parquet"}, - "files": [f"{output_prefix}/edges/paper-cites-paper/edges.parquet"], + "files": ["/edges/paper-cites-paper/edges.parquet"], }, ], } @@ -160,4 +160,4 @@ def convert_ogbn_arxiv(output_prefix: str): if __name__ == "__main__": args = parse_args() - convert_ogbn_arxiv(args.output_prefix) + convert_ogbn_arxiv(args.output_s3_prefix) diff --git a/examples/sagemaker-pipelines-graphbolt/convert_ogb_papers100m_to_gconstruct.py b/examples/sagemaker-pipelines-graphbolt/convert_ogb_papers100m_to_gconstruct.py index 361c35b89d..497e9d7902 100644 --- a/examples/sagemaker-pipelines-graphbolt/convert_ogb_papers100m_to_gconstruct.py +++ b/examples/sagemaker-pipelines-graphbolt/convert_ogb_papers100m_to_gconstruct.py @@ -15,6 +15,7 @@ Convert papers100M data and prepare for input to GConstruct """ + import argparse import gzip import json @@ -87,7 +88,7 @@ def process_data(input_dir, output_dir, filesystem): num_nodes, num_features = node_feat.shape num_edges = edge_index.shape[1] logging.info( - f"Node features shape: {node_feat.shape:,}, Number of edges: {num_edges:,}" + f"Node features shape: {node_feat.shape}, Number of edges: {num_edges:,}" ) # Define schemas for nodes and edges diff --git a/examples/sagemaker-pipelines-graphbolt/deploy_arxiv_pipeline.sh b/examples/sagemaker-pipelines-graphbolt/deploy_arxiv_pipeline.sh index e43b4f4335..91f3a1830f 100644 --- a/examples/sagemaker-pipelines-graphbolt/deploy_arxiv_pipeline.sh +++ b/examples/sagemaker-pipelines-graphbolt/deploy_arxiv_pipeline.sh @@ -26,15 +26,15 @@ parse_params() { case "${1-}" in -h | --help) usage ;; -x | --verbose) set -x ;; - -r | --role) - ROLE="${2-}" + -r | --execution-role) + ROLE_ARN="${2-}" shift ;; -a | --account) ACCOUNT="${2-}" shift ;; - -b | --bucket) + -b | --bucket-name) BUCKET_NAME="${2-}" shift ;; @@ -54,8 +54,8 @@ parse_params() { # check required params and arguments [[ -z "${ACCOUNT-}" ]] && die "Missing required parameter: -a/--account " - [[ -z "${BUCKET-}" ]] && die "Missing required parameter: -b/--bucket " - [[ -z "${ROLE-}" ]] && die "Missing required parameter: -r/--role " + [[ -z "${BUCKET_NAME-}" ]] && die "Missing required parameter: -b/--bucket " + [[ -z "${ROLE_ARN-}" ]] && die "Missing required parameter: -r/--execution-role " [[ -z "${USE_GRAPHBOLT-}" ]] && die "Missing required parameter: -g/--use-graphbolt " return 0 @@ -102,6 +102,7 @@ fi python3 $SCRIPT_DIR/../../sagemaker/pipeline/create_sm_pipeline.py \ --cpu-instance-type ${TRAIN_CPU_INSTANCE} \ + --execution-role "${ROLE_ARN}" \ --graph-construction-args "--num-processes 8" \ --graph-construction-instance-type ${GCONSTRUCT_INSTANCE} \ --graph-construction-config-filename ${GCONSTRUCT_CONFIG} \ @@ -119,7 +120,6 @@ python3 $SCRIPT_DIR/../../sagemaker/pipeline/create_sm_pipeline.py \ --partition-output-json ${PARTITION_OUTPUT_JSON} \ --partition-algorithm ${PARTITION_ALGORITHM} \ --region ${REGION} \ - --role "${ROLE}" \ --train-on-cpu \ --train-inference-task ${TASK_TYPE} \ --train-yaml-s3 "${TRAIN_YAML_S3}" \ diff --git a/examples/sagemaker-pipelines-graphbolt/deploy_papers100M_pipeline.sh b/examples/sagemaker-pipelines-graphbolt/deploy_papers100M_pipeline.sh index d85a2edd1c..4f94a03cab 100644 --- a/examples/sagemaker-pipelines-graphbolt/deploy_papers100M_pipeline.sh +++ b/examples/sagemaker-pipelines-graphbolt/deploy_papers100M_pipeline.sh @@ -27,15 +27,15 @@ parse_params() { case "${1-}" in -h | --help) usage ;; -x | --verbose) set -x ;; - -r | --role) - ROLE="${2-}" + -r | --execution-role) + ROLE_ARN="${2-}" shift ;; -a | --account) ACCOUNT="${2-}" shift ;; - -b | --bucket) + -b | --bucket-name) BUCKET_NAME="${2-}" shift ;; @@ -56,7 +56,7 @@ parse_params() { # check required params and arguments [[ -z "${ACCOUNT-}" ]] && die "Missing required parameter: -a/--account " [[ -z "${BUCKET-}" ]] && die "Missing required parameter: -b/--bucket " - [[ -z "${ROLE-}" ]] && die "Missing required parameter: -r/--role " + [[ -z "${ROLE_ARN-}" ]] && die "Missing required parameter: -r/--execution-role " [[ -z "${USE_GRAPHBOLT-}" ]] && die "Missing required parameter: -g/--use-graphbolt " return 0 @@ -78,6 +78,7 @@ fi JOBS_TO_RUN="gconstruct train inference" +DATASET_S3_PATH="s3://${BUCKET_NAME}/papers-100M-input" OUTPUT_PATH="s3://${BUCKET_NAME}/pipelines-output" GRAPH_NAME="papers-100M" INSTANCE_COUNT="4" @@ -91,7 +92,7 @@ GSF_CPU_IMAGE_URI=${ACCOUNT}.dkr.ecr.$REGION.amazonaws.com/graphstorm:sagemaker- GSF_GPU_IMAGE_URI=${ACCOUNT}.dkr.ecr.$REGION.amazonaws.com/graphstorm:sagemaker-gpu GCONSTRUCT_CONFIG="gconstruct_config_papers100m.json" -GRAPH_CONSTRUCTION_ARGS="--add-reverse-edges False --num-processes 16" +GRAPH_CONSTRUCTION_ARGS="--num-processes 16" PARTITION_OUTPUT_JSON="metadata.json" PARTITION_OUTPUT_JSON="$GRAPH_NAME.json" @@ -111,7 +112,7 @@ if [[ -z "${PIPELINE_NAME-}" ]]; then fi python3 $SCRIPT_DIR/../../sagemaker/pipeline/create_sm_pipeline.py \ - --execution-role "${ROLE}" \ + --execution-role "${ROLE_ARN}" \ --cpu-instance-type ${CPU_INSTANCE_TYPE} \ --gpu-instance-type ${TRAIN_GPU_INSTANCE} \ --graph-construction-args "${GRAPH_CONSTRUCTION_ARGS}" \ @@ -124,7 +125,7 @@ python3 $SCRIPT_DIR/../../sagemaker/pipeline/create_sm_pipeline.py \ --inference-yaml-s3 "${INFERENCE_YAML_S3}" \ --input-data-s3 "${DATASET_S3_PATH}" \ --instance-count ${INSTANCE_COUNT} \ - --jobs-to-run "${JOBS_TO_RUN}" \ + --jobs-to-run ${JOBS_TO_RUN} \ --num-trainers ${NUM_TRAINERS} \ --output-prefix-s3 "${OUTPUT_PATH}" \ --pipeline-name "${PIPELINE_NAME}" \ diff --git a/examples/sagemaker-pipelines-graphbolt/process_papers100M.sh b/examples/sagemaker-pipelines-graphbolt/process_papers100M.sh index d24d5b92a8..d99def1d53 100644 --- a/examples/sagemaker-pipelines-graphbolt/process_papers100M.sh +++ b/examples/sagemaker-pipelines-graphbolt/process_papers100M.sh @@ -5,7 +5,6 @@ trap cleanup SIGINT SIGTERM ERR EXIT cleanup() { trap - SIGINT SIGTERM ERR EXIT # script cleanup here - kill $DISK_USAGE_PID > /dev/null 2>&1 || true } # Download and unzip data in parallel @@ -16,7 +15,6 @@ cd $TEMP_DATA_PATH || exit 1 echo "Will execute script $1 with output prefix $2" - echo "$(date -u '+%Y-%m-%dT%H:%M:%SZ'): Downloading files using axel, this will take at least 10 minutes depending on network speed" time axel -n 16 --quiet http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip @@ -29,6 +27,6 @@ time ripunzip unzip-file data.npz && rm data.npz # Run the processing script echo "$(date -u '+%Y-%m-%dT%H:%M:%SZ'): Processing data and uploading to S3, this will take around 20 minutes" -python3 /opt/ml/code/"$1" \ +time python3 /opt/ml/code/"$1" \ --input-dir "$TEMP_DATA_PATH/papers100M-bin/" \ --output-prefix "$2" diff --git a/examples/sagemaker-pipelines-graphbolt/sagemaker_convert_papers100m.py b/examples/sagemaker-pipelines-graphbolt/sagemaker_convert_papers100m.py new file mode 100644 index 0000000000..f611070595 --- /dev/null +++ b/examples/sagemaker-pipelines-graphbolt/sagemaker_convert_papers100m.py @@ -0,0 +1,94 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"). +You may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Launch SageMaker job to convert papers100M data and prepare for input to GConstruct +""" +import argparse +import os + +from sagemaker.processing import ScriptProcessor +from sagemaker.network import NetworkConfig +from sagemaker import get_execution_role + +_ROOT = os.path.abspath(os.path.dirname(__file__)) + + +def parse_args() -> argparse.Namespace: + """Parse job launch arguments""" + parser = argparse.ArgumentParser( + description="Convert Papers100M dataset to GConstruct format using SageMaker Processing." + ) + + parser.add_argument( + "--execution-role-arn", + type=str, + default=None, + help="SageMaker Execution Role ARN", + ) + parser.add_argument( + "--region", type=str, required=True, help="SageMaker Processing region." + ) + parser.add_argument("--image-uri", type=str, required=True) + parser.add_argument( + "--output-bucket", + type=str, + required=True, + help="S3 output bucket for processed papers100M data. " + "Data will be saved under ``/ogb-papers100M-input/``", + ) + parser.add_argument( + "--instance-type", + type=str, + default="ml.m5.4xlarge", + help="SageMaker Processing Instance type.", + ) + + return parser.parse_args() + + +def main(): + """Launch the papers100M conversion job on SageMaker""" + args = parse_args() + + # Create a ScriptProcessor to run the processing bash script + script_processor = ScriptProcessor( + command=["bash"], + image_uri=args.image_uri, + role=args.execution_role_arn or get_execution_role(), + instance_count=1, + instance_type=args.instance_type, + volume_size_in_gb=400, + max_runtime_in_seconds=8 * 60 * 60, # Adjust as needed + base_job_name="papers100m-processing", + network_config=NetworkConfig( + enable_network_isolation=False + ), # Enable internet access to be able to download the data + ) + + # Submit the processing job + script_processor.run( + code="process_papers100M.sh", + inputs=[], + outputs=[], + arguments=[ + "convert_ogb_papers100m_to_gconstruct.py", + f"s3://{args.output_bucket}/papers-100M-input", + ], + wait=False, + ) + + +if __name__ == "__main__": + main() From db877bb2dbdc370b1a17530dc8bb452ecc104a28 Mon Sep 17 00:00:00 2001 From: Theodore Vasiloudis Date: Thu, 9 Jan 2025 16:37:54 -0800 Subject: [PATCH 5/9] Small fix for arxiv reverse edge path --- .../convert_arxiv_to_gconstruct.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sagemaker-pipelines-graphbolt/convert_arxiv_to_gconstruct.py b/examples/sagemaker-pipelines-graphbolt/convert_arxiv_to_gconstruct.py index 456c2c5bdf..f947478a13 100644 --- a/examples/sagemaker-pipelines-graphbolt/convert_arxiv_to_gconstruct.py +++ b/examples/sagemaker-pipelines-graphbolt/convert_arxiv_to_gconstruct.py @@ -140,7 +140,7 @@ def convert_ogbn_arxiv(output_prefix: str): "dest_id_col": "src", "relation": ["node", "cites-rev", "node"], "format": {"name": "parquet"}, - "files": ["/edges/paper-cites-paper/edges.parquet"], + "files": ["edges/paper-cites-paper/edges.parquet"], }, ], } From 44436c2fb108bdff89d95df58c372ddb7e21ca0f Mon Sep 17 00:00:00 2001 From: Theodore Vasiloudis Date: Mon, 13 Jan 2025 08:55:13 -0800 Subject: [PATCH 6/9] Apply suggestions from code review Co-authored-by: xiang song(charlie.song) --- .../sagemaker-pipelines-graphbolt/README.md | 20 +++++++++---------- .../deploy_arxiv_pipeline.sh | 2 +- .../deploy_papers100M_pipeline.sh | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/sagemaker-pipelines-graphbolt/README.md b/examples/sagemaker-pipelines-graphbolt/README.md index b98f912645..85c4a26b96 100644 --- a/examples/sagemaker-pipelines-graphbolt/README.md +++ b/examples/sagemaker-pipelines-graphbolt/README.md @@ -15,7 +15,7 @@ In this example, you will: Before diving into our hands-on example, it's important to understand some challenges associated with graph training, especially as graphs grow in size and complexity: 1. Memory Constraints: As graphs grow larger, they may no longer fit into the memory of a single machine. A graph with 1B nodes with 512 features per node and 10B edges will require more than 4TB of memory to store, even with optimal representation. This necessitates distributed processing and more efficient graph representation. -2. Graph Sampling: In GNN mini-batch training, you need to sample neighbors for each node to propagate their representations. For multi-layer GNNs, this can lead to exponential growth in the number of nodes sampled, potentially visiting the entire graph for a single node's representation. Efficient sampling methods become necessary. +2. Graph Sampling: In GNN mini-batch training, you need to sample neighbors for each node to propagate their representations. For multi-layer GNNs, this can lead to exponential growth in the number of nodes sampled. Efficient sampling methods become necessary. 3. Remote Data Access: When training on multiple machines, retrieving node features and sampling neighborhoods from other machines will significantly impact performance due to network latency. For example, reading a 1024-feature vector from main memory will take around 3μs, while reading that vector from a remote key/value store would take 50-100x longer. GraphStorm and GraphBolt help address these challenges through efficient graph representations, smart sampling techniques, and sophisticated partitioning algorithms like ParMETIS. @@ -51,7 +51,7 @@ Our benchmarks show significant improvements in both memory usage and training s Figure 1: GraphStorm SageMaker architecture. -A common model development process is to perform model exploration locally on a subset of your full data, and once satisfied with the results train the full scale model. GraphStorm and SageMaker Pipelines allows you to do that by creating a model pipeline you can execute locally to retrieve model metrics, and when ready execute your pipeline on the full data, and produce models, predictions and graph embeddings to use in downstream tasks. In the next section you'll learn how to set up such pipelines for GraphStorm. +A common model development process is to perform model exploration locally on a subset of your full data, and once satisfied with the results train the full scale model. GraphStorm-SageMaker Pipelines allows you to do that by creating a model pipeline you can execute locally to retrieve model metrics, and when ready execute your pipeline on the full data to produce models, predictions and graph embeddings for downstream tasks. In the next section you'll learn how to set up such pipelines for GraphStorm. ## Set up environment for SageMaker distributed training @@ -70,7 +70,7 @@ In order to use SageMaker Studio you will need to have a SageMaker Domain availa ### Set up appropriate roles to use with SageMaker Pipelines -To set up the SageMaker Pipelines you will need permissions to create ECR repositories, pull and push to them, pull from the AWS ECR Public Gallery, launch SageMaker jobs, manage SageMaker Pipelines, and interact with data on S3. We will create a role for Amazon EC2 on the AWS console, which will also create an associated instance profile to use with an EC2 instance. +To set up the SageMaker Pipelines you will need permissions to create ECR repositories, pull and push docker images to them, pull images from the AWS ECR Public Gallery, launch SageMaker jobs, manage SageMaker Pipelines, and interact with data on S3. We will create a role for Amazon EC2 on the AWS console, which will also create an associated instance profile to use with an EC2 instance. You will also need access to a SageMaker execution that your jobs assume during execution. You can use the [Amazon SageMaker Role Manager](https://docs.aws.amazon.com/sagemaker/latest/dg/role-manager.html) to streamline the creation of the necessary roles. @@ -120,14 +120,14 @@ git clone https://github.com/awslabs/graphstorm.git ~/graphstorm ### Download and prepare datasets -In this example you will use two related datasets to demonstrate the scalability of GraphStorm. The Open Graph Benchmark (OGB) project hosts a number of graph datasets that can be used to benchmark the performance of graph learning systems. In this example you will use two citation network datasets, the ogbn-arxiv dataset for a small-scale demo, and the ogbn-papers100M dataset for a demonstration of GraphStorm's large-scale learning capabilities. +The Open Graph Benchmark (OGB) project hosts a number of graph datasets that can be used to benchmark the performance of graph learning systems. In this example you will use two citation network datasets, the ogbn-arxiv dataset for a small-scale demo, and the ogbn-papers100M dataset for a demonstration of GraphStorm's large-scale learning capabilities. Because the two datasets have similar schemas and the same task (node classification) they allow us to emulate a typical data science pipeline, where we first do some model development and testing on a smaller dataset locally, and once ready launch SageMaker jobs to train on the full-scale data. #### Prepare the ogbn-arxiv dataset -You'lll download the smaller-scale [ogbn-arxiv](https://ogb.stanford.edu/docs/nodeprop/#ogbn-arxiv) dataset to run a local test before launching larger scale SageMaker jobs on AWS. This dataset has ~170K nodes and ~1.2M edges. You will use the following script to download the arxiv data and prepare them for GraphStorm. +You'll download the smaller-scale [ogbn-arxiv](https://ogb.stanford.edu/docs/nodeprop/#ogbn-arxiv) dataset to run a local test before launching larger scale SageMaker jobs on AWS. This dataset has ~170K nodes and ~1.2M edges. You will use the following script to download the arxiv data and prepare them for GraphStorm. ```bash @@ -265,11 +265,11 @@ Next, you will create a SageMaker Pipeline to run the jobs that are necessary to ## Create SageMaker Pipeline -In this section, you will create a [Sagemaker Pipeline](https://docs.aws.amazon.com/sagemaker/latest/dg/pipelines-overview.html) on AWS SageMaker. The pipeline will run the following jobs in sequence: +In this section, you will create a [Sagemaker Pipeline](https://docs.aws.amazon.com/sagemaker/latest/dg/pipelines-overview.html) on AWS SageMaker. The pipeline will run the following jobs in sequence: * Launch GConstruct Processing job. This prepares and partitions the data for distributed training.. * Launch GraphStorm Training Job. This will train the model and create model output on S3. -* Launch GraphStorm Inference Job. This will generate predictions and embeddings for every node in the input. +* Launch GraphStorm Inference Job. This will generate predictions and embeddings for every node in the input graph. ```bash PIPELINE_NAME="ogbn-arxiv-gs-pipeline" @@ -289,7 +289,7 @@ Running the above will create a SageMaker Pipeline configured to run 3 SageMaker * A GraphStorm training job that trains a node classification model and saves the model to S3. * A GraphStorm inference job that produces predictions for all nodes in the test set, and creates embeddings for all nodes. -To review the pipeline, navigate to [SageMaker AI Studio](https://us-east-1.console.aws.amazon.com/sagemaker/home?region=us-east-1#/studio-landing) on the AWS Console, select the domain and user profile you used to create the pipeline in the drop-down menus on the top right, then select **Open Studio**. +To review the pipeline, navigate to [SageMaker AI Studio](https://us-east-1.console.aws.amazon.com/sagemaker/home?region=us-east-1#/studio-landing) on the AWS Console, select the domain and user profile you used to create the pipeline in the drop-down menus on the top right, then select **Open Studio**. On the left navigation menu, select **Pipelines**. There should be a pipeline named **ogbn-arxiv-gs-pipeline**. Select that, which will take you to the **Executions** tab for the pipeline. Select **Graph** to view the pipeline steps. @@ -366,7 +366,7 @@ Note that these numbers will vary depending on your instance type. Now that you have established a baseline for performance you can create another pipeline that uses the GraphBolt graph representation to compare the performance. -You can use the same pipeline creation script, but change two variables, providing a new pipeline name, and setting `USE_GRAPHBOLT` to `“true”`. +You can use the same pipeline creation script, but change two variables, providing a new pipeline name, and setting `USE_GRAPHBOLT` to `“true”` as `--use-graphbolt true`. ```bash @@ -385,7 +385,7 @@ python ~/graphstorm/sagemaker/pipeline/execute_sm_pipeline.py \ --local-execution | tee arxiv-local-gb-logs.txt ``` -Analyzing the training logs you can see the per-epoch time has dropped somewhat: +Analyzing the training logs you can see a noticeable reduction in per-epoch time: ```bash python analyze_training_time.py --log-file arxiv-local-gb-logs.txt diff --git a/examples/sagemaker-pipelines-graphbolt/deploy_arxiv_pipeline.sh b/examples/sagemaker-pipelines-graphbolt/deploy_arxiv_pipeline.sh index 91f3a1830f..23b395d731 100644 --- a/examples/sagemaker-pipelines-graphbolt/deploy_arxiv_pipeline.sh +++ b/examples/sagemaker-pipelines-graphbolt/deploy_arxiv_pipeline.sh @@ -54,7 +54,7 @@ parse_params() { # check required params and arguments [[ -z "${ACCOUNT-}" ]] && die "Missing required parameter: -a/--account " - [[ -z "${BUCKET_NAME-}" ]] && die "Missing required parameter: -b/--bucket " + [[ -z "${BUCKET_NAME-}" ]] && die "Missing required parameter: -b/--bucket-name " [[ -z "${ROLE_ARN-}" ]] && die "Missing required parameter: -r/--execution-role " [[ -z "${USE_GRAPHBOLT-}" ]] && die "Missing required parameter: -g/--use-graphbolt " diff --git a/examples/sagemaker-pipelines-graphbolt/deploy_papers100M_pipeline.sh b/examples/sagemaker-pipelines-graphbolt/deploy_papers100M_pipeline.sh index 4f94a03cab..7609045490 100644 --- a/examples/sagemaker-pipelines-graphbolt/deploy_papers100M_pipeline.sh +++ b/examples/sagemaker-pipelines-graphbolt/deploy_papers100M_pipeline.sh @@ -55,7 +55,7 @@ parse_params() { # check required params and arguments [[ -z "${ACCOUNT-}" ]] && die "Missing required parameter: -a/--account " - [[ -z "${BUCKET-}" ]] && die "Missing required parameter: -b/--bucket " + [[ -z "${BUCKET_NAME-}" ]] && die "Missing required parameter: -b/--bucket-name " [[ -z "${ROLE_ARN-}" ]] && die "Missing required parameter: -r/--execution-role " [[ -z "${USE_GRAPHBOLT-}" ]] && die "Missing required parameter: -g/--use-graphbolt " From deaa1c04b2c375d0200ae651b7667a1290fc1d5c Mon Sep 17 00:00:00 2001 From: Theodore Vasiloudis Date: Mon, 13 Jan 2025 17:40:31 +0000 Subject: [PATCH 7/9] Address review comments --- .../sagemaker-pipelines-graphbolt/README.md | 69 +++++++++++-------- .../analyze_training_time.py | 2 +- .../sagemaker_convert_papers100m.py | 3 +- 3 files changed, 43 insertions(+), 31 deletions(-) diff --git a/examples/sagemaker-pipelines-graphbolt/README.md b/examples/sagemaker-pipelines-graphbolt/README.md index 85c4a26b96..7b1b85588b 100644 --- a/examples/sagemaker-pipelines-graphbolt/README.md +++ b/examples/sagemaker-pipelines-graphbolt/README.md @@ -2,7 +2,7 @@ GraphStorm is a low-code enterprise graph machine learning (ML) framework that provides ML practitioners a simple way of building, training and deploying graph ML solutions on industry-scale graph data. While GraphStorm can run efficiently on single instances for small graphs, it truly shines when scaling to enterprise-level graphs in distributed mode using a cluster of EC2 instances or Amazon SageMaker. -GraphStorm 0.4 introduced integration with DGL-GraphBolt, a new graph storage and sampling framework that uses a compact graph representation and pipelined sampling to reduce memory requirements and speed up Graph Neural Network (GNN) training by up to 3x. In this example we'll show how GraphStorm 0.4 brings training and inference speedups of up to 3x. +GraphStorm 0.4 introduced integration with DGL-GraphBolt, a new graph storage and sampling framework that uses a compact graph representation and pipelined sampling to reduce memory requirements and speed up Graph Neural Network (GNN) training. In this example we'll show how GraphStorm 0.4 brings training and inference speedups of up to 3x on the papers100M dataset. In this example, you will: @@ -23,7 +23,6 @@ GraphStorm and GraphBolt help address these challenges through efficient graph r ## GraphBolt: pipeline-driven graph sampling - GraphBolt is a new data loading and graph sampling framework developed by the [DGL](https://www.dgl.ai/) team. It streamlines the operations needed to sample efficiently from a heterogeneous graph and fetch the corresponding features. GraphBolt introduces a new, more compact graph structure representation for heterogeneous graphs, called fused Compressed Sparse Column (fCSC). This can reduce the memory cost of storing a heterogeneous graph by up to 56%, allowing users to fit larger graphs in memory and potentially use smaller, more cost-efficient instances for GNN model training. @@ -35,7 +34,7 @@ GraphStorm 0.4.0 seamlessly integrates with GraphBolt, allowing users to leverag The integration of GraphBolt into GraphStorm's workflow means that users can now: -1. Load and process larger graphs with fewer hardware resources. +1. Train models on larger graphs with fewer hardware resources. 2. Achieve faster training and inference times with more efficient graph sampling framework. 3. Utilize GPU resources more effectively for graph learning. @@ -72,12 +71,13 @@ In order to use SageMaker Studio you will need to have a SageMaker Domain availa To set up the SageMaker Pipelines you will need permissions to create ECR repositories, pull and push docker images to them, pull images from the AWS ECR Public Gallery, launch SageMaker jobs, manage SageMaker Pipelines, and interact with data on S3. We will create a role for Amazon EC2 on the AWS console, which will also create an associated instance profile to use with an EC2 instance. -You will also need access to a SageMaker execution that your jobs assume during execution. You can use the [Amazon SageMaker Role Manager](https://docs.aws.amazon.com/sagemaker/latest/dg/role-manager.html) to streamline the creation of the necessary roles. +You will also need access to a [SageMaker execution role](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) that your jobs assume during execution. +You can use the [Amazon SageMaker Role Manager](https://docs.aws.amazon.com/sagemaker/latest/dg/role-manager.html) to streamline the creation of the necessary roles. ### Set up the pipeline management environment -For this example you can either use your existing development environment or set up a new EC2 instance. If you plan to use a new instance to prepare the large-scale data for this example, ensure it has at least 300GB of disk space available. +For this example we recommend you to set up a new EC2 instance with at least 300 GByte of disk space. To set up an EC2 instance with the appropriate environment: @@ -101,7 +101,7 @@ aws ec2 run-instances \ This command creates an instance using the "Deep Learning OSS Nvidia Driver AMI GPU PyTorch 2.4.1 (Ubuntu 22.04) 20241116" AMI, in the default VPC with the default security group. Make your instance accessible through SSH, using an appropriate security group or the [AWS Systems Session Manager](https://docs.aws.amazon.com/systems-manager/latest/userguide/session-manager.html), and log in to the instance. You can also use the [AWS Console](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/tutorial-launch-my-first-ec2-instance.html) to create a new EC2 instance. -> NOTE: You may need to update the --image-id to the latest available. See https://docs.aws.amazon.com/dlami/latest/devguide/find-dlami-id.html for instructions. +> NOTE: You may need to update the --image-id to the latest available. See https://docs.aws.amazon.com/dlami/latest/devguide/find-dlami-id.html for instructions on finding the latest DLAMI. Once logged in, you can set up your Python environment to run GraphStorm @@ -151,12 +151,12 @@ aws s3 ls s3://$BUCKET_NAME/ogb-arxiv-input/ PRE edges/ PRE nodes/ PRE splits/ -2024-12-11 02:13:27 1269 gconstruct_config_arxiv.json +XXXX-XX-XX XX:XX:XX 1269 gconstruct_config_arxiv.json ``` Finally you'll also upload GraphStorm training configuration files for arxiv to use for training and inference -``` +```bash # Upload the training configurations to S3 aws s3 cp ~/graphstorm/training_scripts/gsgnn_np/arxiv_nc.yaml \ s3://$BUCKET_NAME/yaml/arxiv_nc_train.yaml @@ -327,18 +327,35 @@ aws s3 ls s3://$BUCKET_NAME/pipelines-output/ogbn-arxiv-gs-pipeline/ # There should only be one execution subpath, copy that into a new env variable EXECUTION_SUBPATH="761a4ff194198d49469a3bb223d5f26e" -aws s3 ls --recursive \ - s3://$BUCKET_NAME/pipelines-output/ogbn-arxiv-gs-pipeline/$EXECUTION_SUBPATH +aws s3 ls \ + s3://$BUCKET_NAME/pipelines-output/ogbn-arxiv-gs-pipeline/$EXECUTION_SUBPATH/ + +# You will see the top-level outputs +# gconstruct/ +# inference/ +# model/ -# gconstruct: +# gconstruct/ output +aws s3 ls \ + s3://$BUCKET_NAME/pipelines-output/ogbn-arxiv-gs-pipeline/$EXECUTION_SUBPATH/gconstruct/ + +# We get the 2 graph partitions (part0, part1) and metadata JSON files that describe the graph # data_transform_new.json edge_label_stats.json edge_mapping.pt node_label_stats.json node_mapping.pt ogbn-arxiv.json part0 part1 -# inference: -# embeddings predictions +# model/ output +aws s3 ls \ + s3://$BUCKET_NAME/pipelines-output/ogbn-arxiv-gs-pipeline/$EXECUTION_SUBPATH/model/ -# model: +# We get a model snapshot for every epoch # epoch-0 epoch-1 epoch-2 epoch-3 epoch-4 epoch-5 epoch-6 epoch-7 epoch-8 epoch-9 +# inference/ output +aws s3 ls \ + s3://$BUCKET_NAME/pipelines-output/ogbn-arxiv-gs-pipeline/$EXECUTION_SUBPATH/inference/ + +# We get two prefixes, one containing the embeddings and one the predictions +# embeddings/ predictions/ + ``` You'll be able to see the output of each step in the pipeline. The GConstruct job created the partitioned graph, the training job created models for 10 epochs, and the inference job created embeddings for the nodes and predictions for the nodes in the test set. @@ -371,16 +388,16 @@ You can use the same pipeline creation script, but change two variables, providi ```bash # Deploy the GraphBolt-enabled pipeline -PIPELINE_NAME="ogbn-arxiv-gs-graphbolt-pipeline" +PIPELINE_NAME_GRAPHBOLT="ogbn-arxiv-gs-graphbolt-pipeline" BUCKET_NAME="my-s3-bucket" bash deploy_arxiv_pipeline.sh \ --account "" \ --bucket-name $BUCKET_NAME --role "" \ - --pipeline-name $PIPELINE_NAME \ + --pipeline-name $PIPELINE_NAME_GRAPHBOLT \ --use-graphbolt true # Execute the pipeline locally python ~/graphstorm/sagemaker/pipeline/execute_sm_pipeline.py \ - --pipeline-name $PIPELINE_NAME \ + --pipeline-name $PIPELINE_NAME_GRAPHBOLT \ --region us-east-1 \ --local-execution | tee arxiv-local-gb-logs.txt ``` @@ -459,16 +476,16 @@ python ~/graphstorm/sagemaker/pipeline/execute_sm_pipeline.py \ Next, you can deploy and execute another pipeline, now with GraphBolt enabled: ```bash -PIPELINE_NAME="ogb-papers100M-graphbolt-pipeline" +PIPELINE_NAME_GRAPHBOLT="ogb-papers100M-graphbolt-pipeline" bash deploy_papers100M_pipeline.sh \ --account \ --bucket-name --role \ - --pipeline-name $PIPELINE_NAME \ + --pipeline-name $PIPELINE_NAME_GRAPHBOLT \ --use-graphbolt true # Execute the GraphBolt-enabled pipeline on SageMaker python ~/graphstorm/sagemaker/pipeline/execute_sm_pipeline.py \ - --pipeline-name $PIPELINE_NAME \ + --pipeline-name $PIPELINE_NAME_GRAPHBOLT \ --region us-east-1 \ --async-execution ``` @@ -482,7 +499,7 @@ The easiest way to do so is through the Studio pipeline interface. In the Pipeli ```bash python analyze_training_time.py \ - --pipeline-name papers-100M-gs-pipeline \ + --pipeline-name $PIPELINE_NAME \ --execution-name execution-1734404366941 ``` @@ -502,7 +519,7 @@ Now do the same for the GraphBolt-enabled pipeline: ```bash python analyze_training_time.py \ - --pipeline-name papers-100M-gs-graphbolt-pipeline \ + --pipeline-name $PIPELINE_NAME_GRAPHBOLT \ --execution-name execution-1734463209078 ``` @@ -520,10 +537,4 @@ Average evaluation time: 4.13 seconds Without loss in accuracy, the latest version of GraphStorm achieved a **~1.4x speedup per epoch, and a 3.6x speedup in evaluation time!** -## Conclusion: Accelerate Your Graph ML with GraphStorm - -This example showcased how GraphStorm 0.4, integrated with DGL-GraphBolt, significantly speeds up large-scale graph neural network training and inference. - -We encourage ML practitioners working with large graph data to try GraphStorm. Its low-code interface simplifies building, training, and deploying graph ML solutions on AWS, allowing you to focus on modeling rather than infrastructure. - -To get started, visit the GraphStorm [documentation](https://graphstorm.readthedocs.io/en/) and GraphStorm [Github repository](https://github.com/awslabs/graphstorm). +We encourage you to try out GraphStorm with GraphBolt enabled to see how it can benefit your large-scale graph learning use cases. diff --git a/examples/sagemaker-pipelines-graphbolt/analyze_training_time.py b/examples/sagemaker-pipelines-graphbolt/analyze_training_time.py index ab5c1548d0..fb03440f69 100644 --- a/examples/sagemaker-pipelines-graphbolt/analyze_training_time.py +++ b/examples/sagemaker-pipelines-graphbolt/analyze_training_time.py @@ -68,7 +68,7 @@ def parse_args(): "--logs-days-before", type=int, default=2, - help="The number of days in the past to start analyzing logs.", + help="Limit log analysis to logs created this many days before today.", ) return parser.parse_args() diff --git a/examples/sagemaker-pipelines-graphbolt/sagemaker_convert_papers100m.py b/examples/sagemaker-pipelines-graphbolt/sagemaker_convert_papers100m.py index f611070595..2c6798817e 100644 --- a/examples/sagemaker-pipelines-graphbolt/sagemaker_convert_papers100m.py +++ b/examples/sagemaker-pipelines-graphbolt/sagemaker_convert_papers100m.py @@ -40,7 +40,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--region", type=str, required=True, help="SageMaker Processing region." ) - parser.add_argument("--image-uri", type=str, required=True) + parser.add_argument("--image-uri", type=str, required=True, + help="URI for the 'papers100m-processor' image.") parser.add_argument( "--output-bucket", type=str, From ccb0e7f660bf7deeee05b045e119004aab398733 Mon Sep 17 00:00:00 2001 From: Theodore Vasiloudis Date: Mon, 13 Jan 2025 18:41:49 +0000 Subject: [PATCH 8/9] Fix argument names --- .../sagemaker-pipelines-graphbolt/deploy_arxiv_pipeline.sh | 4 ++-- .../deploy_papers100M_pipeline.sh | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/sagemaker-pipelines-graphbolt/deploy_arxiv_pipeline.sh b/examples/sagemaker-pipelines-graphbolt/deploy_arxiv_pipeline.sh index 23b395d731..2a4d27b72f 100644 --- a/examples/sagemaker-pipelines-graphbolt/deploy_arxiv_pipeline.sh +++ b/examples/sagemaker-pipelines-graphbolt/deploy_arxiv_pipeline.sh @@ -107,8 +107,8 @@ python3 $SCRIPT_DIR/../../sagemaker/pipeline/create_sm_pipeline.py \ --graph-construction-instance-type ${GCONSTRUCT_INSTANCE} \ --graph-construction-config-filename ${GCONSTRUCT_CONFIG} \ --graph-name ${GRAPH_NAME} \ - --graphstorm-pytorch-cpu-image-url "${GSF_CPU_IMAGE_URI}" \ - --graphstorm-pytorch-gpu-image-url "${GSF_GPU_IMAGE_URI}" \ + --graphstorm-pytorch-cpu-image-uri "${GSF_CPU_IMAGE_URI}" \ + --graphstorm-pytorch-gpu-image-uri "${GSF_GPU_IMAGE_URI}" \ --inference-model-snapshot "${INFERENCE_MODEL_SNAPSHOT}" \ --inference-yaml-s3 ${INFERENCE_YAML_S3} \ --input-data-s3 ${DATASET_S3_PATH} \ diff --git a/examples/sagemaker-pipelines-graphbolt/deploy_papers100M_pipeline.sh b/examples/sagemaker-pipelines-graphbolt/deploy_papers100M_pipeline.sh index 7609045490..c46ee1de9f 100644 --- a/examples/sagemaker-pipelines-graphbolt/deploy_papers100M_pipeline.sh +++ b/examples/sagemaker-pipelines-graphbolt/deploy_papers100M_pipeline.sh @@ -119,8 +119,8 @@ python3 $SCRIPT_DIR/../../sagemaker/pipeline/create_sm_pipeline.py \ --graph-construction-instance-type ${GCONSTRUCT_INSTANCE} \ --graph-construction-config-filename ${GCONSTRUCT_CONFIG} \ --graph-name ${GRAPH_NAME} \ - --graphstorm-pytorch-cpu-image-url "${GSF_CPU_IMAGE_URI}" \ - --graphstorm-pytorch-gpu-image-url "${GSF_GPU_IMAGE_URI}" \ + --graphstorm-pytorch-cpu-image-uri "${GSF_CPU_IMAGE_URI}" \ + --graphstorm-pytorch-gpu-image-uri "${GSF_GPU_IMAGE_URI}" \ --inference-model-snapshot "${INFERENCE_MODEL_SNAPSHOT}" \ --inference-yaml-s3 "${INFERENCE_YAML_S3}" \ --input-data-s3 "${DATASET_S3_PATH}" \ From 3a24904b91824f84e908bc4fa6894142cb0bf7ff Mon Sep 17 00:00:00 2001 From: Theodore Vasiloudis Date: Mon, 13 Jan 2025 14:26:08 -0800 Subject: [PATCH 9/9] Update README.md with review comment --- examples/sagemaker-pipelines-graphbolt/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/sagemaker-pipelines-graphbolt/README.md b/examples/sagemaker-pipelines-graphbolt/README.md index 7b1b85588b..6d04f00f12 100644 --- a/examples/sagemaker-pipelines-graphbolt/README.md +++ b/examples/sagemaker-pipelines-graphbolt/README.md @@ -2,7 +2,7 @@ GraphStorm is a low-code enterprise graph machine learning (ML) framework that provides ML practitioners a simple way of building, training and deploying graph ML solutions on industry-scale graph data. While GraphStorm can run efficiently on single instances for small graphs, it truly shines when scaling to enterprise-level graphs in distributed mode using a cluster of EC2 instances or Amazon SageMaker. -GraphStorm 0.4 introduced integration with DGL-GraphBolt, a new graph storage and sampling framework that uses a compact graph representation and pipelined sampling to reduce memory requirements and speed up Graph Neural Network (GNN) training. In this example we'll show how GraphStorm 0.4 brings training and inference speedups of up to 3x on the papers100M dataset. +GraphStorm 0.4 introduced integration with DGL-GraphBolt, a new graph storage and sampling framework that uses a compact graph representation and pipelined sampling to reduce memory requirements and speed up Graph Neural Network (GNN) training. In this example we'll show how GraphStorm 0.4 brings inference speedups of up to 4x, and per-epoch training speedup up to 2x on the papers100M dataset, with even larger speedups possible [1]. In this example, you will: @@ -538,3 +538,5 @@ Average evaluation time: 4.13 seconds Without loss in accuracy, the latest version of GraphStorm achieved a **~1.4x speedup per epoch, and a 3.6x speedup in evaluation time!** We encourage you to try out GraphStorm with GraphBolt enabled to see how it can benefit your large-scale graph learning use cases. + +[1] DGL team GraphBolt benchmarks: https://www.dgl.ai/release/2024/03/06/release.html