diff --git a/deploy/scripts/env_to_args.sh b/deploy/scripts/env_to_args.sh index 9c543917c..4bd22d666 100755 --- a/deploy/scripts/env_to_args.sh +++ b/deploy/scripts/env_to_args.sh @@ -20,9 +20,11 @@ pull_code() { wget $1 -O code.tar.gz elif [[ $1 == "oss://"* ]]; then python -c "import tensorflow as tf; import tensorflow_io; open('code.tar.gz', 'wb').write(tf.io.gfile.GFile('$1', 'rb').read())" + elif [[ $1 == "base64://"* ]]; then + python -c "import base64; f = open('code.tar.gz', 'wb'); f.write(base64.b64decode('$1'[9:])); f.close()" else cp $1 code.tar.gz - fi +fi tar -zxvf code.tar.gz cd $cwd } diff --git a/deploy/scripts/trainer/run_trainer_worker.sh b/deploy/scripts/trainer/run_trainer_worker.sh index c9c29b1ca..03106dedb 100755 --- a/deploy/scripts/trainer/run_trainer_worker.sh +++ b/deploy/scripts/trainer/run_trainer_worker.sh @@ -43,7 +43,12 @@ for i in "${WORKER_GROUPS[@]}"; do done fi -pull_code ${CODE_KEY} $PWD +if [[ -n "${CODE_KEY}" ]]; then + pull_code ${CODE_KEY} $PWD +else + pull_code ${CODE_TAR} $PWD +fi + cd ${ROLE} mode=$(normalize_env_to_args "--mode" "$MODE") diff --git a/web_console_v2/api/test/fedlearner_webconsole/job/yaml_formatter_test.py b/web_console_v2/api/test/fedlearner_webconsole/job/yaml_formatter_test.py index b71fe355a..4679506d7 100644 --- a/web_console_v2/api/test/fedlearner_webconsole/job/yaml_formatter_test.py +++ b/web_console_v2/api/test/fedlearner_webconsole/job/yaml_formatter_test.py @@ -14,8 +14,10 @@ # coding: utf-8 import unittest - -from fedlearner_webconsole.job.yaml_formatter import format_yaml +import tarfile +import base64 +from io import BytesIO +from fedlearner_webconsole.job.yaml_formatter import format_yaml, code_dict_encode class YamlFormatterTest(unittest.TestCase): @@ -62,6 +64,22 @@ def test_format_yaml_unknown_ph(self): format_yaml('$x.y is ${i.j}', x=x) self.assertEqual(str(cm.exception), 'Unknown placeholder: i.j') + def test_encode_code(self): + test_data = {'test/a.py': 'awefawefawefawefwaef', + 'test1/b.py': 'asdfasd', + 'c.py': '', + 'test/d.py': 'asdf'} + code_base64 = code_dict_encode(test_data) + code_dict = {} + if code_base64.startswith('base64://'): + tar_binary = BytesIO(base64.b64decode(code_base64[9:])) + with tarfile.open(fileobj=tar_binary) as tar: + for file in tar.getmembers(): + code_dict[file.name] = str(tar.extractfile(file).read(), + encoding='utf-8') + self.assertEqual(code_dict, test_data) + + if __name__ == '__main__': unittest.main()