diff --git a/pynest/nest/server/hl_api_server.py b/pynest/nest/server/hl_api_server.py
index e4e7dfe9a7..c343317c9d 100644
--- a/pynest/nest/server/hl_api_server.py
+++ b/pynest/nest/server/hl_api_server.py
@@ -32,11 +32,10 @@
CORS_ORIGINS,
EXEC_CALL_ENABLED,
_check_security,
- get_arguments,
nest_calls,
)
from .hl_api_server_mpi import api_client, do_call, log, mpi_comm
-from .hl_api_server_utils import ErrorHandler
+from .hl_api_server_utils import ErrorHandler, get_arguments
# This ensures that the logging information shows up in the console running the server,
# even when Flask's event loop is running.
diff --git a/pynest/nest/server/hl_api_server_helpers.py b/pynest/nest/server/hl_api_server_helpers.py
index 0b7e2c61d9..38a7afb598 100644
--- a/pynest/nest/server/hl_api_server_helpers.py
+++ b/pynest/nest/server/hl_api_server_helpers.py
@@ -19,18 +19,24 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
-import ast
-import importlib
import inspect
-import io
import os
import sys
import time
+import traceback
import nest
import RestrictedPython
+from nest.lib.hl_api_exceptions import NESTError
-from .hl_api_server_utils import get_boolean_environ, get_or_error
+from .hl_api_server_utils import (
+ Capturing,
+ ErrorHandler,
+ clean_code,
+ get_boolean_environ,
+ get_lineno,
+ get_modules_from_env,
+)
_default_origins = "http://localhost:*,http://127.0.0.1:*"
ACCESS_TOKEN = os.environ.get("NEST_SERVER_ACCESS_TOKEN", "")
@@ -38,8 +44,6 @@
CORS_ORIGINS = os.environ.get("NEST_SERVER_CORS_ORIGINS", _default_origins).split(",")
EXEC_CALL_ENABLED = get_boolean_environ("NEST_SERVER_ENABLE_EXEC_CALL")
RESTRICTION_DISABLED = get_boolean_environ("NEST_SERVER_DISABLE_RESTRICTION")
-MODULES = os.environ.get("NEST_SERVER_MODULES", "import nest")
-RESTRICTION_DISABLED = get_boolean_environ("NEST_SERVER_DISABLE_RESTRICTION")
__all__ = [
"nestify",
@@ -75,31 +79,6 @@ def _check_security():
print("\n - ".join([" "] + msg) + "\n")
-class Capturing(list):
- """Monitor stdout contents i.e. print."""
-
- def __enter__(self):
- self._stdout = sys.stdout
- sys.stdout = self._stringio = io.StringIO()
- return self
-
- def __exit__(self, *args):
- self.extend(self._stringio.getvalue().splitlines())
- del self._stringio # free up some memory
- sys.stdout = self._stdout
-
-
-def clean_code(source):
- codes = source.split("\n")
- codes_cleaned = [] # noqa
- for code in codes:
- if code.startswith("import") or code.startswith("from"):
- codes_cleaned.append("#" + code)
- else:
- codes_cleaned.append(code)
- return "\n".join(codes_cleaned)
-
-
def do_exec(kwargs):
source_code = kwargs.get("source", "")
source_cleaned = clean_code(source_code)
@@ -132,56 +111,38 @@ def do_exec(kwargs):
return response
-def get_arguments(request):
- """Get arguments from the request."""
- args, kwargs = [], {}
- if request.is_json:
- json = request.get_json()
- if isinstance(json, str) and len(json) > 0:
- args = [json]
- elif isinstance(json, list):
- args = json
- elif isinstance(json, dict):
- kwargs = json
- if "args" in kwargs:
- args = kwargs.pop("args")
- elif len(request.form) > 0:
- if "args" in request.form:
- args = request.form.getlist("args")
- else:
- kwargs = request.form.to_dict()
- elif len(request.args) > 0:
- if "args" in request.args:
- args = request.args.getlist("args")
- else:
- kwargs = request.args.to_dict()
- return list(args), kwargs
+def get_or_error(func):
+ """Wrapper to exec function."""
+ def func_wrapper(*args, **kwargs):
+ try:
+ return func(*args, **kwargs)
-def get_modules_from_env():
- """Get modules from environment variable NEST_SERVER_MODULES.
+ except NESTError as err:
+ error_class = err.errorname + " (NESTError)"
+ detail = err.errormessage
+ lineno = get_lineno(err, 1)
- This function converts the content of the environment variable NEST_SERVER_MODULES:
- to a formatted dictionary for updating the Python `globals`.
+ except (KeyError, SyntaxError, TypeError, ValueError) as err:
+ error_class = err.__class__.__name__
+ detail = err.args[0]
+ lineno = get_lineno(err, 1)
- Here is an example:
- `NEST_SERVER_MODULES="import nest; import numpy as np; from numpy import random"`
- is converted to the following dictionary:
- `{'nest': 'np': , 'random': }`
- """
- modules = {}
- try:
- parsed = ast.iter_child_nodes(ast.parse(MODULES))
- except (SyntaxError, ValueError):
- raise SyntaxError("The NEST server module environment variables contains syntax errors.")
- for node in parsed:
- if isinstance(node, ast.Import):
- for alias in node.names:
- modules[alias.asname or alias.name] = importlib.import_module(alias.name)
- elif isinstance(node, ast.ImportFrom):
- for alias in node.names:
- modules[alias.asname or alias.name] = importlib.import_module(f"{node.module}.{alias.name}")
- return modules
+ except Exception as err:
+ error_class = err.__class__.__name__
+ detail = err.args[0]
+ lineno = get_lineno(err, -1)
+
+ for line in traceback.format_exception(*sys.exc_info()):
+ print(line, flush=True)
+
+ if lineno == -1:
+ message = "%s: %s" % (error_class, detail)
+ else:
+ message = "%s at line %d: %s" % (error_class, lineno, detail)
+ raise ErrorHandler(message, lineno)
+
+ return func_wrapper
def get_restricted_globals():
diff --git a/pynest/nest/server/hl_api_server_utils.py b/pynest/nest/server/hl_api_server_utils.py
index c4e790082a..4777a2411f 100644
--- a/pynest/nest/server/hl_api_server_utils.py
+++ b/pynest/nest/server/hl_api_server_utils.py
@@ -19,11 +19,29 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
+
+import ast
+import importlib
+import io
import os
import sys
import traceback
-from nest.lib.hl_api_exceptions import NESTError
+MODULES = os.environ.get("NEST_SERVER_MODULES", "import nest")
+
+
+class Capturing(list):
+ """Monitor stdout contents i.e. print."""
+
+ def __enter__(self):
+ self._stdout = sys.stdout
+ sys.stdout = self._stringio = io.StringIO()
+ return self
+
+ def __exit__(self, *args):
+ self.extend(self._stringio.getvalue().splitlines())
+ del self._stringio # free up some memory
+ sys.stdout = self._stdout
class ErrorHandler(Exception):
@@ -47,6 +65,43 @@ def to_dict(self):
return rv
+def clean_code(source):
+ codes = source.split("\n")
+ codes_cleaned = [] # noqa
+ for code in codes:
+ if code.startswith("import") or code.startswith("from"):
+ codes_cleaned.append("#" + code)
+ else:
+ codes_cleaned.append(code)
+ return "\n".join(codes_cleaned)
+
+
+def get_arguments(request):
+ """Get arguments from the request."""
+ args, kwargs = [], {}
+ if request.is_json:
+ json = request.get_json()
+ if isinstance(json, str) and len(json) > 0:
+ args = [json]
+ elif isinstance(json, list):
+ args = json
+ elif isinstance(json, dict):
+ kwargs = json
+ if "args" in kwargs:
+ args = kwargs.pop("args")
+ elif len(request.form) > 0:
+ if "args" in request.form:
+ args = request.form.getlist("args")
+ else:
+ kwargs = request.form.to_dict()
+ elif len(request.args) > 0:
+ if "args" in request.args:
+ args = request.args.getlist("args")
+ else:
+ kwargs = request.args.to_dict()
+ return list(args), kwargs
+
+
def get_boolean_environ(env_key, default_value="false"):
env_value = os.environ.get(env_key, default_value)
return env_value.lower() in ["yes", "true", "t", "1"]
@@ -65,35 +120,27 @@ def get_lineno(err, tb_idx):
return lineno
-def get_or_error(func):
- """Wrapper to exec function."""
-
- def func_wrapper(*args, **kwargs):
- try:
- return func(*args, **kwargs)
-
- except NESTError as err:
- error_class = err.errorname + " (NESTError)"
- detail = err.errormessage
- lineno = get_lineno(err, 1)
-
- except (KeyError, SyntaxError, TypeError, ValueError) as err:
- error_class = err.__class__.__name__
- detail = err.args[0]
- lineno = get_lineno(err, 1)
-
- except Exception as err:
- error_class = err.__class__.__name__
- detail = err.args[0]
- lineno = get_lineno(err, -1)
-
- for line in traceback.format_exception(*sys.exc_info()):
- print(line, flush=True)
-
- if lineno == -1:
- message = "%s: %s" % (error_class, detail)
- else:
- message = "%s at line %d: %s" % (error_class, lineno, detail)
- raise ErrorHandler(message, lineno)
-
- return func_wrapper
+def get_modules_from_env():
+ """Get modules from environment variable NEST_SERVER_MODULES.
+
+ This function converts the content of the environment variable NEST_SERVER_MODULES:
+ to a formatted dictionary for updating the Python `globals`.
+
+ Here is an example:
+ `NEST_SERVER_MODULES="import nest; import numpy as np; from numpy import random"`
+ is converted to the following dictionary:
+ `{'nest': 'np': , 'random': }`
+ """
+ modules = {}
+ try:
+ parsed = ast.iter_child_nodes(ast.parse(MODULES))
+ except (SyntaxError, ValueError):
+ raise SyntaxError("The NEST server module environment variables contains syntax errors.")
+ for node in parsed:
+ if isinstance(node, ast.Import):
+ for alias in node.names:
+ modules[alias.asname or alias.name] = importlib.import_module(alias.name)
+ elif isinstance(node, ast.ImportFrom):
+ for alias in node.names:
+ modules[alias.asname or alias.name] = importlib.import_module(f"{node.module}.{alias.name}")
+ return modules