diff --git a/python/graphstorm/sagemaker/sagemaker_train.py b/python/graphstorm/sagemaker/sagemaker_train.py index 3e282774a2..f59c094e28 100644 --- a/python/graphstorm/sagemaker/sagemaker_train.py +++ b/python/graphstorm/sagemaker/sagemaker_train.py @@ -102,8 +102,8 @@ def launch_train_task(task_type, num_gpus, graph_config, "--ssh-port", "22"] launch_cmd += [custom_script] if custom_script is not None else [] launch_cmd += ["--cf", f"{yaml_path}", - "--save-model-path", f"{save_model_path}"] + \ - ["--restore-model-path", f"{restore_model_path}"] \ + "--save-model-path", f"{save_model_path}"] + launch_cmd += ["--restore-model-path", f"{restore_model_path}"] \ if restore_model_path is not None else [] launch_cmd += extra_args @@ -153,9 +153,13 @@ def run_train(args, unknownargs): """ num_gpus = args.num_gpus data_path = args.data_path - restore_model_path = "/tmp/gsgnn_model_checkpoint/" + model_checkpoint_s3 = args.model_checkpoint_to_load + if model_checkpoint_s3 is not None: + restore_model_path = "/tmp/gsgnn_model_checkpoint/" + os.makedirs(restore_model_path, exist_ok=True) + else: + restore_model_path = None output_path = "/tmp/gsgnn_model/" - os.makedirs(restore_model_path, exist_ok=True) os.makedirs(output_path, exist_ok=True) # start the ssh server @@ -214,7 +218,6 @@ def run_train(args, unknownargs): graph_data_s3 = args.graph_data_s3 task_type = args.task_type train_yaml_s3 = args.train_yaml_s3 - model_checkpoint_s3 = args.model_checkpoint_to_load model_artifact_s3 = args.model_artifact_s3.rstrip('/') custom_script = args.custom_script diff --git a/sagemaker/launch/launch_train.py b/sagemaker/launch/launch_train.py index 65d2db4020..9d1109ab5c 100644 --- a/sagemaker/launch/launch_train.py +++ b/sagemaker/launch/launch_train.py @@ -66,10 +66,11 @@ def run_job(input_args, image, unknowargs): "graph-name": graph_name, "graph-data-s3": graph_data_s3, "train-yaml-s3": train_yaml_s3, - "model-artifact-s3": model_artifact_s3, - "model-checkpoint-to-load": model_checkpoint_to_load} + "model-artifact-s3": model_artifact_s3} if custom_script is not None: params["custom-script"] = custom_script + if model_checkpoint_to_load is not None: + params["model-checkpoint-to-load"] = model_checkpoint_to_load # We must handle cases like # --target-etype query,clicks,asin query,search,asin # --feat-name ntype0:feat0 ntype1:feat1