Skip to content

Commit

Permalink
adding deployment for surrogate
Browse files Browse the repository at this point in the history
  • Loading branch information
JBris committed Sep 23, 2024
1 parent b974c85 commit a2a13f8
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
2 changes: 1 addition & 1 deletion app/conf/deployments_form/common.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ components:
children: Neural density calibrator
color: primary
className: me-1
- id: call-surrogate-modelling-button
- id: call-surrogate-button
label: Surrogate model
help: Make a web request to the deployed surrogate model calibrator
class_name: dash_bootstrap_components.Button
Expand Down
47 changes: 47 additions & 0 deletions app/pages/deployments_root_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,53 @@ def call_snpe(n_clicks: int | list[int], statistics_inputs: list[float]) -> Call
return dcc.send_file(outfile), True, f"Calling {task} calibrator"


@callback(
Output(f"{PAGE_ID}-download-results", "data", allow_duplicate=True),
Output(f"{PAGE_ID}-load-toast", "is_open", allow_duplicate=True),
Output(f"{PAGE_ID}-load-toast", "children", allow_duplicate=True),
Input({"index": f"{PAGE_ID}-call-surrogate-button", "type": ALL}, "n_clicks"),
State({"type": f"{PAGE_ID}-parameters", "index": ALL}, "value"),
prevent_initial_call=True,
)
def call_surrogate(
n_clicks: int | list[int], statistics_inputs: list[float]
) -> Callable:
"""Call the surrogate model endpoint.
Args:
n_clicks (int | list[int]):
The number of form clicks.
statistics_inputs (list[float]):
The list of summary statistic values.
Returns:
Callable:
The form data.
"""
if n_clicks is None or len(n_clicks) == 0: # type: ignore
return no_update

if n_clicks[0] is None or n_clicks[0] == 0: # type: ignore
return no_update

endpoint = os.environ.get(
"DEPLOYMENT_SURROGATE_INTERNAL_LINK", "http://surrogate:3000"
)

app = get_app()
statistics_form = app.settings[FORM_NAME]
summary_statistics = {}
for i, child in enumerate(statistics_form.components["parameters"]["children"]):
statistic_name = child["param"]
statistic_value = statistics_inputs[i]
summary_statistics[statistic_name] = statistic_value

json = {"data": [summary_statistics]}
task = "surrogate"
outfile = endpoint_predict(task, endpoint, json)
return dcc.send_file(outfile), True, f"Calling {task} calibrator"


######################################
# Layout
######################################
Expand Down

0 comments on commit a2a13f8

Please sign in to comment.