diff --git a/pynest/nest/server/hl_api_server.py b/pynest/nest/server/hl_api_server.py index e4e7dfe9a7..c343317c9d 100644 --- a/pynest/nest/server/hl_api_server.py +++ b/pynest/nest/server/hl_api_server.py @@ -32,11 +32,10 @@ CORS_ORIGINS, EXEC_CALL_ENABLED, _check_security, - get_arguments, nest_calls, ) from .hl_api_server_mpi import api_client, do_call, log, mpi_comm -from .hl_api_server_utils import ErrorHandler +from .hl_api_server_utils import ErrorHandler, get_arguments # This ensures that the logging information shows up in the console running the server, # even when Flask's event loop is running. diff --git a/pynest/nest/server/hl_api_server_helpers.py b/pynest/nest/server/hl_api_server_helpers.py index 0b7e2c61d9..38a7afb598 100644 --- a/pynest/nest/server/hl_api_server_helpers.py +++ b/pynest/nest/server/hl_api_server_helpers.py @@ -19,18 +19,24 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -import ast -import importlib import inspect -import io import os import sys import time +import traceback import nest import RestrictedPython +from nest.lib.hl_api_exceptions import NESTError -from .hl_api_server_utils import get_boolean_environ, get_or_error +from .hl_api_server_utils import ( + Capturing, + ErrorHandler, + clean_code, + get_boolean_environ, + get_lineno, + get_modules_from_env, +) _default_origins = "http://localhost:*,http://127.0.0.1:*" ACCESS_TOKEN = os.environ.get("NEST_SERVER_ACCESS_TOKEN", "") @@ -38,8 +44,6 @@ CORS_ORIGINS = os.environ.get("NEST_SERVER_CORS_ORIGINS", _default_origins).split(",") EXEC_CALL_ENABLED = get_boolean_environ("NEST_SERVER_ENABLE_EXEC_CALL") RESTRICTION_DISABLED = get_boolean_environ("NEST_SERVER_DISABLE_RESTRICTION") -MODULES = os.environ.get("NEST_SERVER_MODULES", "import nest") -RESTRICTION_DISABLED = get_boolean_environ("NEST_SERVER_DISABLE_RESTRICTION") __all__ = [ "nestify", @@ -75,31 +79,6 @@ def _check_security(): print("\n - ".join([" "] + msg) + "\n") -class Capturing(list): - """Monitor stdout contents i.e. print.""" - - def __enter__(self): - self._stdout = sys.stdout - sys.stdout = self._stringio = io.StringIO() - return self - - def __exit__(self, *args): - self.extend(self._stringio.getvalue().splitlines()) - del self._stringio # free up some memory - sys.stdout = self._stdout - - -def clean_code(source): - codes = source.split("\n") - codes_cleaned = [] # noqa - for code in codes: - if code.startswith("import") or code.startswith("from"): - codes_cleaned.append("#" + code) - else: - codes_cleaned.append(code) - return "\n".join(codes_cleaned) - - def do_exec(kwargs): source_code = kwargs.get("source", "") source_cleaned = clean_code(source_code) @@ -132,56 +111,38 @@ def do_exec(kwargs): return response -def get_arguments(request): - """Get arguments from the request.""" - args, kwargs = [], {} - if request.is_json: - json = request.get_json() - if isinstance(json, str) and len(json) > 0: - args = [json] - elif isinstance(json, list): - args = json - elif isinstance(json, dict): - kwargs = json - if "args" in kwargs: - args = kwargs.pop("args") - elif len(request.form) > 0: - if "args" in request.form: - args = request.form.getlist("args") - else: - kwargs = request.form.to_dict() - elif len(request.args) > 0: - if "args" in request.args: - args = request.args.getlist("args") - else: - kwargs = request.args.to_dict() - return list(args), kwargs +def get_or_error(func): + """Wrapper to exec function.""" + def func_wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) -def get_modules_from_env(): - """Get modules from environment variable NEST_SERVER_MODULES. + except NESTError as err: + error_class = err.errorname + " (NESTError)" + detail = err.errormessage + lineno = get_lineno(err, 1) - This function converts the content of the environment variable NEST_SERVER_MODULES: - to a formatted dictionary for updating the Python `globals`. + except (KeyError, SyntaxError, TypeError, ValueError) as err: + error_class = err.__class__.__name__ + detail = err.args[0] + lineno = get_lineno(err, 1) - Here is an example: - `NEST_SERVER_MODULES="import nest; import numpy as np; from numpy import random"` - is converted to the following dictionary: - `{'nest': 'np': , 'random': }` - """ - modules = {} - try: - parsed = ast.iter_child_nodes(ast.parse(MODULES)) - except (SyntaxError, ValueError): - raise SyntaxError("The NEST server module environment variables contains syntax errors.") - for node in parsed: - if isinstance(node, ast.Import): - for alias in node.names: - modules[alias.asname or alias.name] = importlib.import_module(alias.name) - elif isinstance(node, ast.ImportFrom): - for alias in node.names: - modules[alias.asname or alias.name] = importlib.import_module(f"{node.module}.{alias.name}") - return modules + except Exception as err: + error_class = err.__class__.__name__ + detail = err.args[0] + lineno = get_lineno(err, -1) + + for line in traceback.format_exception(*sys.exc_info()): + print(line, flush=True) + + if lineno == -1: + message = "%s: %s" % (error_class, detail) + else: + message = "%s at line %d: %s" % (error_class, lineno, detail) + raise ErrorHandler(message, lineno) + + return func_wrapper def get_restricted_globals(): diff --git a/pynest/nest/server/hl_api_server_utils.py b/pynest/nest/server/hl_api_server_utils.py index c4e790082a..4777a2411f 100644 --- a/pynest/nest/server/hl_api_server_utils.py +++ b/pynest/nest/server/hl_api_server_utils.py @@ -19,11 +19,29 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . + +import ast +import importlib +import io import os import sys import traceback -from nest.lib.hl_api_exceptions import NESTError +MODULES = os.environ.get("NEST_SERVER_MODULES", "import nest") + + +class Capturing(list): + """Monitor stdout contents i.e. print.""" + + def __enter__(self): + self._stdout = sys.stdout + sys.stdout = self._stringio = io.StringIO() + return self + + def __exit__(self, *args): + self.extend(self._stringio.getvalue().splitlines()) + del self._stringio # free up some memory + sys.stdout = self._stdout class ErrorHandler(Exception): @@ -47,6 +65,43 @@ def to_dict(self): return rv +def clean_code(source): + codes = source.split("\n") + codes_cleaned = [] # noqa + for code in codes: + if code.startswith("import") or code.startswith("from"): + codes_cleaned.append("#" + code) + else: + codes_cleaned.append(code) + return "\n".join(codes_cleaned) + + +def get_arguments(request): + """Get arguments from the request.""" + args, kwargs = [], {} + if request.is_json: + json = request.get_json() + if isinstance(json, str) and len(json) > 0: + args = [json] + elif isinstance(json, list): + args = json + elif isinstance(json, dict): + kwargs = json + if "args" in kwargs: + args = kwargs.pop("args") + elif len(request.form) > 0: + if "args" in request.form: + args = request.form.getlist("args") + else: + kwargs = request.form.to_dict() + elif len(request.args) > 0: + if "args" in request.args: + args = request.args.getlist("args") + else: + kwargs = request.args.to_dict() + return list(args), kwargs + + def get_boolean_environ(env_key, default_value="false"): env_value = os.environ.get(env_key, default_value) return env_value.lower() in ["yes", "true", "t", "1"] @@ -65,35 +120,27 @@ def get_lineno(err, tb_idx): return lineno -def get_or_error(func): - """Wrapper to exec function.""" - - def func_wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - - except NESTError as err: - error_class = err.errorname + " (NESTError)" - detail = err.errormessage - lineno = get_lineno(err, 1) - - except (KeyError, SyntaxError, TypeError, ValueError) as err: - error_class = err.__class__.__name__ - detail = err.args[0] - lineno = get_lineno(err, 1) - - except Exception as err: - error_class = err.__class__.__name__ - detail = err.args[0] - lineno = get_lineno(err, -1) - - for line in traceback.format_exception(*sys.exc_info()): - print(line, flush=True) - - if lineno == -1: - message = "%s: %s" % (error_class, detail) - else: - message = "%s at line %d: %s" % (error_class, lineno, detail) - raise ErrorHandler(message, lineno) - - return func_wrapper +def get_modules_from_env(): + """Get modules from environment variable NEST_SERVER_MODULES. + + This function converts the content of the environment variable NEST_SERVER_MODULES: + to a formatted dictionary for updating the Python `globals`. + + Here is an example: + `NEST_SERVER_MODULES="import nest; import numpy as np; from numpy import random"` + is converted to the following dictionary: + `{'nest': 'np': , 'random': }` + """ + modules = {} + try: + parsed = ast.iter_child_nodes(ast.parse(MODULES)) + except (SyntaxError, ValueError): + raise SyntaxError("The NEST server module environment variables contains syntax errors.") + for node in parsed: + if isinstance(node, ast.Import): + for alias in node.names: + modules[alias.asname or alias.name] = importlib.import_module(alias.name) + elif isinstance(node, ast.ImportFrom): + for alias in node.names: + modules[alias.asname or alias.name] = importlib.import_module(f"{node.module}.{alias.name}") + return modules