-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfrontend.py
250 lines (228 loc) · 8.48 KB
/
frontend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import json
import os
import pathlib
import pickle
import shutil
import tempfile
from uuid import uuid4
from dash import ClientsideFunction, Input, Output, State, dcc
from dash_component_editor import JSONParameterEditor
from file_manager.data_project import DataProject
from src.app_layout import DATA_DIR, TILED_KEY, USER, app, long_callback_manager
from src.callbacks.display import ( # noqa: F401
close_warning_modal,
open_warning_modal,
refresh_image,
refresh_label,
refresh_results,
update_slider_boundaries_new_dataset,
update_slider_boundaries_prediction,
update_slider_value,
)
from src.callbacks.download import disable_download, toggle_storage_modal # noqa: F401
from src.callbacks.execute import close_resources_popup, execute # noqa: F401
from src.callbacks.load_labels import load_from_splash_modal # noqa: F401
from src.callbacks.table import delete_row, open_job_modal, update_table # noqa: F401
from src.utils.data_utils import get_input_params, prepare_directories
from src.utils.job_utils import MlexJob
from src.utils.model_utils import get_gui_components, get_model_content
APP_HOST = os.getenv("APP_HOST", "127.0.0.1")
APP_PORT = os.getenv("APP_PORT", "8062")
DIR_MOUNT = os.getenv("DIR_MOUNT", DATA_DIR)
server = app.server
app.clientside_callback(
ClientsideFunction(namespace="clientside", function_name="transform_image"),
Output("img-output", "src"),
Input("log-transform", "on"),
Input("img-output-store", "data"),
prevent_initial_call=True,
)
app.clientside_callback(
"""
function(n) {
if (typeof Intl === 'object' && typeof Intl.DateTimeFormat === 'function') {
const dtf = Intl.DateTimeFormat();
if (typeof dtf === 'object' && typeof dtf.resolvedOptions === 'function') {
const ro = dtf.resolvedOptions();
if (typeof ro === 'object' && typeof ro.timeZone === 'string') {
return ro.timeZone;
}
}
}
return 'Timezone information not available';
}
""",
Output("timezone-browser", "value"),
Input("interval", "n_intervals"),
)
@app.callback(
Output("app-parameters", "children"),
Input("model-selection", "value"),
Input("action", "value"),
prevent_intial_call=True,
)
def load_parameters(model_selection, action_selection):
"""
This callback dynamically populates the parameters and contents of the website according to the
selected action & model.
Args:
model_selection: Selected model (from content registry)
action_selection: Selected action (pre-defined actions in Data Clinic)
Returns:
app-parameters: Parameters according to the selected model & action
"""
parameters = get_gui_components(model_selection, action_selection)
gui_item = JSONParameterEditor(
_id={"type": str(uuid4())}, # pattern match _id (base id), name
json_blob=parameters,
)
gui_item.init_callbacks(app)
return gui_item
@app.long_callback(
Output("download-out", "data"),
Input("download-button", "n_clicks"),
State("jobs-table", "data"),
State("jobs-table", "selected_rows"),
manager=long_callback_manager,
prevent_intial_call=True,
)
def save_results(download, job_data, row):
"""
This callback saves the experimental results as a ZIP file
Args:
download: Download button
job_data: Table of jobs
row: Selected job/row
Returns:
ZIP file with results
"""
if download and row:
experiment_id = job_data[row[0]]["experiment_id"]
experiment_path = pathlib.Path(f"{DATA_DIR}/mlex_store/{USER}/{experiment_id}")
with tempfile.TemporaryDirectory():
tmp_dir = tempfile.gettempdir()
archive_path = os.path.join(tmp_dir, "results")
shutil.make_archive(archive_path, "zip", experiment_path)
return dcc.send_file(f"{archive_path}.zip")
else:
return None
@app.long_callback(
Output("job-alert-confirm", "is_open"),
Input("submit", "n_clicks"),
State("app-parameters", "children"),
State("num-cpus", "value"),
State("num-gpus", "value"),
State("action", "value"),
State("jobs-table", "data"),
State("jobs-table", "selected_rows"),
State({"base_id": "file-manager", "name": "data-project-dict"}, "data"),
State("model-name", "value"),
State("event-id", "value"),
State("model-selection", "value"),
State("log-transform", "on"),
State("img-labeled-indx", "options"),
running=[(Output("job-alert", "is_open"), "True", "False")],
manager=long_callback_manager,
prevent_initial_call=True,
)
def submit_ml_job(
submit,
children,
num_cpus,
num_gpus,
action_selection,
job_data,
row,
data_project_dict,
model_name,
event_id,
model_id,
log,
labeled_dropdown,
):
"""
This callback submits a job request to the compute service according to the selected action & model
Args:
submit: Submit button
children: Model parameters
num_cpus: Number of CPUs assigned to job
num_gpus: Number of GPUs assigned to job
action_selection: Action selected
job_data: Lists of jobs
row: Selected row (job)
data_project_dict: Data project dictionary
model_name: User-defined name for training or prediction model
event_id: Tagging event id for version control of tags
model_id: UID of model in content registry
log: Log toggle
labeled_dropdown: Indexes of the labeled images in this data set
Returns:
open the alert indicating that the job was submitted
"""
# Get model information from content registry
model_uri, [train_cmd, prediction_cmd] = get_model_content(model_id)
# Get model parameters
input_params = get_input_params(children)
input_params["log"] = log
kwargs = {}
data_project = DataProject.from_dict(data_project_dict, api_key=TILED_KEY)
if action_selection == "train_model":
experiment_id, orig_out_path, data_info = prepare_directories(
USER,
data_project,
labeled_indices=labeled_dropdown,
correct_path=(DATA_DIR == DIR_MOUNT),
)
# Find the relative data directory in docker container
if DIR_MOUNT == DATA_DIR:
relative_data_dir = "/app/work/data"
out_path = "/app/work/data" + str(orig_out_path).split(DATA_DIR, 1)[-1]
data_info = "/app/work/data" + str(data_info).split(DATA_DIR, 1)[-1]
else:
relative_data_dir = DATA_DIR
command = f"{train_cmd} -d {data_info} -o {out_path} -e {event_id}"
else:
experiment_id, orig_out_path, data_info = prepare_directories(
USER, data_project, train=False, correct_path=(DATA_DIR == DIR_MOUNT)
)
# Find the relative data directory in docker container
if DIR_MOUNT == DATA_DIR:
relative_data_dir = "/app/work/data"
out_path = "/app/work/data" + str(orig_out_path).split(DATA_DIR, 1)[-1]
data_info = "/app/work/data" + str(data_info).split(DATA_DIR, 1)[-1]
else:
relative_data_dir = DATA_DIR
training_exp_id = job_data[row[0]]["experiment_id"]
model_path = pathlib.Path(
f"{relative_data_dir}/mlex_store/{USER}/{training_exp_id}"
)
command = f"{prediction_cmd} -d {data_info} -m {model_path} -o {out_path}"
kwargs = {"train_params": job_data[row[0]]["parameters"]}
with open(f"{orig_out_path}/.file_manager_vars.pkl", "wb") as file:
pickle.dump(
data_project_dict,
file,
)
# Define MLExjob
job = MlexJob(
service_type="backend",
description=model_name,
working_directory="{}".format(DIR_MOUNT),
job_kwargs={
"uri": model_uri,
"type": "docker",
"cmd": f"{command} -p '{json.dumps(input_params)}'",
"kwargs": {
"job_type": action_selection,
"experiment_id": experiment_id,
"dataset": data_project.project_id,
"params": input_params,
**kwargs,
},
},
)
# Submit job
job.submit(USER, num_cpus, num_gpus)
return True
if __name__ == "__main__":
app.run_server(debug=True, host=APP_HOST, port=APP_PORT)