Skip to content

Commit

Permalink
Bugfix sagemaker on restore model checkpoint (#558)
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato authored Oct 14, 2023
1 parent 7e7687e commit fe8d317
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
13 changes: 8 additions & 5 deletions python/graphstorm/sagemaker/sagemaker_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def launch_train_task(task_type, num_gpus, graph_config,
"--ssh-port", "22"]
launch_cmd += [custom_script] if custom_script is not None else []
launch_cmd += ["--cf", f"{yaml_path}",
"--save-model-path", f"{save_model_path}"] + \
["--restore-model-path", f"{restore_model_path}"] \
"--save-model-path", f"{save_model_path}"]
launch_cmd += ["--restore-model-path", f"{restore_model_path}"] \
if restore_model_path is not None else []
launch_cmd += extra_args

Expand Down Expand Up @@ -153,9 +153,13 @@ def run_train(args, unknownargs):
"""
num_gpus = args.num_gpus
data_path = args.data_path
restore_model_path = "/tmp/gsgnn_model_checkpoint/"
model_checkpoint_s3 = args.model_checkpoint_to_load
if model_checkpoint_s3 is not None:
restore_model_path = "/tmp/gsgnn_model_checkpoint/"
os.makedirs(restore_model_path, exist_ok=True)
else:
restore_model_path = None
output_path = "/tmp/gsgnn_model/"
os.makedirs(restore_model_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)

# start the ssh server
Expand Down Expand Up @@ -214,7 +218,6 @@ def run_train(args, unknownargs):
graph_data_s3 = args.graph_data_s3
task_type = args.task_type
train_yaml_s3 = args.train_yaml_s3
model_checkpoint_s3 = args.model_checkpoint_to_load
model_artifact_s3 = args.model_artifact_s3.rstrip('/')
custom_script = args.custom_script

Expand Down
5 changes: 3 additions & 2 deletions sagemaker/launch/launch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,11 @@ def run_job(input_args, image, unknowargs):
"graph-name": graph_name,
"graph-data-s3": graph_data_s3,
"train-yaml-s3": train_yaml_s3,
"model-artifact-s3": model_artifact_s3,
"model-checkpoint-to-load": model_checkpoint_to_load}
"model-artifact-s3": model_artifact_s3}
if custom_script is not None:
params["custom-script"] = custom_script
if model_checkpoint_to_load is not None:
params["model-checkpoint-to-load"] = model_checkpoint_to_load
# We must handle cases like
# --target-etype query,clicks,asin query,search,asin
# --feat-name ntype0:feat0 ntype1:feat1
Expand Down

0 comments on commit fe8d317

Please sign in to comment.