diff --git a/.gitignore b/.gitignore
index 48143ca5a..e64f8bf44 100644
--- a/.gitignore
+++ b/.gitignore
@@ -93,3 +93,4 @@ config.mk
server.config.js
es.match_phrase.js
es.match_phrase.json
+web_console/config
diff --git a/deploy/charts/fedlearner/charts/fedlearner-web-console/templates/deployment.yaml b/deploy/charts/fedlearner/charts/fedlearner-web-console/templates/deployment.yaml
index acd9cb837..13c628f54 100644
--- a/deploy/charts/fedlearner/charts/fedlearner-web-console/templates/deployment.yaml
+++ b/deploy/charts/fedlearner/charts/fedlearner-web-console/templates/deployment.yaml
@@ -34,9 +34,9 @@ spec:
image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
imagePullPolicy: {{ .Values.image.pullPolicy }}
command:
- - node
+ - sh
args:
- - bootstrap.js
+ - startup.sh
env:
- name: NODE_ENV
value: production
diff --git a/deploy/charts/fedlearner/charts/fedlearner-web-console/values.yaml b/deploy/charts/fedlearner/charts/fedlearner-web-console/values.yaml
index a9512c4ec..6fa6ce554 100644
--- a/deploy/charts/fedlearner/charts/fedlearner-web-console/values.yaml
+++ b/deploy/charts/fedlearner/charts/fedlearner-web-console/values.yaml
@@ -80,9 +80,11 @@ cluster:
DB_PASSWORD: fedlearner
DB_HOST: fedlearner-stack-mariadb
DB_PORT: 3306
- DB_SYNC: true
+ DB_SYNC: false
GRPC_AUTHORITY: ""
KIBANA_HOST: fedlearner-stack-kibana
KIBANA_PORT: 443
ES_HOST: fedlearner-stack-elasticsearch-client
ES_PORT: 9200
+ ETCD_ADDR: "fedlearner-stack-etcd.default.svc.cluster.local:2379"
+ KVSTORE_TYPE: "mysql"
diff --git a/deploy/integrated_test/client_integrated_test.py b/deploy/integrated_test/client_integrated_test.py
index 7c4986388..ceb85dbe0 100644
--- a/deploy/integrated_test/client_integrated_test.py
+++ b/deploy/integrated_test/client_integrated_test.py
@@ -2,7 +2,8 @@
import argparse
import json
import requests
-from tools import login, request_and_response, build_raw_data, build_data_join_ticket, build_train_ticket
+from tools import login, request_and_response, build_raw_data, \
+ build_data_join_ticket, build_nn_ticket, build_tree_ticket
def build_federation_json(args):
@@ -18,17 +19,61 @@ def build_federation_json(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
- parser.add_argument('--name', type=str)
- parser.add_argument('--x-federation', type=str)
- parser.add_argument('--image', type=str)
- parser.add_argument('--url', type=str)
- parser.add_argument('--username', type=str)
- parser.add_argument('--password', type=str)
- parser.add_argument('--api-version')
+ parser.add_argument('--name',
+ type=str,
+ help='Name for peer federation.')
+ parser.add_argument('--x-federation',
+ type=str,
+ help='Name for local federation.')
+ parser.add_argument('--image',
+ type=str,
+ help='Image address.')
+ parser.add_argument('--data-portal-type',
+ type=str,
+ help='Type of raw data, Streaming(default) or PSI.',
+ choices=['Streaming', 'PSI'],
+ default='Streaming')
+ parser.add_argument('--model-type',
+ type=str,
+ help='Type of train model, (nn model) or tree model.',
+ choices=['nn_model', 'tree_model'],
+ default='nn_model')
+ parser.add_argument('--rsa-key-path',
+ type=str,
+ help='Path to RSA public key.')
+ parser.add_argument('--rsa-key-pem',
+ type=str,
+ help='Either rsa key path or rsa key pem must be given.')
+ parser.add_argument('--url',
+ type=str,
+ help='URL to webconsole.',
+ default='127.0.0.1:1989')
+ parser.add_argument('--username',
+ type=str,
+ help='Username of webconsole.',
+ default='ada')
+ parser.add_argument('--password',
+ type=str,
+ help='Password of webconsole.',
+ default='ada')
+ parser.add_argument('--api-version',
+ help='API version of webconsole.',
+ default=1)
args = parser.parse_args()
+ args.streaming = args.data_portal_type == 'Streaming'
+ if not args.streaming:
+ args.cmd_args = {'Master': ["/app/deploy/scripts/rsa_psi/run_psi_data_join_master.sh"],
+ 'Worker': ["/app/deploy/scripts/rsa_psi/run_psi_data_join_worker.sh"]}
+ if args.rsa_key_pem is not None:
+ args.psi_extras = [{"name": "RSA_KEY_PEM", "value": args.rsa_key_pem}]
+ elif args.rsa_key_path is not None:
+ args.psi_extras = [{"name": "RSA_KEY_PATH", "value": args.rsa_key_path}]
+ else:
+ raise Exception('Either RSA_KEY_PEN or RSA_KEY_PATH must be provided when using PSI.')
+ args.psi_extras.append({"name": "SIGN_RPC_TIMEOUT_MS", "value": "128000"})
+
args.url = args.url.strip().rstrip('/') + '/api/v' + str(args.api_version)
cookie = login(args)
-
federation_json, suffix = build_federation_json(args)
federation_id, federation_name = request_and_response(args=args,
url=args.url + '/federations',
@@ -45,15 +90,22 @@ def build_federation_json(args):
requests.post(url=args.url + '/raw_data/' + str(raw_data_id) + '/submit', cookies=cookie)
join_ticket_json, suffix = build_data_join_ticket(args, federation_id, raw_data_name,
- 'template_json/template_join_ticket.json', 'Leader')
+ 'template_json/template_streaming_join_ticket.json'
+ if args.streaming
+ else 'template_json/template_psi_join_ticket.json',
+ 'Leader' if args.streaming else 'Follower')
join_ticket_id, join_ticket_name = request_and_response(args=args,
url=args.url + '/tickets',
json_data=join_ticket_json,
cookies=cookie,
name_suffix=suffix)
- train_ticket_json, suffix = build_train_ticket(args, federation_id,
- 'template_json/template_train_ticket.json', 'Follower')
+ if args.model_type == 'nn_model':
+ train_ticket_json, suffix = build_nn_ticket(args, federation_id,
+ 'template_json/template_nn_ticket.json', 'Follower')
+ else:
+ train_ticket_json, suffix = build_tree_ticket(args, federation_id,
+ 'template_json/template_tree_ticket.json', 'Follower')
train_ticket_id, train_ticket_name = request_and_response(args=args,
url=args.url + '/tickets',
json_data=train_ticket_json,
diff --git a/deploy/integrated_test/template_json/template_train_ticket.json b/deploy/integrated_test/template_json/template_nn_ticket.json
similarity index 100%
rename from deploy/integrated_test/template_json/template_train_ticket.json
rename to deploy/integrated_test/template_json/template_nn_ticket.json
diff --git a/deploy/integrated_test/template_json/template_psi_join_ticket.json b/deploy/integrated_test/template_json/template_psi_join_ticket.json
new file mode 100644
index 000000000..503ca701e
--- /dev/null
+++ b/deploy/integrated_test/template_json/template_psi_join_ticket.json
@@ -0,0 +1,178 @@
+{
+ "public_params": {
+ "spec": {
+ "flReplicaSpecs": {
+ "Master": {
+ "pair": true,
+ "replicas": 1,
+ "template": {
+ "spec": {
+ "containers": [
+ {
+ "env": [
+ {
+ "name": "PARTITION_NUM",
+ "value": "4"
+ },
+ {
+ "name": "START_TIME",
+ "value": "0"
+ },
+ {
+ "name": "END_TIME",
+ "value": "999999999999"
+ },
+ {
+ "name": "NEGATIVE_SAMPLING_RATE",
+ "value": "1.0"
+ },
+ {
+ "name": "RAW_DATA_SUB_DIR",
+ "value": "portal_publish_dir/"
+ }
+ ],
+ "image": "!image",
+ "ports": [
+ {
+ "containerPort": 50051,
+ "name": "flapp-port"
+ }
+ ],
+ "command": [
+ "/app/deploy/scripts/wait4pair_wrapper.sh"
+ ],
+ "args": []
+ }
+ ]
+ }
+ }
+ },
+ "Worker": {
+ "pair": true,
+ "replicas": 4,
+ "template": {
+ "spec": {
+ "containers": [
+ {
+ "env": [
+ {
+ "name": "PARTITION_NUM",
+ "value": "4"
+ },
+ {
+ "name": "RAW_DATA_SUB_DIR",
+ "value": "portal_publish_dir/"
+ },
+ {
+ "name": "DATA_BLOCK_DUMP_INTERVAL",
+ "value": "600"
+ },
+ {
+ "name": "DATA_BLOCK_DUMP_THRESHOLD",
+ "value": "65536"
+ },
+ {
+ "name": "EXAMPLE_ID_DUMP_INTERVAL",
+ "value": "600"
+ },
+ {
+ "name": "EXAMPLE_ID_DUMP_THRESHOLD",
+ "value": "65536"
+ },
+ {
+ "name": "PSI_RAW_DATA_ITER",
+ "value": "TF_RECORD"
+ },
+ {
+ "name": "PSI_OUTPUT_BUILDER",
+ "value": "TF_RECORD"
+ },
+ {
+ "name": "DATA_BLOCK_BUILDER",
+ "value": "TF_RECORD"
+ },
+ {
+ "name": "EXAMPLE_JOINER",
+ "value": "SORT_RUN_JOINER"
+ },
+ {
+ "name": "SIGN_RPC_TIMEOUT_MS",
+ "value": "128000"
+ }
+ ],
+ "image": "!image",
+ "ports": [
+ {
+ "containerPort": 50051,
+ "name": "flapp-port"
+ }
+ ],
+ "command": [
+ "/app/deploy/scripts/wait4pair_wrapper.sh"
+ ],
+ "args": []
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ },
+ "private_params": {
+ "spec": {
+ "flReplicaSpecs": {
+ "Master": {
+ "template": {
+ "spec": {
+ "containers": [
+ {
+ "image": "!image",
+ "env": [
+ {
+ "name": "RAW_DATA_SUB_DIR",
+ "value": "portal_publish_dir/"
+ },
+ {
+ "name": "PARTITION_NUM",
+ "value": "4"
+ }
+ ]
+ }
+ ]
+ }
+ },
+ "replicas": 1
+ },
+ "Worker": {
+ "template": {
+ "spec": {
+ "containers": [
+ {
+ "image": "!image",
+ "env": [
+ {
+ "name": "RAW_DATA_SUB_DIR",
+ "value": "portal_publish_dir/"
+ },
+ {
+ "name": "PARTITION_NUM",
+ "value": "4"
+ }
+ ]
+ }
+ ]
+ }
+ },
+ "replicas": 4
+ }
+ }
+ }
+ },
+ "name": "!name",
+ "federation_id": "!federation_id",
+ "job_type": "psi_data_join",
+ "role": "!role",
+ "expire_time": "!expire_time",
+ "remark": "Built by integrated test."
+}
diff --git a/deploy/integrated_test/template_json/template_raw_data.json b/deploy/integrated_test/template_json/template_raw_data.json
index 5f261e7cc..f7b9eab92 100644
--- a/deploy/integrated_test/template_json/template_raw_data.json
+++ b/deploy/integrated_test/template_json/template_raw_data.json
@@ -2,7 +2,7 @@
"name": "!name",
"federation_id": "!federation_id",
"output_partition_num": 4,
- "data_portal_type": "Streaming",
+ "data_portal_type": "!data_portal_type",
"input": "/app/deploy/integrated_test/tfrecord_raw_data",
"image": "!image",
"context": {
diff --git a/deploy/integrated_test/template_json/template_join_ticket.json b/deploy/integrated_test/template_json/template_streaming_join_ticket.json
similarity index 97%
rename from deploy/integrated_test/template_json/template_join_ticket.json
rename to deploy/integrated_test/template_json/template_streaming_join_ticket.json
index c08be5f92..fa52e48bd 100644
--- a/deploy/integrated_test/template_json/template_join_ticket.json
+++ b/deploy/integrated_test/template_json/template_streaming_join_ticket.json
@@ -61,13 +61,21 @@
"containers": [
{
"env": [
+ {
+ "name": "PARTITION_NUM",
+ "value": "4"
+ },
+ {
+ "name": "RAW_DATA_SUB_DIR",
+ "value": "portal_publish_dir/"
+ },
{
"name": "DATA_BLOCK_DUMP_INTERVAL",
"value": "600"
},
{
"name": "DATA_BLOCK_DUMP_THRESHOLD",
- "value": "262144"
+ "value": "65536"
},
{
"name": "EXAMPLE_ID_DUMP_INTERVAL",
@@ -75,7 +83,7 @@
},
{
"name": "EXAMPLE_ID_DUMP_THRESHOLD",
- "value": "262144"
+ "value": "65536"
},
{
"name": "EXAMPLE_ID_BATCH_SIZE",
@@ -96,14 +104,6 @@
{
"name": "RAW_DATA_ITER",
"value": "TF_RECORD"
- },
- {
- "name": "RAW_DATA_SUB_DIR",
- "value": "portal_publish_dir/"
- },
- {
- "name": "PARTITION_NUM",
- "value": "4"
}
],
"image": "!image",
@@ -182,5 +182,5 @@
"job_type": "data_join",
"role": "!role",
"expire_time": "!expire_time",
- "remark": "Build by integrated test."
+ "remark": "Built by integrated test."
}
diff --git a/deploy/integrated_test/template_json/template_tree_ticket.json b/deploy/integrated_test/template_json/template_tree_ticket.json
new file mode 100644
index 000000000..07765938e
--- /dev/null
+++ b/deploy/integrated_test/template_json/template_tree_ticket.json
@@ -0,0 +1,85 @@
+{
+ "public_params": {
+ "spec": {
+ "flReplicaSpecs": {
+ "Worker": {
+ "pair": true,
+ "replicas": 1,
+ "template": {
+ "spec": {
+ "containers": [
+ {
+ "env": [
+ {
+ "name": "FILE_EXT",
+ "value": ".data"
+ },
+ {
+ "name": "FILE_TYPE",
+ "value": "tfrecord"
+ },
+ {
+ "name": "SEND_SCORES_TO_FOLLOWER",
+ "value": ""
+ },
+ {
+ "name": "MODE",
+ "value": "train"
+ },
+ {
+ "name": "DATA_SOURCE",
+ "value": "!DATA_SOURCE"
+ }
+ ],
+ "image": "!image",
+ "ports": [
+ {
+ "containerPort": 50051,
+ "name": "flapp-port"
+ }
+ ],
+ "command": [
+ "/app/deploy/scripts/wait4pair_wrapper.sh"
+ ],
+ "args": [
+ "/app/deploy/scripts/trainer/run_tree_worker.sh"
+ ]
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ },
+ "private_params": {
+ "spec": {
+ "flReplicaSpecs": {
+ "Worker": {
+ "template": {
+ "spec": {
+ "containers": [
+ {
+ "image": "!image",
+ "env": [
+ {
+ "name": "DATA_SOURCE",
+ "value": "!DATA_SOURCE"
+ }
+ ]
+ }
+ ]
+ }
+ }
+ }
+ }
+ }
+ },
+ "name": "!name",
+ "federation_id": -1,
+ "job_type": "tree_model",
+ "role": "!role",
+ "expire_time": "!expire_time",
+ "remark": "Built by integrated test.",
+ "undefined": ""
+}
diff --git a/deploy/integrated_test/tfrecord_raw_data/raw_data_partition_0000.rd b/deploy/integrated_test/tfrecord_raw_data/raw_data_partition_0000.rd
index f10e60143..662ee7561 100644
Binary files a/deploy/integrated_test/tfrecord_raw_data/raw_data_partition_0000.rd and b/deploy/integrated_test/tfrecord_raw_data/raw_data_partition_0000.rd differ
diff --git a/deploy/integrated_test/tfrecord_raw_data/raw_data_partition_0001.rd b/deploy/integrated_test/tfrecord_raw_data/raw_data_partition_0001.rd
index c86af6015..5428c488e 100644
Binary files a/deploy/integrated_test/tfrecord_raw_data/raw_data_partition_0001.rd and b/deploy/integrated_test/tfrecord_raw_data/raw_data_partition_0001.rd differ
diff --git a/deploy/integrated_test/tfrecord_raw_data/raw_data_partition_0002.rd b/deploy/integrated_test/tfrecord_raw_data/raw_data_partition_0002.rd
index 8bef6f388..d92eb1819 100644
Binary files a/deploy/integrated_test/tfrecord_raw_data/raw_data_partition_0002.rd and b/deploy/integrated_test/tfrecord_raw_data/raw_data_partition_0002.rd differ
diff --git a/deploy/integrated_test/tfrecord_raw_data/raw_data_partition_0003.rd b/deploy/integrated_test/tfrecord_raw_data/raw_data_partition_0003.rd
index 99d155524..35e28e1fb 100644
Binary files a/deploy/integrated_test/tfrecord_raw_data/raw_data_partition_0003.rd and b/deploy/integrated_test/tfrecord_raw_data/raw_data_partition_0003.rd differ
diff --git a/deploy/integrated_test/tools.py b/deploy/integrated_test/tools.py
index ac6c30aee..471abd65b 100644
--- a/deploy/integrated_test/tools.py
+++ b/deploy/integrated_test/tools.py
@@ -45,6 +45,8 @@ def request_and_response(args, url, json_data, cookies, name_suffix=''):
try:
response = json.loads(response.text)
except json.decoder.JSONDecodeError:
+ print('Json data to be sent:')
+ print(json_data)
raise Exception('404 error encountered when building/modifying {}. '
'Please check whether webconsole api changed.'.format(url.split('/')[-1]))
if 'error' not in response.keys():
@@ -61,6 +63,7 @@ def build_raw_data(args, fed_id, filepath):
name_suffix = '-raw-data'
raw_json['name'] = args.name + name_suffix
raw_json['federation_id'] = fed_id
+ raw_json['data_portal_type'] = args.data_portal_type
raw_json['image'] = args.image
fl_rep_spec = raw_json['context']['yaml_spec']['spec']['flReplicaSpecs']
fl_rep_spec['Master']['template']['spec']['containers'][0]['image'] = args.image
@@ -79,9 +82,14 @@ def build_data_join_ticket(args, fed_id, raw_name, filepath, role):
ticket_json['sdk_version'] = args.image.split(':')[-1]
ticket_json['expire_time'] = str(datetime.datetime.now().year + 1) + '-12-31'
for param in ['public_params', 'private_params']:
- for pod in ticket_json[param]['spec']['flReplicaSpecs'].values():
+ for pod_name, pod in ticket_json[param]['spec']['flReplicaSpecs'].items():
container = pod['template']['spec']['containers'][0]
container['image'] = args.image
+ if not args.streaming:
+ if param == 'public_params':
+ container['args'] = args.cmd_args[pod_name]
+ if pod_name == 'Worker':
+ container['env'].extend(args.psi_extras)
for d in container['env']:
if d['name'] == 'RAW_DATA_SUB_DIR':
d['value'] += raw_name
diff --git a/deploy/scripts/rsa_psi/run_psi_data_join_master.sh b/deploy/scripts/rsa_psi/run_psi_data_join_master.sh
new file mode 100755
index 000000000..1ddf7969a
--- /dev/null
+++ b/deploy/scripts/rsa_psi/run_psi_data_join_master.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+
+# Copyright 2020 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -ex
+
+data_join_master_cmd=/app/deploy/scripts/data_join/run_data_join_master.sh
+
+export RAW_DATA_SUB_DIR="portal_publish_dir/${APPLICATION_ID}_psi_preprocess"
+
+# Reverse the role assignment for data join so that leader for PSI preprocessor
+# becomes follower for data join. Data join's workers get their role from
+# master so we don't need to do this for worker.
+if [ $ROLE == "leader" ]; then
+ export ROLE="follower"
+else
+ export ROLE="leader"
+fi
+
+${data_join_master_cmd}
diff --git a/deploy/scripts/rsa_psi/run_psi_data_join_worker.sh b/deploy/scripts/rsa_psi/run_psi_data_join_worker.sh
new file mode 100755
index 000000000..fd0bdb449
--- /dev/null
+++ b/deploy/scripts/rsa_psi/run_psi_data_join_worker.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+
+# Copyright 2020 The FedLearner Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -ex
+
+psi_data_join_leader_worker_cmd=/app/deploy/scripts/data_join/run_psi_data_join_leader_worker_v2.sh
+psi_data_join_follower_worker_cmd=/app/deploy/scripts/data_join/run_psi_data_join_follower_worker_v2.sh
+
+export INPUT_FILE_SUBSCRIBE_DIR=$RAW_DATA_SUB_DIR
+export RAW_DATA_PUBLISH_DIR="portal_publish_dir/${APPLICATION_ID}_psi_preprocess"
+if [ $ROLE == "leader" ]; then
+ ${psi_data_join_leader_worker_cmd}
+else
+ ${psi_data_join_follower_worker_cmd}
+fi
diff --git a/deploy/scripts/rsa_psi/run_psi_preprocessor.sh b/deploy/scripts/rsa_psi/run_psi_preprocessor.sh
index 2a77b54d9..2ec30a0cc 100755
--- a/deploy/scripts/rsa_psi/run_psi_preprocessor.sh
+++ b/deploy/scripts/rsa_psi/run_psi_preprocessor.sh
@@ -60,9 +60,13 @@ builder_compressed_type=$(normalize_env_to_args "--builder_compressed_type" $PSI
preprocessor_offload_processor_number=$(normalize_env_to_args "--preprocessor_offload_processor_number" $PREPROCESSOR_OFFLOAD_PROCESSOR_NUMBER)
kvstore_type=$(normalize_env_to_args '--kvstore_type' $KVSTORE_TYPE)
+
+# Turn off display to avoid RSA_KEY_PEM showing in log
+set +x
+
python -m fedlearner.data_join.cmd.rsa_psi_preprocessor_cli \
--psi_role=$ROLE \
- --rsa_key_path=$RSA_KEY_PATH \
+ --rsa_key_path="$RSA_KEY_PATH" \
--rsa_key_pem="$RSA_KEY_PEM" \
--output_file_dir="$OUTPUT_BASE_DIR/psi_output" \
--raw_data_publish_dir=$RAW_DATA_PUBLISH_DIR \
diff --git a/deploy/scripts/rsa_psi/run_rsa_psi_signer.sh b/deploy/scripts/rsa_psi/run_rsa_psi_signer.sh
index 2f88342b2..78cd484d2 100755
--- a/deploy/scripts/rsa_psi/run_rsa_psi_signer.sh
+++ b/deploy/scripts/rsa_psi/run_rsa_psi_signer.sh
@@ -24,8 +24,11 @@ slow_sign_threshold=$(normalize_env_to_args "--slow_sign_threshold" $SLOW_SIGN_T
worker_num=$(normalize_env_to_args "--worker_num" $WORKER_NUM)
signer_offload_processor_number=$(normalize_env_to_args "--signer_offload_processor_number" $SIGNER_OFFLOAD_PROCESSOR_NUMBER)
+# Turn off display to avoid RSA_KEY_PEM showing in log
+set +x
+
python -m fedlearner.data_join.cmd.rsa_psi_signer_service \
--listen_port=50051 \
- --rsa_private_key_path=$RSA_PRIVATE_KEY_PATH \
+ --rsa_private_key_path="$RSA_PRIVATE_KEY_PATH" \
--rsa_privet_key_pem="$RSA_KEY_PEM" \
$slow_sign_threshold $worker_num $signer_offload_processor_number
diff --git a/deploy/scripts/trainer/run_tree_worker.sh b/deploy/scripts/trainer/run_tree_worker.sh
index 23e79cab3..6f1edd3a0 100755
--- a/deploy/scripts/trainer/run_tree_worker.sh
+++ b/deploy/scripts/trainer/run_tree_worker.sh
@@ -22,24 +22,31 @@ source /app/deploy/scripts/env_to_args.sh
NUM_WORKERS=`python -c 'import json, os; print(len(json.loads(os.environ["CLUSTER_SPEC"])["clusterSpec"]["Worker"]))'`
+if [[ -z "${DATA_PATH}" && -n "${DATA_SOURCE}" ]]; then
+ export DATA_PATH="${STORAGE_ROOT_PATH}/data_source/${DATA_SOURCE}/data_block"
+fi
+
mode=$(normalize_env_to_args "--mode" "$MODE")
data_path=$(normalize_env_to_args "--data-path" "$DATA_PATH")
validation_data_path=$(normalize_env_to_args "--validation-data-path" "$VALIDATION_DATA_PATH")
no_data=$(normalize_env_to_args "--no-data" "$NO_DATA")
file_ext=$(normalize_env_to_args "--file-ext" "$FILE_EXT")
+file_type=$(normalize_env_to_args "--file-type" "$FILE_TYPE")
load_model_path=$(normalize_env_to_args "--load-model-path" "$LOAD_MODEL_PATH")
verbosity=$(normalize_env_to_args "--verbosity" "$VERBOSITY")
+loss_type=$(normalize_env_to_args "--loss-type" "$LOSS_TYPE")
learning_rate=$(normalize_env_to_args "--learning-rate" "$LEARNING_RATE")
max_iters=$(normalize_env_to_args "--max-iters" "$MAX_ITERS")
max_depth=$(normalize_env_to_args "--max-depth" "$MAX_DEPTH")
l2_regularization=$(normalize_env_to_args "--l2-regularization" "$L2_REGULARIZATION")
max_bins=$(normalize_env_to_args "--max-bins" "$MAX_BINS")
-num_parallel=$(normalize_env_to_args "--num-parallel" "$NUM_PARALELL")
+num_parallel=$(normalize_env_to_args "--num-parallel" "$NUM_PARALLEL")
verify_example_ids=$(normalize_env_to_args "--verify-example-ids" "$VERIFY_EXAMPLE_IDS")
ignore_fields=$(normalize_env_to_args "--ignore-fields" "$IGNORE_FIELDS")
cat_fields=$(normalize_env_to_args "--cat-fields" "$CAT_FIELDS")
use_streaming=$(normalize_env_to_args "--use-streaming" "$USE_STREAMING")
send_scores_to_follower=$(normalize_env_to_args "--send-scores-to-follower" "$SEND_SCORES_TO_FOLLOWER")
+send_metrics_to_follower=$(normalize_env_to_args "--send-metrics-to-follower" "$SEND_METRICS_TO_FOLLOWER")
python -m fedlearner.model.tree.trainer \
@@ -53,8 +60,9 @@ python -m fedlearner.model.tree.trainer \
--checkpoint-path="$OUTPUT_BASE_DIR/checkpoints" \
--output-path="$OUTPUT_BASE_DIR/outputs" \
$mode $data_path $validation_data_path \
- $no_data $file_ext $load_model_path \
- $verbosity $learning_rate $max_iters \
+ $no_data $file_ext $file_type $load_model_path \
+ $verbosity $loss_type $learning_rate $max_iters \
$max_depth $l2_regularization $max_bins \
$num_parallel $verify_example_ids $ignore_fields \
- $cat_fields $use_streaming $send_scores_to_follower
+ $cat_fields $use_streaming $send_scores_to_follower \
+ $send_metrics_to_follower
diff --git a/package-lock.json b/package-lock.json
deleted file mode 100644
index b8ceed905..000000000
--- a/package-lock.json
+++ /dev/null
@@ -1,11 +0,0 @@
-{
- "requires": true,
- "lockfileVersion": 1,
- "dependencies": {
- "esm": {
- "version": "3.2.25",
- "resolved": "https://registry.npmjs.org/esm/-/esm-3.2.25.tgz",
- "integrity": "sha512-U1suiZ2oDVWv4zPO56S0NcR5QriEahGtdN2OR6FiOG4WJvcjBVFB0qI4+eKoWFH483PKGuLuu6V8Z4T5g63UVA=="
- }
- }
-}
diff --git a/web_console/Dockerfile b/web_console/Dockerfile
index b7ec22d01..8ace1ca81 100644
--- a/web_console/Dockerfile
+++ b/web_console/Dockerfile
@@ -13,4 +13,4 @@ RUN npm run build
EXPOSE 1989
-CMD ["node", "bootstrap.js"]
+CMD ["startup.sh"]
diff --git a/web_console/api/federation.js b/web_console/api/federation.js
index 4f7eac4ec..4fc1545a6 100644
--- a/web_console/api/federation.js
+++ b/web_console/api/federation.js
@@ -111,4 +111,18 @@ router.get('/api/v1/federations/:id/tickets', SessionMiddleware, async (ctx) =>
ctx.body = { data };
});
+router.get('/api/v1/federations/:id/heartbeat', SessionMiddleware, async (ctx) => {
+ const federation = await Federation.findByPk(ctx.params.id);
+ if (!federation) {
+ ctx.status = 404;
+ ctx.body = {
+ error: 'Federation not found',
+ };
+ return;
+ }
+ const client = new FederationClient(federation);
+ const res = await client.heartBeat();
+ ctx.body = res;
+});
+
module.exports = router;
diff --git a/web_console/api/job.js b/web_console/api/job.js
index 2a0a63069..34ac7f124 100644
--- a/web_console/api/job.js
+++ b/web_console/api/job.js
@@ -1,6 +1,7 @@
const router = require('@koa/router')();
const { Op } = require('sequelize');
const SessionMiddleware = require('../middlewares/session');
+const FindOptionsMiddleware = require('../middlewares/find_options');
const k8s = require('../libs/k8s');
const es = require('../libs/es');
const { Job, Ticket, Federation } = require('../models');
@@ -8,6 +9,7 @@ const FederationClient = require('../rpc/client');
const getConfig = require('../utils/get_confg');
const checkParseJson = require('../utils/check_parse_json');
const { clientValidateJob, clientGenerateYaml } = require('../utils/job_builder');
+const { client } = require('../libs/k8s');
const config = getConfig({
NAMESPACE: process.env.NAMESPACE,
@@ -19,15 +21,35 @@ try {
es_oparator_match_phrase = require('../es.match_phrase');
} catch (err) { /* */ }
-router.get('/api/v1/jobs', SessionMiddleware, async (ctx) => {
+router.get('/api/v1/jobs', SessionMiddleware, FindOptionsMiddleware, async (ctx) => {
const jobs = await Job.findAll({
+ ...ctx.findOptions,
order: [['created_at', 'DESC']],
});
const { flapps } = await k8s.getFLAppsByNamespace(namespace);
- const data = jobs.map((job) => ({
- ...(flapps.items.find((item) => item.metadata.name === job.name)),
- localdata: job,
- }));
+ let data = [];
+ for (job of jobs) {
+ if (job.status == null) job.status = 'started';
+ if (job.federation_id == null) {
+ const clientTicket = await Ticket.findOne({
+ where: {
+ name: { [Op.eq]: job.client_ticket_name },
+ },
+ });
+ job.federation_id = clientTicket.federation_id;
+ }
+ if (job.status === 'stopped') {
+ data.push({
+ ...JSON.parse(job.k8s_meta_snapshot).flapp,
+ localdata: job,
+ });
+ } else {
+ data.push({
+ ...(flapps.items.find((item) => item.metadata.name === job.name)),
+ localdata: job,
+ });
+ }
+ }
ctx.body = { data };
});
@@ -41,7 +63,24 @@ router.get('/api/v1/job/:id', SessionMiddleware, async (ctx) => {
};
return;
}
- const { flapp } = await k8s.getFLApp(namespace, job.name);
+
+ if (job.status == null) job.status = 'started';
+ if (job.federation_id == null) {
+ const clientTicket = await Ticket.findOne({
+ where: {
+ name: { [Op.eq]: job.client_ticket_name },
+ },
+ });
+ job.federation_id = clientTicket.federation_id;
+ }
+
+ var flapp;
+ if (job.status === 'stopped') {
+ flapp = JSON.parse(job.k8s_meta_snapshot).flapp;
+ } else {
+ flapp = (await k8s.getFLApp(namespace, job.name)).flapp;
+ }
+
ctx.body = {
data: {
...flapp,
@@ -52,7 +91,28 @@ router.get('/api/v1/job/:id', SessionMiddleware, async (ctx) => {
router.get('/api/v1/job/:k8s_name/pods', SessionMiddleware, async (ctx) => {
const { k8s_name } = ctx.params;
- const { pods } = await k8s.getFLAppPods(namespace, k8s_name);
+
+ const job = await Job.findOne({
+ where: {
+ name: { [Op.eq]: k8s_name },
+ },
+ });
+
+ if (!job) {
+ ctx.status = 404;
+ ctx.body = {
+ error: 'Job not found',
+ };
+ return;
+ }
+
+ var pods;
+ if (job.status === 'stopped') {
+ pods = JSON.parse(job.k8s_meta_snapshot).pods;
+ } else {
+ pods = (await k8s.getFLAppPods(namespace, k8s_name)).pods;
+ }
+
ctx.body = { data: pods.items };
});
@@ -126,11 +186,6 @@ router.post('/api/v1/job', SessionMiddleware, async (ctx) => {
return;
}
- const job = {
- name, job_type, client_ticket_name, server_ticket_name,
- client_params, server_params,
- };
-
const exists = await Job.findOne({
where: {
name: { [Op.eq]: name },
@@ -157,25 +212,47 @@ router.post('/api/v1/job', SessionMiddleware, async (ctx) => {
return;
}
+ const clientFed = await Federation.findByPk(clientTicket.federation_id);
+ if (!clientFed) {
+ ctx.status = 422;
+ ctx.body = {
+ error: 'Federation does not exist',
+ };
+ return;
+ }
+ const rpcClient = new FederationClient(clientFed);
+
+ let serverTicket;
try {
- clientValidateJob(job, clientTicket);
- } catch (e) {
- ctx.status = 400;
+ const { data } = await rpcClient.getTickets({ job_type: '', role: '' });
+ serverTicket = data.find(x => x.name === server_ticket_name);
+ if (!serverTicket) {
+ throw new Error(`Cannot find server ticket ${server_ticket_name}`);
+ }
+ } catch (err) {
+ ctx.status = 500;
ctx.body = {
- error: `client_params validation failed: ${e.message}`,
+ error: `Cannot get server ticket: ${err.message}`,
};
return;
}
- const clientFed = await Federation.findByPk(clientTicket.federation_id);
- if (!clientFed) {
- ctx.status = 422;
+ const job = {
+ name, job_type, client_ticket_name, server_ticket_name,
+ client_params, server_params, status: 'started',
+ federation_id: clientFed.id,
+ };
+
+ try {
+ clientValidateJob(job, clientTicket, serverTicket);
+ } catch (e) {
+ ctx.status = 400;
ctx.body = {
- error: 'Federation does not exist',
+ error: `client_params validation failed: ${e.message}`,
};
return;
}
- const rpcClient = new FederationClient(clientFed);
+
try {
await rpcClient.createJob({
...job,
@@ -184,7 +261,7 @@ router.post('/api/v1/job', SessionMiddleware, async (ctx) => {
} catch (err) {
ctx.status = 500;
ctx.body = {
- error: err.details,
+ error: `RPC Error: ${err.message}`,
};
return;
}
@@ -215,6 +292,166 @@ router.post('/api/v1/job', SessionMiddleware, async (ctx) => {
ctx.body = { data };
});
+router.post('/api/v1/job/:id/update', SessionMiddleware, async (ctx) => {
+ // get old job info
+ const { id } = ctx.params;
+ const old_job = await Job.findByPk(id);
+ if (!old_job) {
+ ctx.status = 404;
+ ctx.body = {
+ error: 'Job not found',
+ };
+ return;
+ }
+
+ if (old_job.status === 'error') {
+ ctx.status = 422;
+ ctx.body = {
+ error: 'Cannot update errored job',
+ };
+ return;
+ }
+
+ const {
+ name, job_type, client_ticket_name, server_ticket_name,
+ client_params, server_params, status,
+ } = ctx.request.body;
+
+ if (old_job.status === 'started' && status != 'stopped') {
+ ctx.status = 422;
+ ctx.body = {
+ error: 'Cannot change running job',
+ };
+ return;
+ }
+
+ if (name != old_job.name) {
+ ctx.status = 422;
+ ctx.body = {
+ error: 'cannot change job name',
+ };
+ return;
+ }
+
+ if (job_type != old_job.job_type) {
+ ctx.status = 422;
+ ctx.body = {
+ error: 'cannot change job type',
+ };
+ return;
+ }
+
+ const clientTicket = await Ticket.findOne({
+ where: {
+ name: { [Op.eq]: client_ticket_name },
+ },
+ });
+ if (!clientTicket) {
+ ctx.status = 422;
+ ctx.body = {
+ error: `client_ticket ${client_ticket_name} does not exist`,
+ };
+ return;
+ }
+
+ const OldClientTicket = await Ticket.findOne({
+ where: {
+ name: { [Op.eq]: old_job.client_ticket_name },
+ },
+ });
+ if (!OldClientTicket) {
+ ctx.status = 422;
+ ctx.body = {
+ error: `client_ticket ${old_job.client_ticket_name} does not exist`,
+ };
+ return;
+ }
+
+ if (clientTicket.federation_id != OldClientTicket.federation_id) {
+ ctx.status = 422;
+ ctx.body = {
+ error: 'cannot change job federation',
+ };
+ return;
+ }
+
+ const clientFed = await Federation.findByPk(clientTicket.federation_id);
+ if (!clientFed) {
+ ctx.status = 422;
+ ctx.body = {
+ error: 'Federation does not exist',
+ };
+ return;
+ }
+ const rpcClient = new FederationClient(clientFed);
+
+ let serverTicket;
+ try {
+ const { data } = await rpcClient.getTickets({ job_type: '', role: '' });
+ serverTicket = data.find(x => x.name === server_ticket_name);
+ if (!serverTicket) {
+ throw new Error(`Cannot find server ticket ${server_ticket_name}`);
+ }
+ } catch (err) {
+ ctx.status = 500;
+ ctx.body = {
+ error: `RPC Error: ${err.message}`,
+ };
+ return;
+ }
+
+ const new_job = {
+ name, job_type, client_ticket_name, server_ticket_name,
+ client_params, server_params, status,
+ federation_id: clientTicket.federation_id,
+ };
+
+ try {
+ clientValidateJob(new_job, clientTicket, serverTicket);
+ } catch (e) {
+ ctx.status = 400;
+ ctx.body = {
+ error: `client_params validation failed: ${e.message}`,
+ };
+ return;
+ }
+
+ // update job
+ try {
+ await rpcClient.updateJob({
+ ...new_job,
+ server_params: JSON.stringify(server_params),
+ });
+ } catch (err) {
+ ctx.status = 500;
+ ctx.body = {
+ error: `RPC Error: ${err.message}`,
+ };
+ return;
+ }
+
+ if (old_job.status === 'started' && new_job.status === 'stopped') {
+ flapp = (await k8s.getFLApp(namespace, new_job.name)).flapp;
+ pods = (await k8s.getFLAppPods(namespace, new_job.name)).pods;
+ old_job.k8s_meta_snapshot = JSON.stringify({flapp, pods});
+ await k8s.deleteFLApp(namespace, new_job.name);
+ } else if (old_job.status === 'stopped' && new_job.status === 'started') {
+ const clientYaml = clientGenerateYaml(clientFed, new_job, clientTicket);
+ await k8s.createFLApp(namespace, clientYaml);
+ }
+
+ old_job.client_ticket_name = new_job.client_ticket_name;
+ old_job.server_ticket_name = new_job.server_ticket_name;
+ old_job.client_params = new_job.client_params;
+ old_job.server_params = new_job.server_params;
+ old_job.status = new_job.status;
+ old_job.federation_id = new_job.federation_id;
+
+ const data = await old_job.save();
+
+ ctx.body = { data };
+});
+
router.delete('/api/v1/job/:id', SessionMiddleware, async (ctx) => {
// TODO: just owner can delete
const { id } = ctx.params;
@@ -228,6 +465,11 @@ router.delete('/api/v1/job/:id', SessionMiddleware, async (ctx) => {
return;
}
+ if (!data.status || data.status == 'started') {
+ await k8s.deleteFLApp(namespace, data.name);
+ }
+ await data.destroy({ force: true });
+
const ticket = await Ticket.findOne({
where: {
name: { [Op.eq]: data.client_ticket_name },
@@ -240,12 +482,10 @@ router.delete('/api/v1/job/:id', SessionMiddleware, async (ctx) => {
} catch (err) {
ctx.status = 500;
ctx.body = {
- error: err.details,
+ error: `RPC Error: ${err.message}`,
};
return;
}
- await k8s.deleteFLApp(namespace, data.name);
- await data.destroy({ force: true });
ctx.body = { data };
});
diff --git a/web_console/components/BooleanSelect.jsx b/web_console/components/BooleanSelect.jsx
new file mode 100644
index 000000000..2ecd86bd4
--- /dev/null
+++ b/web_console/components/BooleanSelect.jsx
@@ -0,0 +1,19 @@
+import React from 'react';
+import { Select } from '@zeit-ui/react';
+
+const options = [
+ { label: 'True', value: 'true' },
+ { label: 'False', value: 'false' },
+]
+
+export default function ClientTicketSelect(props) {
+ const actualValue = props.value?.toString() || 'true'
+ const actualOnChange = (value) => {
+ props.onChange(value === 'true');
+ };
+ return (
+
+ );
+}
diff --git a/web_console/components/ClientTicketSelect.jsx b/web_console/components/ClientTicketSelect.jsx
index 13ec80a9d..3f2d92656 100644
--- a/web_console/components/ClientTicketSelect.jsx
+++ b/web_console/components/ClientTicketSelect.jsx
@@ -1,13 +1,23 @@
-import React from 'react';
+import React, { useCallback } from 'react';
import { Select } from '@zeit-ui/react';
import useSWR from 'swr';
import { fetcher } from '../libs/http';
+import { JOB_TYPE_CLASS } from '../constants/job'
-export default function ClientTicketSelect(props) {
+let filter = () => true
+export default function ClientTicketSelect({type, ...props}) {
const { data } = useSWR('tickets', fetcher);
- const tickets = (data && data.data) || [];
- const actualValue = tickets.find((x) => x.name === props.value)?.value;
+ if (type) {
+ filter = el => JOB_TYPE_CLASS[type].some(t => el.job_type === t)
+ }
+
+ const tickets = data
+ ? data.data.filter(filter)
+ : [];
+
+ // const actualValue = tickets.find((x) => x.name === props.value)?.value;
+ const actualValue = props.value || ''
const actualOnChange = (value) => {
const ticket = tickets.find((x) => x.name === value);
props.onChange(ticket.name);
diff --git a/web_console/components/CommonJobList.jsx b/web_console/components/CommonJobList.jsx
new file mode 100644
index 000000000..a9b0672da
--- /dev/null
+++ b/web_console/components/CommonJobList.jsx
@@ -0,0 +1,747 @@
+import React, { useMemo, useState, useCallback } from 'react';
+import css from 'styled-jsx/css';
+import { Link, Text, Input, Fieldset, Button, Card, Description, useTheme, useInput, Tooltip } from '@zeit-ui/react';
+import AlertCircle from '@geist-ui/react-icons/alertCircle'
+import Search from '@zeit-ui/react-icons/search';
+import NextLink from 'next/link';
+import useSWR from 'swr';
+import produce from 'immer'
+
+import { fetcher } from '../libs/http';
+import { FLAppStatus, handleStatus, getStatusColor, JobStatus } from '../utils/job';
+import Layout from '../components/Layout';
+import PopConfirm from '../components/PopConfirm';
+import Dot from '../components/Dot';
+import Empty from '../components/Empty';
+import { deleteJob, createJob } from '../services/job';
+import Form from '../components/Form';
+import {
+ JOB_DATA_JOIN_PARAMS,
+ JOB_NN_PARAMS,
+ JOB_PSI_DATA_JOIN_PARAMS,
+ JOB_TREE_PARAMS,
+ JOB_DATA_JOIN_REPLICA_TYPE,
+ JOB_NN_REPLICA_TYPE,
+ JOB_PSI_DATA_JOIN_REPLICA_TYPE,
+ JOB_TREE_REPLICA_TYPE,
+} from '../constants/form-default'
+import { getParsedValueFromData, fillJSON, getValueFromJson, getValueFromEnv, filterArrayValue } from '../utils/form_utils';
+import { getJobStatus } from '../utils/job'
+import { JOB_TYPE_CLASS, JOB_TYPE } from '../constants/job'
+
+// import {mockJobList} from '../constants/mock_data'
+
+function useStyles(theme) {
+ return css`
+ .counts-wrap {
+ padding: 0 5%;
+ display: flex;
+ align-items: center;
+ justify-content: space-between;
+ }
+
+ .num {
+ text-align: center;
+ color: ${theme.palette.accents_5};
+ margin-bottom: 1em;
+ }
+ .h {
+ font-weight: normal;
+ margin: 1em 0;
+ }
+ .b {
+ color: ${theme.palette.accents_6};
+ font-size: 1.4em;
+ }
+
+ .list-wrap {
+ position: relative;
+ }
+ .filter-bar {
+ position: absolute;
+ right: 0;
+ top: 0;
+ display: flex;
+ align-items: center;
+ justify-content: flex-end;
+ }
+ .filter-form {
+ display: flex;
+ align-items: center;
+ margin-right: 10px;
+ }
+ .filter-input {
+ width: 200px;
+ margin-right: 10px;
+ }
+
+ .content-list-wrap {
+ list-style: none;
+ padding: 0;
+ margin: 0;
+ }
+ .content-list {
+ padding: 20px 10px;
+ border-bottom: 1px solid ${theme.palette.border};
+ }
+ .content-list:last-of-type {
+ border-bottom: none;
+ }
+ .desc-wrap {
+ display: flex;
+ }
+ `;
+}
+
+const RESOURCE_PATH_PREFIX = 'spec.flReplicaSpecs.[replicaType].template.spec.containers[].resources'
+const ENV_PATH = 'spec.flReplicaSpecs.[replicaType].template.spec.containers[].env'
+const PARAMS_GROUP = ['client_params', 'server_params']
+
+function handleParamData(container, data, field) {
+ if (field.type === 'label') { return }
+
+ let path = field.path || field.key
+ let value = data
+
+ if (/[/s/S]* num$/.test(field.key)) {
+ value = parseInt(value)
+ }
+
+ fillJSON(container, path, value)
+}
+
+function fillField(data, field) {
+ if (data === undefined) return field
+
+ let isSetValueWithEmpty = false
+ let disabled = false
+
+ let v = getValueFromJson(data, field.path || field.key) || field.emptyDefault || ''
+
+ if (field.key === 'federation_id') {
+ const federationID = parseInt(localStorage.getItem('federationID'))
+ if (federationID > 0) {
+ v = federationID
+ disabled = true
+ }
+ }
+ else if (/[/s/S]* num$/.test(field.key)) {
+ let replicaType = field.key.split(' ')[0]
+ let path = `spec.flReplicaSpecs.${replicaType}.replicas`
+ v = getValueFromJson(data['client_params'], path)
+ || getValueFromJson(data['server_params'], path)
+ }
+
+ if (typeof v === 'object' && v !== null) {
+ v = JSON.stringify(v, null, 2)
+ }
+
+ if (v !== undefined || (v === undefined && isSetValueWithEmpty)) {
+ field.value = v
+ }
+ field.editing = true
+
+ if (!field.props) field.props = {}
+ field.props.disabled = disabled
+
+ return field
+}
+
+let federationId = null, jobType = null
+
+const passFieldInfo = fields => produce(fields, draft => {
+ draft.map(field => {
+ if (field.key === 'client_ticket_name') {
+ field.props.job_type = jobType
+ }
+ if (field.key === 'server_ticket_name') {
+ field.props.federation_id = federationId
+ }
+ })
+})
+
+function mapValueToFields({data, fields, targetGroup, type = 'form', init = false}) {
+ return produce(fields, draft => {
+ draft.map((x) => {
+
+ if (x.groupName) {
+ if (!data[x.groupName]) return
+ if (!init && x.groupName !== targetGroup) return
+
+ if (x.formTypes) {
+ let types = init ? x.formTypes : [type]
+ types.forEach(el => {
+ x.fields[el].forEach(field => fillField(data[x.groupName], field))
+ })
+ } else {
+ x.fields.forEach(field => fillField(data[x.groupName], field))
+ }
+
+ } else {
+ fillField(data, x)
+ }
+
+ });
+ })
+}
+
+let formMeta = {}
+const setFormMeta = value => { formMeta = value }
+
+export default function JobList({
+ datasoure,
+ training,
+ filter,
+ ...props
+}) {
+ const theme = useTheme();
+ const styles = useStyles(theme);
+
+ let JOB_REPLICA_TYPE, NAME_KEY, FILTER_TYPES, PAGE_NAME, INIT_PARAMS, DEFAULT_JOB_TYPE
+ if (datasoure) {
+
+ PAGE_NAME = 'datasource'
+
+ JOB_REPLICA_TYPE = JOB_DATA_JOIN_REPLICA_TYPE
+
+ NAME_KEY = 'DATA_SOURCE_NAME'
+
+ FILTER_TYPES = JOB_TYPE_CLASS.datasource
+
+ INIT_PARAMS = JOB_DATA_JOIN_PARAMS
+
+ DEFAULT_JOB_TYPE = JOB_TYPE.data_join
+
+ } else {
+
+ PAGE_NAME = 'training'
+
+ JOB_REPLICA_TYPE = JOB_NN_REPLICA_TYPE
+
+ NAME_KEY = 'TRAINING_NAME'
+
+ FILTER_TYPES = JOB_TYPE_CLASS.training
+
+ INIT_PARAMS = JOB_NN_PARAMS
+
+ DEFAULT_JOB_TYPE = JOB_TYPE.nn_model
+
+ }
+
+ filter = filter
+ || useCallback(job => FILTER_TYPES.some(type => type === job.localdata.job_type), [])
+
+ const getParamsFormFields = useCallback(() => JOB_REPLICA_TYPE.reduce((total, currType) => {
+ total.push(...[
+ { key: currType, type: 'label' },
+ {
+ key: currType + '.env',
+ label: 'env',
+ type: 'name-value',
+ path: `spec.flReplicaSpecs.${currType}.template.spec.containers[].env`,
+ span: 24,
+ emptyDefault: [],
+ props: {
+ ignoreKeys: filterArrayValue([
+ datasoure && 'DATA_SOURCE_NAME',
+ training && 'TRAINING_NAME',
+ ])
+ }
+ },
+ {
+ key: 'resoure.' + currType + '.cup_request',
+ label: 'cpu request',
+ path: RESOURCE_PATH_PREFIX.replace('[replicaType]', currType) + '.requests.cpu',
+ span: 12,
+ },
+ {
+ key: 'resoure.' + currType + '.cup_limit',
+ label: 'cpu limit',
+ path: RESOURCE_PATH_PREFIX.replace('[replicaType]', currType) + '.limits.cpu',
+ span: 12,
+ },
+ {
+ key: 'resoure.' + currType + '.memory_request',
+ label: 'memory request',
+ path: RESOURCE_PATH_PREFIX.replace('[replicaType]', currType) + '.requests.memory',
+ span: 12,
+ },
+ {
+ key: 'resoure.' + currType + '.memory_limit',
+ label: 'memory limit',
+ path: RESOURCE_PATH_PREFIX.replace('[replicaType]', currType) + '.limits.memory',
+ span: 12,
+ },
+ ])
+ return total
+ }, []), [RESOURCE_PATH_PREFIX, JOB_REPLICA_TYPE])
+
+ const { data, mutate } = useSWR('jobs', fetcher);
+ const jobs = data && data.data ? data.data.filter(el => el.metadata).filter(filter) : null
+ // const jobs = mockJobList.data
+
+ // form meta convert functions
+ const rewriteFields = useCallback((draft, data) => {
+ // this function will be call inner immer
+ // env
+ const insert2Env = filterArrayValue([
+ { name: NAME_KEY, getValue: data => data.name },
+ ])
+
+ PARAMS_GROUP.forEach(paramType => {
+ JOB_REPLICA_TYPE.forEach(replicaType => {
+ if (!draft[paramType]) {
+ draft[paramType] = {}
+ }
+ let envs = getValueFromJson(draft[paramType], ENV_PATH.replace('[replicaType]', replicaType))
+ if (!envs) return
+
+ let envNames = envs.map(env => env.name)
+
+ insert2Env.forEach(el => {
+ let idx = envNames.indexOf(el.name)
+ let value = el.getValue(data) || ''
+ if (idx >= 0) {
+ envs[idx].value = value.toString()
+ } else {
+ // here envs is not extensible, push will throw error
+ envs = envs.concat({name: el.name, value: value.toString()})
+ }
+ })
+
+ // trigger immer‘s intercepter
+ fillJSON(draft[paramType], ENV_PATH.replace('[replicaType]', replicaType), envs)
+
+ // replicas
+ let path = `spec.flReplicaSpecs.${replicaType}.replicas`
+ if (replicaType !== 'Master') {
+ let num = parseInt(data[`${replicaType} num`])
+ !isNaN(num) && fillJSON(draft[paramType], path, parseInt(data[`${replicaType} num`]))
+ }
+
+ })
+ })
+
+ // delete useless fields
+
+ JOB_REPLICA_TYPE
+ .forEach(replicaType =>
+ draft[`${replicaType} num`] && delete draft[`${replicaType} num`]
+ )
+
+ }, [JOB_REPLICA_TYPE])
+ const mapFormMeta2FullData = useCallback((fields = fields) => {
+ let data = {}
+ fields.map((x) => {
+ if (x.groupName) {
+ data[x.groupName] = { ...formMeta[x.groupName] }
+ data[x.groupName][x.groupName] = formMeta[x.groupName]
+ } else {
+ data[x.key] = formMeta[x.key]
+ }
+ })
+ return data
+ }, [])
+ const writeJson2FormMeta = useCallback((groupName, data) => {
+ setFormMeta(produce(formMeta, draft => {
+ fields.map((x) => {
+ if (x.groupName) {
+ if (x.groupName !== groupName) return
+ draft[groupName] = JSON.parse(data[groupName][groupName])
+ } else {
+ draft[x.key] = getParsedValueFromData(data, x) || draft[x.key]
+ }
+ })
+
+ rewriteFields(draft, data)
+ }))
+ }, [])
+ const writeForm2FormMeta = useCallback((groupName, data) => {
+ setFormMeta(produce(formMeta, draft => {
+ let value
+
+ fields.map(x => {
+ if (x.groupName) {
+ if (x.groupName !== groupName) return
+ if (!draft[groupName]) { draft[groupName] = {} }
+
+ for (let field of getParamsFormFields()) {
+ value = getParsedValueFromData(data[groupName], field)
+ handleParamData(draft[groupName], value, field)
+ }
+
+ } else {
+ value = getParsedValueFromData(data, x) || draft[x.key]
+ handleParamData(draft, value, x)
+ }
+ })
+ rewriteFields(draft, data)
+ }))
+ }, [])
+ // ---end---
+ const onJobTypeChange = useCallback((value, totalData, groupFormType) => {
+ writeFormMeta(totalData, groupFormType)
+
+ switch (value) {
+ case JOB_TYPE.data_join:
+ JOB_REPLICA_TYPE = JOB_DATA_JOIN_REPLICA_TYPE
+ setFormMeta({...formMeta, ...JOB_DATA_JOIN_PARAMS}); break
+ case JOB_TYPE.psi_data_join:
+ JOB_REPLICA_TYPE = JOB_PSI_DATA_JOIN_REPLICA_TYPE
+ setFormMeta({...formMeta, ...JOB_PSI_DATA_JOIN_PARAMS}); break
+ case JOB_TYPE.nn_model:
+ JOB_REPLICA_TYPE = JOB_NN_REPLICA_TYPE
+ setFormMeta({...formMeta, ...JOB_NN_PARAMS}); break
+ case JOB_TYPE.tree_model:
+ JOB_REPLICA_TYPE = JOB_TREE_REPLICA_TYPE
+ setFormMeta({...formMeta, ...JOB_TREE_PARAMS}); break
+ }
+
+ jobType = value
+
+ setFields(
+ passFieldInfo(mapValueToFields({
+ data: mapFormMeta2FullData(fields),
+ fields: getDefaultFields(),
+ init: true,
+ }))
+ )
+ }, [])
+ const getDefaultFields = useCallback(() => filterArrayValue([
+ {
+ key: 'name',
+ required: true,
+ },
+ {
+ key: 'job_type',
+ type: 'jobType',
+ props: {type: PAGE_NAME},
+ required: true,
+ label: (
+ <>
+ job_type
+