diff --git a/CMakeLists.txt b/CMakeLists.txt index 29cfbecb..91ae3d29 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,6 +43,7 @@ option(WITH_HDF5 "Use HDF5 as a database back end" OFF) option(WITH_RMQ "Use RabbitMQ as a database back end (require a reachable and running RabbitMQ server service)" OFF) option(WITH_AMS_DEBUG "Enable verbose messages" OFF) option(WITH_PERFFLOWASPECT "Use PerfFlowAspect for Profiling" OFF) +option(WITH_WORKFLOW "Install python drivers used by the outer workflow" OFF) option(BUILD_SHARED_LIBS "Build using shared libraries" ON) if (WITH_MPI) @@ -193,39 +194,6 @@ if (WITH_TORCH) list(APPEND AMS_APP_DEFINES "-D__ENABLE_TORCH__") endif() -# ------------------------------------------------------------------------------ -if (WITH_RMQ) - if (WITH_CUDA) - add_compile_definitions(THRUST_IGNORE_CUB_VERSION_CHECK) - endif() - list(APPEND AMS_APP_DEFINES "-D__ENABLE_RMQ__") - - find_package(amqpcpp REQUIRED) - get_target_property(amqpcpp_INCLUDE_DIR amqpcpp INTERFACE_INCLUDE_DIRECTORIES) - list(APPEND AMS_APP_INCLUDES ${amqpcpp_INCLUDE_DIR}) - - find_package(OpenSSL REQUIRED) - if (OPENSSL_FOUND) - list(APPEND AMS_APP_INCLUDES ${OPENSSL_INCLUDE_DIR}) - list(APPEND AMS_APP_LIBRARIES "${OPENSSL_LIBRARIES}") - list(APPEND AMS_APP_LIBRARIES ssl) - message(STATUS "OpenSSL includes found: " ${OPENSSL_INCLUDE_DIR}) - message(STATUS "OpenSSL libraries found: " ${OPENSSL_LIBRARIES}) - else() - message(STATUS "OpenSSL Not Found") - endif() - - find_package(libevent REQUIRED) # event loop library - list(APPEND AMS_APP_INCLUDES ${LIBEVENT_INCLUDE_DIR}) - list(APPEND AMS_APP_LIBRARIES "${LIBEVENT_LIBRARIES}") - list(APPEND AMS_APP_LIBRARIES amqpcpp event_pthreads event) - # if (WITH_MPI) - # # Mandatory otherwise MPI_Init hangs forever for unknown reasons - # # probably linked to libevent, pthread ot OpenSSL - # list(APPEND AMS_APP_LIBRARIES MPI::MPI_CXX) - # endif() -endif() - # ------------------------------------------------------------------------------ if (WITH_FAISS) ## TODO: still need to create FindFaiss.cmake diff --git a/scripts/bootstrap_flux.sh b/scripts/bootstrap_flux.sh new file mode 100755 index 00000000..92b7d5f9 --- /dev/null +++ b/scripts/bootstrap_flux.sh @@ -0,0 +1,224 @@ +#!/usr/bin/env bash +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +usage="Usage: $(basename "$0") [#NODES] [JSON file] -- Script that bootstrap Flux on NNODES and writes Flux URIs to the JSON file." + +function version() { + echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; +} + +# Check if the allocation has the right size +# args: +# - $1 : number of nodes requested for Flux +function check_main_allocation() { + if [[ "$(flux getattr size)" -eq "$1" ]]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Flux launch successful with $(flux getattr size) nodes" + else + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Error: Requested nodes=$1 but Flux allocation size=$(flux getattr size)" + exit 1 + fi +} + +# Check if the 3 inputs are integers +function check_input_integers() { + re='^[0-9]+$' + if ! [[ $1 =~ $re && $2 =~ $re && $3 =~ $re ]] ; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Error: number of nodes is not an integer ($1, $2, $3)" + exit 1 + fi +} + +# Check if an allocation is running and if yes set the second parameter to the URI +# - $1 : Flux Job ID +# - $2 : the resulting URI +function check_allocation_running() { + local JOBID="$1" + local _result=$2 + local temp_uri='' + # NOTE: with more recent versions of Flux, instead of sed here we could use flux jobs --no-header + if [[ "$(flux jobs -o '{status}' $JOBID | sed -n '1!p')" == "RUN" ]]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Job $JOBID is running" + temp_uri=$(flux uri --remote $JOBID) + else + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Warning: failed to launch job ($JOBID)" + fi + eval $_result="'$temp_uri'" +} + +# Wait for a file to be created +# - $1 : the file +# - $2 : Max number of retry (one retry every 5 seconds) +function wait_for_file() { + local FLUX_SERVER="$1" + local EXIT_COUNTER=0 + local MAX_COUNTER="$2" + while [ ! -f $FLUX_SERVER ]; do + sleep 5s + echo "[$(date +'%m%d%Y-%T')@$(hostname)] $FLUX_SERVER does not exist yet." + exit_counter=$((EXIT_COUNTER + 1)) + if [ "$EXIT_COUNTER" -eq "$MAX_COUNTER" ]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Timeout: Failed to find file (${FLUX_SERVER})." + exit 1 + fi + done +} + +# ------------------------------------------------------------------------------ +# the script needs the number of nodes for flux +FLUX_NODES="$1" +# JSON configuration for AMS that will get updated by this script +AMS_JSON="$2" +FLUX_SERVER="ams-uri.log" +FLUX_LOG="ams-flux.log" +# Flux-core Minimum version required by AMS +[[ -z ${MIN_VER_FLUX+z} ]] && MIN_VER_FLUX="0.45.0" + +re='^[0-9]+$' +if ! [[ $FLUX_NODES =~ $re ]] ; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] ERROR: '$FLUX_NODES' is not a number." + echo $usage + exit 1 +fi + +if ! [[ -f "$AMS_JSON" ]]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Error: $AMS_JSON does not exists." + exit 1 +fi + +echo "[$(date +'%m%d%Y-%T')@$(hostname)] Launching Flux with $FLUX_NODES nodes" +echo "[$(date +'%m%d%Y-%T')@$(hostname)] Writing Flux configuration/URIs into $AMS_JSON" + +unset FLUX_URI +export LC_ALL="C" +export FLUX_F58_FORCE_ASCII=1 +export FLUX_SSH="ssh" +# Cleanup from previous runs +rm -f $FLUX_SERVER $FLUX_LOG + +if ! [[ -x "$(command -v flux)" ]]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Error: flux is not installed." + exit 1 +fi +echo "[$(date +'%m%d%Y-%T')@$(hostname)] flux = $(which flux)" + +flux_version=$(version $(flux version | awk '/^commands/ {print $2}')) +MIN_VER_FLUX_LONG=$(version ${MIN_VER_FLUX}) +# We need to remove leading 0 because they are interpreted as octal numbers in bash +if [[ "${flux_version#00}" -lt "${MIN_VER_FLUX_LONG#00}" ]]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Error Flux $(flux version | awk '/^commands/ {print $2}') is not supported.\ + AMS requires flux>=${MIN_VER_FLUX}" + exit 1 +fi + +echo "[$(date +'%m%d%Y-%T')@$(hostname)] flux version" +flux version + +# We create a Flux wrapper around sleep on the fly to get the main Flux URI +FLUX_SLEEP_WRAPPER="./$(mktemp flux-wrapper.XXXX.sh)" +cat << 'EOF' > $FLUX_SLEEP_WRAPPER +#!/usr/bin/env bash +echo "ssh://$(hostname)$(flux getattr local-uri | sed -e 's!local://!!')" > "$1" +sleep inf +EOF +chmod u+x $FLUX_SLEEP_WRAPPER + +MACHINE=$(echo $HOSTNAME | sed -e 's/[0-9]*$//') +if [[ "$MACHINE" == "lassen" ]] ; then + # To use module command we must source this file + # Those options are needed on IBM machines (CORAL) + # Documented: https://flux-framework.readthedocs.io/en/latest/tutorials/lab/coral.html + source /etc/profile.d/z00_lmod.sh + module use /usr/tce/modulefiles/Core + module use /usr/global/tools/flux/blueos_3_ppc64le_ib/modulefiles + module load pmi-shim + + PMIX_MCA_gds="^ds12,ds21" \ + jsrun -a 1 -c ALL_CPUS -g ALL_GPUS -n ${FLUX_NODES} \ + --bind=none --smpiargs="-disable_gpu_hooks" \ + flux start -o,-S,log-filename=$FLUX_LOG -v $FLUX_SLEEP_WRAPPER $FLUX_SERVER & +elif [[ "$MACHINE" == "pascal" || "$MACHINE" == "ruby" ]] ; then + srun -n ${FLUX_NODES} -N ${FLUX_NODES} --pty --mpi=none --mpibind=off \ + flux start -o,-S,log-filename=$FLUX_LOG -v $FLUX_SLEEP_WRAPPER $FLUX_SERVER & +else + echo "[$(date +'%m%d%Y-%T')@$(hostname)] machine $MACHINE is not supported at the moment." + exit 1 +fi + +echo "" +# now, wait for the flux info file +# we retry 20 times (one retry every 5 seconds) +wait_for_file $FLUX_SERVER 20 +export FLUX_URI=$(cat $FLUX_SERVER) +echo "[$(date +'%m%d%Y-%T')@$(hostname)] Run: export FLUX_URI=$(cat $FLUX_SERVER)" +check_main_allocation ${FLUX_NODES} + +# Read configuration file with number of nodes/cores for each sub allocations +NODES_PHYSICS=$(jq ".physics.nodes" $AMS_JSON) +NODES_ML=$(jq ".ml.nodes" $AMS_JSON) +NODES_CONTAINERS=$(jq ".containers.nodes" $AMS_JSON) +check_input_integers $NODES_PHYSICS $NODES_ML $NODES_CONTAINERS + +CORES_PHYSICS=$(jq ".physics.cores" $AMS_JSON) +CORES_ML=$(jq ".ml.cores" $AMS_JSON) +CORES_CONTAINERS=$(jq ".containers.cores" $AMS_JSON) +check_input_integers $CORES_PHYSICS $CORES_ML $CORES_CONTAINERS + +GPUS_PHYSICS=$(jq ".physics.gpus" $AMS_JSON) +GPUS_ML=$(jq ".ml.gpus" $AMS_JSON) +GPUS_CONTAINERS=$(jq ".containers.gpus" $AMS_JSON) +check_input_integers $GPUS_PHYSICS $GPUS_ML $GPUS_CONTAINERS + +# Partition resources for physics, ML and containers (RabbitMQ, filtering) +# NOTE: with more recent Flux (>=0.46), we could use flux alloc --bg instead +JOBID_PHYSICS=$( + flux mini batch --job-name="ams-physics" \ + --output="ams-physics-{{id}}.log" \ + --exclusive \ + --nslots=1 --nodes=$NODES_PHYSICS \ + --cores-per-slot=$CORES_PHYSICS \ + --gpus-per-slot=$GPUS_PHYSICS \ + --wrap sleep inf +) +sleep 2s +check_allocation_running $JOBID_PHYSICS FLUX_PHYSICS_URI + +JOBID_ML=$( + flux mini batch --job-name="ams-ml" \ + --output="ams-ml-{{id}}.log" \ + --exclusive \ + --nslots=1 --nodes=$NODES_ML\ + --cores-per-slot=$CORES_ML \ + --gpus-per-slot=$GPUS_ML \ + --wrap sleep inf +) +sleep 2s +check_allocation_running $JOBID_ML FLUX_ML_URI + +JOBID_CONTAINERS=$( + flux mini batch --job-name="ams-containers" \ + --output="ams-containers-{{id}}.log" \ + --nslots=1 --nodes=$NODES_CONTAINERS \ + --cores-per-slot=$CORES_CONTAINERS \ + --gpus-per-slot=$GPUS_CONTAINERS \ + --wrap sleep inf +) +sleep 2s +check_allocation_running $JOBID_CONTAINERS FLUX_CONTAINERS_URI + +# Add all URIs to existing AMS JSON file +AMS_JSON_BCK=${AMS_JSON}.bck +cp -f $AMS_JSON $AMS_JSON_BCK +jq '. += {flux:{}}' $AMS_JSON > $AMS_JSON_BCK && cp $AMS_JSON_BCK $AMS_JSON +jq --arg var "$(id -u)" '.flux += {"uid":$var}' $AMS_JSON > $AMS_JSON_BCK && cp $AMS_JSON_BCK $AMS_JSON +jq --arg flux_uri "$FLUX_URI" '.flux += {"global_uri":$flux_uri}' $AMS_JSON > $AMS_JSON_BCK && cp $AMS_JSON_BCK $AMS_JSON +jq --arg flux_uri "$FLUX_PHYSICS_URI" '.flux += {"physics_uri":$flux_uri}' $AMS_JSON > $AMS_JSON_BCK && cp $AMS_JSON_BCK $AMS_JSON +jq --arg flux_uri "$FLUX_ML_URI" '.flux += {"ml_uri":$flux_uri}' $AMS_JSON > $AMS_JSON_BCK && cp $AMS_JSON_BCK $AMS_JSON +jq --arg flux_uri "$FLUX_CONTAINERS_URI" '.flux += {"container_uri":$flux_uri}' $AMS_JSON > $AMS_JSON_BCK && cp $AMS_JSON_BCK $AMS_JSON + +# We move the file only if jq is sucessful otherwise jq will likey erase the original file +if [[ "$?" -eq 0 ]]; then + mv -f $AMS_JSON_BCK $AMS_JSON && rm -f $AMS_JSON_BCK +fi diff --git a/scripts/launch_ams.sh b/scripts/launch_ams.sh new file mode 100755 index 00000000..f3b5ad96 --- /dev/null +++ b/scripts/launch_ams.sh @@ -0,0 +1,99 @@ +#!/usr/bin/env bash +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +usage="Usage: $(basename "$0") [#NODES] [JSON file] -- Launch the AMS workflow on N nodes based on JSON configuration file." + +if [ -z ${AMS_ROOT+x} ]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Please export AMS_ROOT to where AMS repository is located." + echo "[$(date +'%m%d%Y-%T')@$(hostname)] with: export AMS_ROOT=" + exit 1 +fi + +# the script needs the number of nodes for flux +FLUX_NODES="$1" +# JSON configuration for AMS that will get updated by this script +AMS_JSON="$2" +UPDATE_SECRETS="${AMS_ROOT}/scripts/rmq_add_secrets.sh" +# Ssh bridge could be needed if OpenShift is not reachable from every clusters. +SSH_BRIDGE="quartz" +BOOTSTRAP="${AMS_ROOT}/scripts/bootstrap_flux.sh" +START_PHYSICS="${AMS_ROOT}/scripts/launch_physics.sh" + +re='^[0-9]+$' +if ! [[ $FLUX_NODES =~ $re ]] ; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] ERROR: '$FLUX_NODES' is not a number." + echo $usage + exit 1 +fi +if ! [[ -f "$AMS_JSON" ]]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Error: $AMS_JSON does not exists." + exit 1 +fi + +# Get the absolute path +AMS_JSON=$(realpath "$2") + +# Flux-core Minimum version required by AMS +export MIN_VER_FLUX="0.45.0" +export LC_ALL="C" +export FLUX_F58_FORCE_ASCII=1 +export FLUX_SSH="ssh" + +USE_DB=$(jq -r ".ams_app.use_db" $AMS_JSON) +DBTYPE=$(jq -r ".ams_app.dbtype" $AMS_JSON) +# 1. If we use RabbitMQ and asynchronous traning we add needed secrets OpenShift so AMS daemon can connect to RabbitMQ +# Note: +# This step might fail if you have not logged in OpenShift already. +# If that step fails, please try to log in OC with the following command +# oc login --insecure-skip-tls-verify=true --server=https://api.czapps.llnl.gov:6443 -u $(whoami) +if [[ $USE_DB && $DBTYPE = "rmq" ]]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Trying to update secrets on OpenShift" + ssh ${SSH_BRIDGE} bash <<-EOF + $UPDATE_SECRETS $AMS_JSON +EOF +fi + +echo "[$(date +'%m%d%Y-%T')@$(hostname)] Starting bootstraping Flux on $FLUX_NODES" +# 2. We bootstrap Flux on FLUX_NODES nodes +$BOOTSTRAP $FLUX_NODES $AMS_JSON + +RMQ_TMP="rmq.json" +CERT_TLS="rmq.cert" +# This require to install the AMS python package +AMS_BROKER_EXE="AMSBroker" + +# 3. We send the current UID and the Flux ML URI to the AMS daemon listening +if [[ $USE_DB && $DBTYPE = "rmq" ]]; then + RMQ_CONFIG=$(jq ".rabbitmq" $AMS_JSON) + echo $RMQ_CONFIG > $RMQ_TMP + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Extracted RabbitMQ configuration to ${RMQ_TMP}" + + REMOTE_HOST=$(jq -r '.rabbitmq."service-host"' $AMS_JSON) + REMOTE_PORT=$(jq -r '.rabbitmq."service-port"' $AMS_JSON) + openssl s_client -connect ${REMOTE_HOST}:${REMOTE_PORT} -showcerts < /dev/null 2>/dev/null | sed -ne '/-BEGIN CERTIFICATE-/,/-END CERTIFICATE-/p' > ${CERT_TLS} + if [[ "$?" -eq 0 ]]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Successfuly generated TLS certificate written in ${CERT_TLS}" + else + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Error during TLS certificate generation" + exit 1 + fi + AMS_DAEMON_QUEUE=$(jq -r '.daemon."queue-training-init"' $AMS_JSON) + AMS_UID=$(id -u) + AMS_ML_URI=$(jq -r '.flux.ml_uri' $AMS_JSON) + # Warning: there should be no whitespace in the message + MSG="{\"uid\":${AMS_UID},\"ml_uri\":\"${AMS_ML_URI}\"}" + ${AMS_BROKER_EXE} -c ${RMQ_TMP} -t ${CERT_TLS} -q ${AMS_DAEMON_QUEUE} -s $MSG + if [[ "$?" -eq 0 ]]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Successfuly sent message ${MSG} to queue ${AMS_DAEMON_QUEUE}" + else + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Error: message did not get send to RabbitMQ" + exit 1 + fi +fi + +# 4. We start the physics code +$START_PHYSICS $AMS_JSON +echo "[$(date +'%m%d%Y-%T')@$(hostname)] AMS workflow is ready to run!" diff --git a/scripts/launch_physics.sh b/scripts/launch_physics.sh new file mode 100755 index 00000000..37a917f7 --- /dev/null +++ b/scripts/launch_physics.sh @@ -0,0 +1,86 @@ +#!/usr/bin/env bash +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +usage="Usage: $(basename "$0") [JSON file] -- Script that launch AMS based on settings defined in the JSON file." +function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; } + +AMS_JSON="$1" +if ! [[ -f "$AMS_JSON" ]]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Error: $AMS_JSON does not exists." + exit 1 +fi + +# Flux-core Minimum version required by AMS +[[ -z ${MIN_VER_FLUX+z} ]] && MIN_VER_FLUX="0.45.0" + +export LC_ALL="C" +export FLUX_F58_FORCE_ASCII=1 +export FLUX_SSH="ssh" + +# We check that Flux exist +if ! [[ -x "$(command -v flux)" ]]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Error: flux is not installed." + exit 1 +fi + +flux_version=$(version $(flux version | awk '/^commands/ {print $2}')) +MIN_VER_FLUX_LONG=$(version ${MIN_VER_FLUX}) +# We need to remove leading 0 because they are interpreted as octal numbers in bash +if [[ "${flux_version#00}" -lt "${MIN_VER_FLUX_LONG#00}" ]]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Error: Flux $(flux version | awk '/^commands/ {print $2}') is not supported.\ + AMS requires flux>=${MIN_VER_FLUX}" + exit 1 +fi + +# The -r flag is important here to remove quotes around results from jq which confuse Flux +FLUX_URI=$(jq -r ".flux.global_uri" $AMS_JSON) +PHYSICS_URI=$(jq -r ".flux.physics_uri" $AMS_JSON) +NODES_PHYSICS=$(jq ".physics.nodes" $AMS_JSON) +EXEC=$(jq -r ".ams_app.executable" $AMS_JSON) +ML_PATH=$(jq -r ".ams_app.modelpath" $AMS_JSON) +USE_GPU=$(jq -r ".ams_app.use_gpu" $AMS_JSON) +USE_DB=$(jq -r ".ams_app.use_db" $AMS_JSON) +DBTYPE=$(jq -r ".ams_app.dbtype" $AMS_JSON) +MPI_RANKS=$(jq -r ".ams_app.mpi_ranks" $AMS_JSON) +# -1 for all debug messages, 0 for no debug messages +VERBOSE=$(jq -r ".ams_app.verbose" $AMS_JSON) + +AMS_ARGS="-S ${ML_PATH}" + +if $USE_GPU; then + AMS_ARGS="${AMS_ARGS} -d cuda" +fi + +if $USE_DB; then + OUTPUTS="output" + if [[ $DBTYPE = "csv" || $DBTYPE = "hdf5" ]]; then + mkdir -p $OUTPUTS + elif [[ $DBTYPE = "rmq" ]]; then + RMQ_CONFIG=$(jq ".rabbitmq" $AMS_JSON) + # We have to write that JSON for AMS app to work (AMS does not read from stdin) + OUTPUTS="${OUTPUTS}.json" + echo $RMQ_CONFIG > $OUTPUTS + fi + AMS_ARGS="${AMS_ARGS} -dt ${DBTYPE} -db ${OUTPUTS}" +fi + +echo "[$(date +'%m%d%Y-%T')@$(hostname)] Launching AMS on ${NODES_PHYSICS} nodes" +echo "[$(date +'%m%d%Y-%T')@$(hostname)] AMS binary = ${EXEC}" +echo "[$(date +'%m%d%Y-%T')@$(hostname)] AMS verbose level = ${VERBOSE}" +echo "[$(date +'%m%d%Y-%T')@$(hostname)] AMS Arguments = ${AMS_ARGS}" +echo "[$(date +'%m%d%Y-%T')@$(hostname)] MPI ranks = ${MPI_RANKS}" +echo "[$(date +'%m%d%Y-%T')@$(hostname)] > Cores/rank = 1" +echo "[$(date +'%m%d%Y-%T')@$(hostname)] > GPUs/rank = 1" + +ams_jobid=$( + LIBAMS_VERBOSITY_LEVEL=${VERBOSE} FLUX_URI=$PHYSICS_URI flux mini submit \ + --job-name="ams-app" \ + -N ${NODES_PHYSICS} -n $MPI_RANKS -c 1 -g 1 \ + -o mpi=spectrum -o cpu-affinity=per-task -o gpu-affinity=per-task \ + ${EXEC} ${AMS_ARGS} +) +echo "[$(date +'%m%d%Y-%T')@$(hostname)] Launched job $ams_jobid" +echo "[$(date +'%m%d%Y-%T')@$(hostname)] To debug: FLUX_URI=$PHYSICS_URI flux job attach $ams_jobid" diff --git a/scripts/rmq_add_secrets.sh b/scripts/rmq_add_secrets.sh new file mode 100755 index 00000000..56876694 --- /dev/null +++ b/scripts/rmq_add_secrets.sh @@ -0,0 +1,69 @@ +#!/usr/bin/env bash +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +usage="Usage: $(basename "$0") [JSON file] [USER (optional)] -- Script that adds the JSON file as secrets in OpenShift" +#TODO: Only for LC (add it somewhere as config value) +export PATH="$PATH:/usr/global/openshift/bin/" + +check_cmd() { + err=$($@ 2>&1) + if [ $? -ne 0 ]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] error: $err" + # if [[ -x "$(command -v oc)" ]]; then + # oc logout + # fi + exit 1 + else + echo $err + fi +} + +AMS_JSON="$1" +USER="$2" +[ -z "$2" ] && USER=$(whoami) # If no argument $2 we take the default user +URL="https://api.czapps.llnl.gov" +PORT=6443 +PROJECT_NAME="cz-amsdata" +RMQ_CREDS="creds.json" +SECRET="rabbitmq-creds" + +if ! [[ -f "$AMS_JSON" ]]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Error: config file \"$AMS_JSON\" does not exists." + exit 1 +fi + +if ! [[ -x "$(command -v oc)" ]]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Error: OpenShift (oc) not found." + exit 1 +fi +echo "[$(date +'%m%d%Y-%T')@$(hostname)] oc = $(which oc)" + +RMQ_CONFIG=$(jq ".rabbitmq" $AMS_JSON) +echo "$RMQ_CONFIG" > $RMQ_CREDS + +echo "[$(date +'%m%d%Y-%T')@$(hostname)] Login in ${URL}:${PORT} as ${USER}" +oc login --insecure-skip-tls-verify=true --server=${URL}:${PORT} -u ${USER} +# Warning: Do not use function check_cmd to wrap oc login here (it will block oc login) +if [[ "$?" -ne 0 ]]; then + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Error while connecting to OpenShift." + exit 1 +fi +echo "[$(date +'%m%d%Y-%T')@$(hostname)] Logged in as $(oc whoami), switching to project ${PROJECT_NAME}" +check_cmd oc project $PROJECT_NAME + +err=$(oc create secret generic $SECRET --from-file=$RMQ_CREDS 2>&1) + +if [[ "$?" -ne 0 ]]; then + check_cmd echo $err | grep "already exists" + echo "[$(date +'%m%d%Y-%T')@$(hostname)] secret already exists, we are updating it." + check_cmd oc delete secret $SECRET + check_cmd oc create secret generic $SECRET --from-file=$RMQ_CREDS +else + check_cmd oc get secrets $SECRET + echo "[$(date +'%m%d%Y-%T')@$(hostname)] Added secrets successfully." +fi + +# check_cmd oc logout \ No newline at end of file diff --git a/src/AMSWorkflow/CMakeLists.txt b/src/AMSWorkflow/CMakeLists.txt new file mode 100644 index 00000000..65f4ab69 --- /dev/null +++ b/src/AMSWorkflow/CMakeLists.txt @@ -0,0 +1,25 @@ +find_package(Python COMPONENTS Interpreter REQUIRED) + +add_subdirectory(app) +add_subdirectory(ams) + +configure_file("setup.py" "${CMAKE_CURRENT_BINARY_DIR}/setup.py" COPYONLY) + +file(GLOB_RECURSE pyfiles *.py app/*.py ams/*.py) + +# detect virtualenv and set Pip args accordingly +set(AMS_PY_APP "${CMAKE_CURRENT_BINARY_DIR}") +if(DEFINED ENV{VIRTUAL_ENV} OR DEFINED ENV{CONDA_PREFIX}) + set(_pip_args) +else() + set(_pip_args "--user") +endif() + +message(WARNING "AMS Python Source files are ${pyfiles}") +message(WARNING "AMS Python built cmd is : ${Python_EXECUTABLE} -m pip install ${_pip_args} ${AMS_PY_APP}") + +add_custom_target(PyAMS ALL + COMMAND ${Python_EXECUTABLE} -m pip install ${_pip_args} ${AMS_PY_APP} + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}" + COMMENT "Build AMS-WF Python Modules and Applications" + DEPENDS ${pyfiles}) diff --git a/src/AMSWorkflow/ams/CMakeLists.txt b/src/AMSWorkflow/ams/CMakeLists.txt new file mode 100644 index 00000000..a945efa6 --- /dev/null +++ b/src/AMSWorkflow/ams/CMakeLists.txt @@ -0,0 +1,11 @@ +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +file(GLOB_RECURSE pyfiles *.py) +foreach (filename ${pyfiles}) + get_filename_component(target "${filename}" NAME) + message(STATUS "Copying ${filename} to ${target}") + configure_file("${filename}" "${CMAKE_CURRENT_BINARY_DIR}/${target}" COPYONLY) +endforeach (filename) diff --git a/src/AMSWorkflow/ams/database.py b/src/AMSWorkflow/ams/database.py new file mode 100644 index 00000000..7ccf3ed6 --- /dev/null +++ b/src/AMSWorkflow/ams/database.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from abc import ABC, abstractmethod +import logging +import csv +import numpy as np +import os +import sys +from typing import Dict, List, Any, Union, Type + + +class DBInterface(ABC): + """ + Represents a database instance in AMS. + """ + + @classmethod + def __subclasshook__(cls, subclass): + """ + Ensure subclass implement all the abstrac method + defined in the interface. Errors will be raised + if all methods aren't overridden. + """ + return (hasattr(subclass, "__str__") and + callable(subclass.__str__) and + hasattr(subclass, "open") and + callable(subclass.open) and + hasattr(subclass, "close") and + callable(subclass.close) and + hasattr(subclass, "store") and + callable(subclass.store) or + NotImplemented) + + @abstractmethod + def __str__(self) -> str: + """ Return a string representation of the broker """ + raise NotImplementedError + + def __repr__(self) -> str: + """ Return a string representation of the broker """ + return self.__str__() + + @abstractmethod + def open(self): + """ Connect to the DB (or open file if file-based DB) """ + raise NotImplementedError + + @abstractmethod + def close(self): + """ Close DB """ + raise NotImplementedError + + @abstractmethod + def store(self, inputs, outputs) -> int: + """ + Store the two arrays using a given backend + Return the number of characters written + """ + raise NotImplementedError + +class csvDB(DBInterface): + """ + A simple CSV backend. + """ + def __init__(self, file_name: str, delimiter: str = ':'): + super().__init__() + self.file_name = file_name + self.delimiter = delimiter + self.fd = None + + def __str__(self) -> str: + return f"{__class__.__name__}(fd={self.fd}, delimiter={self.delimiter})" + + def open(self): + self.fd = open(self.file_name, 'a') + + def close(self): + self.fd.close() + + def __enter__(self): + self.open() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def store(self, inputs: np.array, outputs: np.array) -> int: + """ Store the two arrays in a CSV file """ + assert len(inputs) == len(outputs) + if self.fd and self.fd.closed: + return 0 + csvwriter = csv.writer(self.fd, + delimiter = self.delimiter, + quotechar = "'", + quoting = csv.QUOTE_MINIMAL + ) + nelem = len(inputs) + elem_wrote: int = 0 + # We follow the mini-app format, inputs elem and then output elems + for i in range(nelem): + elem_wrote += csvwriter.writerow(np.concatenate((inputs[i], outputs[i]), axis=0)) + return elem_wrote \ No newline at end of file diff --git a/src/AMSWorkflow/ams/orchestrator.py b/src/AMSWorkflow/ams/orchestrator.py new file mode 100644 index 00000000..a2619d7e --- /dev/null +++ b/src/AMSWorkflow/ams/orchestrator.py @@ -0,0 +1,103 @@ +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import flux +import json +from flux.security import SecurityContext +from flux import job +from flux.job import JobspecV1 + +from .ams_rmq import RMQClient + + +class Orchestrator: + def __init__(self, server_config_fn, certificate): + with open(server_config_fn, "r") as fd: + server_config = json.load(fd) + + self.host = (server_config["service-host"],) + self.vhost = (server_config["rabbitmq-vhost"],) + self.port = (server_config["service-port"],) + self.user = (server_config["rabbitmq-user"],) + self.password = server_config["rabbitmq-password"] + self.certificate = certificate + +class AMSDaemon(Orchestrator): + """ + Class modeling a rmq-client daemon running on a compute + allocation that will issue flux run commands + """ + + def __init__(self, server_config_fn, certificate): + super().__init__(server_config_fn, certificate) + + def getMLJobSpec(self, client): + with client.connect("test3") as channel: + spec = channel.receive(n_msg=1).pop() + # TODO: Write some simple wrapper class around the 'dict' + # to correcly create ML Job Specification + return spec + + def __run(self, flux_handle, client, jobspec): + with client.connect("ml-start") as channel: + while True: + # Currently we ignore the message + channel.receive(n_msg=1) + jobid = flux.job.submit(flux_handle, jobspec, pre_signed=True, wait=True) + job.wait() + # TODO: Send completed message in RMQ + + def __call__(self): + flux_handle = flux.Flux() + + with RMQClient(self.host, self.port, self.vhost, self.user, self.password, self.certificate) as client: + # We currently assume a single ML job specification + ml_job_spec = self.getMLJobSpec(client) + + # Create a Flux Job Specification + jobspec = JobspecV1.from_command( + command=ml_job_spec["jobspec"]["command"], + num_tasks=ml_job_spec["jobspec"]["num_tasks"], + num_nodes=ml_job_spec["jobspec"]["num_nodes"], + cores_per_task=ml_job_spec["jobspec"]["cores_per_task"], + gpus_per_task=ml_job_spec["jobspec"]["gpus_per_task"], + ) + + ctx = SecurityContext() + signed_jobspec = ctx.sign_wrap_as( + ml_job_spec["uid"], jobspec.dumps(), mech_type="none" + ).decode("utf-8") + # This is a 'busy' loop + self.__run(flux_handle, client, signed_jobspec) + +class FluxDaemonWrapper(Orchestrator): + """ + class to start Daemon through Flux + """ + + def __init__(self, server_config_fn, certificate): + super().__init__(server_config_fn, certificate) + + def getFluxUri(self, client): + with client.connect("test3") as channel: + msg = channel.receive(n_msg=1).pop() + return msg['ml_uri'] + + def __call__(self, application_cmd : list): + if not isinstance(application_cmd, list): + raise TypeError('StartDaemon requires application_cmd as a list') + + with RMQClient(self.host, self.port, self.vhost, self.user, self.password, self.certificate) as client: + self.uri = getFluxUri(client) + + flux_cmd = [ + "flux", + "proxy", + "--force", + f"{self.uri}", + "flux"] + cmd = flux_cmd + application_cmd + subprocess.run(cmd) + diff --git a/src/AMSWorkflow/ams/rmq.py b/src/AMSWorkflow/ams/rmq.py new file mode 100644 index 00000000..2429a3ca --- /dev/null +++ b/src/AMSWorkflow/ams/rmq.py @@ -0,0 +1,143 @@ +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pika +import ssl +import sys +import os +import logging +import json + +class RMQChannel: + """ + A wrapper around RMQ channel + """ + + def __init__(self, connection, q_name): + self.connection = connection + self.q_name = q_name + + def __enter__(self): + self.open() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + @staticmethod + def callback(method, properties, body): + return body.decode("utf-8") + + def open(self): + self.channel = self.connection.channel() + self.channel.queue_declare(queue = self.q_name) + + def close(self): + self.channel.close() + + def receive(self, n_msg: int = None, accum_msg = list()): + """ + Consume a message on the queue and post processing by calling the callback. + @param n_msg The number of messages to receive. + - if n_msg is None, this call will block for ever and will process all messages that arrives + - if n_msg = 1 for example, this function will block until one message has been processed. + @return a list containing all received messages + """ + + if self.channel and self.channel.is_open: + self.logger.info( + f"Starting to consume messages from queue={self.q_name}, routing_key={self.routing_key} ..." + ) + # we will consume only n_msg and requeue all other messages + # if there are more messages in the queue. + # It will block as long as n_msg did not get read + if n_msg: + n_msg = max(n_msg, 0) + message_consumed = 0 + # Comsume n_msg messages and break out + for method_frame, properties, body in self.channel.consume(self.q_name): + # Call the call on the message parts + try: + accum_msg.append( + RMQClient.callback( + method_frame, + properties, + body, + ) + ) + except Exception as e: + self.logger.error(f"Exception {type(e)}: {e}") + self.logger.debug(traceback.format_exc()) + finally: + # Acknowledge the message even on failure + self.channel.basic_ack(delivery_tag=method_frame.delivery_tag) + self.logger.warning( + f"Consumed message {message_consumed+1}/{method_frame.delivery_tag} (exchange={method_frame.exchange}, routing_key={method_frame.routing_key})" + ) + message_consumed += 1 + # Escape out of the loop after nb_msg messages + if message_consumed == n_msg: + # Cancel the consumer and return any pending messages + self.channel.cancel() + break + return accum_msg + + def send(self, text: str): + """ + Send a message + @param text The text to send + """ + self.channel.basic_publish(exchange="", routing_key=self.q_name, body=text) + return + + def get_messages(self): + return # messages + + def purge(self): + """Removes all the messages from the queue.""" + if self.channel and self.channel.is_open: + self.channel.queue_purge(self.q_name) + + +class RMQClient: + """ + RMQClient is a class that manages the RMQ client lifecycle. + """ + def __init__(self, host, port, vhost, user, password, cert, logger: logging.Logger = None): + # CA Cert, can be generated with (where $REMOTE_HOST and $REMOTE_PORT can be found in the JSON file): + # openssl s_client -connect $REMOTE_HOST:$REMOTE_PORT -showcerts < /dev/null 2>/dev/null | sed -ne '/-BEGIN CERTIFICATE-/,/-END CERTIFICATE-/p' rmq-pds.crt + self.logger = logger if logger else logging.getLogger(__name__) + self.cert = cert + self.context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + self.context.verify_mode = ssl.CERT_REQUIRED + self.context.load_verify_locations(self.cert) + self.host = host + self.vhost = vhost + self.port = port + self.user = user + self.password = password + + self.credentials = pika.PlainCredentials( + self.user, self.password) + + self.connection_params = pika.ConnectionParameters( + host=self.host, + port=self.port, + virtual_host=self.vhost, + credentials=self.credentials, + ssl_options=pika.SSLOptions(self.context), + ) + + def __enter__(self): + self.connection = pika.BlockingConnection(self.connection_params) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.connection.close() + + def connect(self, queue): + """Connect to the queue""" + return RMQChannel(self.connection, queue) + diff --git a/src/AMSWorkflow/ams_wf/AMSBroker.py b/src/AMSWorkflow/ams_wf/AMSBroker.py new file mode 100644 index 00000000..2f981198 --- /dev/null +++ b/src/AMSWorkflow/ams_wf/AMSBroker.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import json +import os +import sys + +from ams.rmq import RMQClient + +def main(): + parser = argparse.ArgumentParser( + description="AMS Broker interface to send/receive messages." + ) + + parser.add_argument( + "-c", + "--config", + dest="config", + help="Path to broker configuration file", + required=True, + ) + + parser.add_argument( + "-t", + "--certificate", + dest="certificate", + help="Path to TLS certificate file", + required=True, + ) + + parser.add_argument( + "-s", + "--send", + dest="msg_send", + type=str, + help="Message to send", + required=True, + ) + + parser.add_argument( + "-q", + "--queue", + dest="queue", + type=str, + help="Queue to which the message will be sent", + required=True, + ) + args = parser.parse_args() + + if not os.path.isfile(args.config): + print(f"Error: config file {args.config} does not exist") + sys.exit(1) + + if not os.path.isfile(args.certificate): + print(f"Error: certificate file {args.certificate} does not exist") + sys.exit(1) + + with open(args.config, "r") as fd: + config = json.load(fd) + + host = config["service-host"] + vhost = config["rabbitmq-vhost"] + port = config["service-port"] + user = config["rabbitmq-user"] + password = config["rabbitmq-password"] + + with RMQClient(host, port, vhost, user, password, args.certificate) as client: + with client.connect(args.queue) as channel: + channel.send(args.msg_send) + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("Interrupted") + try: + sys.exit(0) + except SystemExit: + os._exit(0) diff --git a/src/AMSWorkflow/ams_wf/AMSDBStage.py b/src/AMSWorkflow/ams_wf/AMSDBStage.py new file mode 100644 index 00000000..97cd2522 --- /dev/null +++ b/src/AMSWorkflow/ams_wf/AMSDBStage.py @@ -0,0 +1,15 @@ +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +from pathlib import Path + + +def main(): + print("Place Holder for AMSDBStage.py") + + +if __name__ == "__main__": + main() diff --git a/src/AMSWorkflow/ams_wf/AMSOrchestrator.py b/src/AMSWorkflow/ams_wf/AMSOrchestrator.py new file mode 100755 index 00000000..210598a1 --- /dev/null +++ b/src/AMSWorkflow/ams_wf/AMSOrchestrator.py @@ -0,0 +1,73 @@ +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import sys +import os +from ams.rmq import RMQClient +from ams.orchestrator import AMSDaemon +from ams.orchestrator import FluxDaemonWrapper + +import argparse + + +def main(): + daemon_actions = ["start", "wrap"] + parser = argparse.ArgumentParser( + description="AMS Machine Learning Daemon running on Training allocation" + ) + parser.add_argument( + "-a", + "--action", + dest='action', + choices=daemon_actions, + help="Decide whether to start daemon process directly or through flux wrap script", + required=True, + ) + + parser.add_argument( + "-c", + "--certificate", + dest="certificate", + help="Path to certificate file to establish connection", + required=True, + ) + + parser.add_argument( + "-cfg", + "--config", + dest="config", + help="Path to AMS configuration file", + required=True, + ) + + args = parser.parse_args() + if args.action == "start": + daemon = AMSDaemon(args.config, args.certificate) + # Busy wait for messages and spawn ML training jobs + daemon() + elif args.action == "wrap": + daemon_cmd = [ + "python", + __file__, + "--action", + "start", + "-c", + args.certificate, + "-cfg", + args.config, + ] + daemon = FluxDaemonWrapper(args.config, args.certificate) + daemon(daemon_cmd) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("Interrupted") + try: + sys.exit(0) + except SystemExit: + os._exit(0) diff --git a/src/AMSWorkflow/ams_wf/AMSTrain.py b/src/AMSWorkflow/ams_wf/AMSTrain.py new file mode 100644 index 00000000..4252f468 --- /dev/null +++ b/src/AMSWorkflow/ams_wf/AMSTrain.py @@ -0,0 +1,12 @@ +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse + +def main(): + print('Hello from AMSTrain.py') + +if __name__ == '__main__': + main() diff --git a/src/AMSWorkflow/ams_wf/CMakeLists.txt b/src/AMSWorkflow/ams_wf/CMakeLists.txt new file mode 100644 index 00000000..9102fad1 --- /dev/null +++ b/src/AMSWorkflow/ams_wf/CMakeLists.txt @@ -0,0 +1,12 @@ +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +file(GLOB_RECURSE pyfiles *.py) +foreach (filename ${pyfiles}) + get_filename_component(target "${filename}" NAME) + message(STATUS "Copying ${filename} to ${target}") + configure_file("${filename}" "${CMAKE_CURRENT_BINARY_DIR}/${target}" COPYONLY) +endforeach (filename) + diff --git a/src/AMSWorkflow/setup.py b/src/AMSWorkflow/setup.py new file mode 100644 index 00000000..5f1020c3 --- /dev/null +++ b/src/AMSWorkflow/setup.py @@ -0,0 +1,30 @@ +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import setuptools + +setuptools.setup( + name="ams-wf", + version="1.0", + packages=['ams_wf', 'ams'], + install_requires = [ + 'argparse', + 'pika>=1.3.0', + 'numpy>=1.2.0' + ], + entry_points={ + 'console_scripts': [ + 'AMSBroker=ams_wf.AMSBroker:main', + 'AMSDBStage=ams_wf.AMSDBStage:main', + 'AMSOrchestrator=ams_wf.AMSOrchestrator:main', + 'AMSTrain=ams_wf.AMSTrain:main'] + }, + classifiers = [ + "Development Status :: 3 - Alpha", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + ] +) diff --git a/src/AMS.cpp b/src/AMSlib/AMS.cpp similarity index 100% rename from src/AMS.cpp rename to src/AMSlib/AMS.cpp diff --git a/src/AMSlib/CMakeLists.txt b/src/AMSlib/CMakeLists.txt new file mode 100644 index 00000000..1e1198a8 --- /dev/null +++ b/src/AMSlib/CMakeLists.txt @@ -0,0 +1,82 @@ +# Copyright (c) Lawrence Livermore National Security, LLC and other AMS +# Project developers. See top-level LICENSE AND COPYRIGHT files for dates and +# other details. No copyright assignment is required to contribute + +# ------------------------------------------------------------------------------ +# handle sources and headers +file(GLOB_RECURSE MINIAPP_INCLUDES "*.hpp") +#set global library path to link with tests if necessary +set(LIBRARY_OUTPUT_PATH ${AMS_LIB_OUT_PATH}) +set(AMS_LIB_SRC ${MINIAPP_INCLUDES} AMS.cpp wf/resource_manager.cpp) +# two targets: a shared lib and an exec +add_library(AMS ${AMS_LIB_SRC} ${MINIAPP_INCLUDES}) + +# ------------------------------------------------------------------------------ +if (WITH_CUDA) + set_target_properties(AMS PROPERTIES CUDA_ARCHITECTURES ${AMS_CUDA_ARCH}) + + # if (BUILD_SHARED_LIBS) + # set_target_properties(AMS PROPERTIES CUDA_SEPARABLE_COMPILATION ON) + # else() + # set_target_properties(AMS PROPERTIES CUDA_SEPARABLE_COMPILATION ON) + # set_target_properties(AMS PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON) + # endif() + + set_source_files_properties(AMS.cpp PROPERTIES LANGUAGE CUDA) + set_source_files_properties(AMS.cpp PROPERTIES CUDA_ARCHITECTURES ${AMS_CUDA_ARCH}) + set_source_files_properties(AMS.cpp PROPERTIES COMPILE_FLAGS "--expt-extended-lambda") + + if (WITH_PERFFLOWASPECT) + set_property(SOURCE AMS.cpp APPEND_STRING PROPERTY COMPILE_FLAGS " -ccbin clang++ -Xcompiler=-Xclang -Xcompiler=-load -Xcompiler=-Xclang -Xcompiler=${PERFFLOWASPECT_LIB_DIR}/libWeavePass.so") + set_source_files_properties(wf/resource_manager.cpp COMPILE_FLAGS "-Xclang -load -Xclang ${PERFFLOWASPECT_LIB_DIR}/libWeavePass.so") + endif() +endif() + +# ------------------------------------------------------------------------------ +# setup the lib first +message(STATUS "ALL INCLUDES ARE ${AMS_APP_INCLUDES}") +target_compile_definitions(AMS PRIVATE ${AMS_APP_DEFINES}) +target_include_directories(AMS PRIVATE ${AMS_APP_INCLUDES}) +target_include_directories(AMS PUBLIC + $ + $) +target_include_directories(AMS PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_directories(AMS PUBLIC ${AMS_APP_LIB_DIRS}) +target_link_libraries(AMS PUBLIC ${AMS_APP_LIBRARIES} stdc++fs) + +#------------------------------------------------------------------------------- +# create the configuration header file with the respective information +#------------------------------------------------------------------------------- +set(CALIPER_DEFINES "// #define __AMS_ENABLE_CALIPER__") +set(MPI_DEFINES "// #define __AMS_ENABLE_MPI__") +set(PERFF_DEFINES "// #define __AMS_ENABLE_PERFFLOWASPECT__") + +if (${WITH_CALIPER}) + set(CALIPER_DEFINES "#define __AMS_ENABLE_CALIPER__") +endif() + +if (${WITH_MPI}) + set(MPI_DEFINES "#define __AMS_ENABLE_MPI__") +endif() + +if (${WITH_PERFFLOWASPECT}) + set(PERFF_DEFINES "#define __AMS_ENABLE_PERFFLOWASPECT__") +endif() + +configure_file ("${CMAKE_CURRENT_SOURCE_DIR}/include/AMS-config.h.in" "${PROJECT_BINARY_DIR}/include/AMS-config.h") +configure_file ("${CMAKE_CURRENT_SOURCE_DIR}/include/AMS.h" "${PROJECT_BINARY_DIR}/include/AMS.h" COPYONLY) + +# setup the exec +#SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,-rpath -Wl,$ORIGIN") +# ------------------------------------------------------------------------------ +# installation paths +install(TARGETS AMS + EXPORT AMSTargets + DESTINATION lib) + +install(EXPORT AMSTargets + FILE AMS.cmake + DESTINATION lib/cmake/AMS) + +install(FILES ${PROJECT_BINARY_DIR}/include/AMS.h DESTINATION include) +install(FILES ${PROJECT_BINARY_DIR}/include/AMS-config.h DESTINATION include) diff --git a/src/include/AMS-config.h.in b/src/AMSlib/include/AMS-config.h.in similarity index 100% rename from src/include/AMS-config.h.in rename to src/AMSlib/include/AMS-config.h.in diff --git a/src/include/AMS.h b/src/AMSlib/include/AMS.h similarity index 100% rename from src/include/AMS.h rename to src/AMSlib/include/AMS.h diff --git a/src/ml/hdcache.hpp b/src/AMSlib/ml/hdcache.hpp similarity index 100% rename from src/ml/hdcache.hpp rename to src/AMSlib/ml/hdcache.hpp diff --git a/src/ml/surrogate.hpp b/src/AMSlib/ml/surrogate.hpp similarity index 100% rename from src/ml/surrogate.hpp rename to src/AMSlib/ml/surrogate.hpp diff --git a/src/AMSlib/wf/basedb.hpp b/src/AMSlib/wf/basedb.hpp new file mode 100644 index 00000000..4fe10878 --- /dev/null +++ b/src/AMSlib/wf/basedb.hpp @@ -0,0 +1,1909 @@ +/* + * Copyright 2021-2023 Lawrence Livermore National Security, LLC and other + * AMSLib Project Developers + * + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + */ + +#ifndef __AMS_BASE_DB__ +#define __AMS_BASE_DB__ + +#include +#include +#include +#include +#include +#include + +#include "AMS.h" +#include "wf/debug.h" +#include "wf/utils.hpp" + +namespace fs = std::experimental::filesystem; + +#ifdef __ENABLE_REDIS__ +#include + +#include +// TODO: We should comment out "using" in header files as +// it propagates to every other file including this file +#warning Redis is currently not supported/tested +using namespace sw::redis; +#endif + + +#ifdef __ENABLE_HDF5__ +#include + +#define HDF5_ERROR(Eid) \ + if (Eid < 0) { \ + std::cerr << "[Error] Happened in " << __FILE__ << ":" \ + << __PRETTY_FUNCTION__ << " ( " << __LINE__ << ")\n"; \ + exit(-1); \ + } +#endif + +#ifdef __ENABLE_RMQ__ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#endif // __ENABLE_RMQ__ + +/** + * @brief A simple pure virtual interface to store data in some + * persistent storage device + */ +template +class BaseDB +{ + /** @brief unique id of the process running this simulation */ + uint64_t id; + +public: + BaseDB(const BaseDB&) = delete; + BaseDB& operator=(const BaseDB&) = delete; + + BaseDB(uint64_t id) : id(id) {} + virtual ~BaseDB() {} + + /** + * @brief Define the type of the DB (File, Redis etc) + */ + virtual std::string type() = 0; + + /** + * @brief Takes an input and an output vector each holding 1-D vectors data, and + * store. them in persistent data storage. + * @param[in] num_elements Number of elements of each 1-D vector + * @param[in] inputs Vector of 1-D vectors containing the inputs to be stored + * @param[in] inputs Vector of 1-D vectors, each 1-D vectors contains + * 'num_elements' values to be stored + * @param[in] outputs Vector of 1-D vectors, each 1-D vectors contains + * 'num_elements' values to be stored + */ + virtual void store(size_t num_elements, + std::vector& inputs, + std::vector& outputs) = 0; +}; + +/** + * @brief A pure virtual interface for data bases storing data using + * some file format (filesystem DB). + */ +template +class FileDB : public BaseDB +{ +protected: + /** @brief Path to file to write data to */ + std::string fn; + /** @brief absolute path to directory storing the data */ + std::string fp; + + /** + * @brief check error code, if it exists print message and exit application + * @param[in] ec error code + */ + void checkError(std::error_code& ec) + { + if (ec) { + std::cerr << "Error in is_regular_file: " << ec.message(); + exit(-1); + } + } + +public: + /** + * @brief Takes an input and an output vector each holding 1-D vectors data, and + * store. them in persistent data storage. + * @param[in] path Path to an existing directory where to store our data + * @param[in] suffix The suffix of the file to write to + * @param[in] rId a unique Id for each process taking part in a distributed + * execution (rank-id) + * */ + FileDB(std::string path, const std::string suffix, uint64_t rId) + : BaseDB(rId) + { + fs::path Path(path); + std::error_code ec; + + if (!fs::exists(Path, ec)) { + std::cerr << "[ERROR]: Path:'" << path << "' does not exist\n"; + exit(-1); + } + + checkError(ec); + + if (!fs::is_directory(Path, ec)) { + std::cerr << "[ERROR]: Path:'" << path << "' is a file NOT a directory\n"; + exit(-1); + } + + Path = fs::absolute(Path); + fp = Path.string(); + + // We can now create the filename + std::string dbfn("data_"); + dbfn += std::to_string(rId) + suffix; + Path /= fs::path(dbfn); + fn = Path.string(); + DBG(DB, "File System DB writes to file %s", fn.c_str()) + } +}; + + +template +class csvDB final : public FileDB +{ +private: + /** @brief file descriptor */ + std::fstream fd; + +public: + csvDB(const csvDB&) = delete; + csvDB& operator=(const csvDB&) = delete; + + /** + * @brief constructs the class and opens the file to write to + * @param[in] fn Name of the file to store data to + * @param[in] rId a unique Id for each process taking part in a distributed + * execution (rank-id) + */ + csvDB(std::string path, uint64_t rId) : FileDB(path, ".csv", rId) + { + fd.open(this->fn, std::ios_base::app | std::ios_base::out); + if (!fd.is_open()) { + std::cerr << "Cannot open db file: " << this->fn << std::endl; + } + DBG(DB, "DB Type: %s", type().c_str()) + } + + /** + * @brief deconstructs the class and closes the file + */ + ~csvDB() { fd.close(); } + + /** + * @brief Define the type of the DB (File, Redis etc) + */ + std::string type() override { return "csv"; } + + /** + * @brief Takes an input and an output vector each holding 1-D vectors data, and + * store them into a csv file delimited by ':'. This should never be used for + * large scale simulations as txt/csv format will be extremely slow. + * @param[in] num_elements Number of elements of each 1-D vector + * @param[in] inputs Vector of 1-D vectors containing the inputs to bestored + * @param[in] inputs Vector of 1-D vectors, each 1-D vectors contains + * 'num_elements' values to be stored + * @param[in] outputs Vector of 1-D vectors, each 1-D vectors contains + * 'num_elements' values to be stored + */ + PERFFASPECT() + virtual void store(size_t num_elements, + std::vector& inputs, + std::vector& outputs) override + { + DBG(DB, + "DB of type %s stores %ld elements of input/output dimensions (%d, %d)", + type().c_str(), + num_elements, + inputs.size(), + outputs.size()) + + const size_t num_in = inputs.size(); + const size_t num_out = outputs.size(); + + for (size_t i = 0; i < num_elements; i++) { + for (size_t j = 0; j < num_in; j++) { + fd << inputs[j][i] << ":"; + } + + for (size_t j = 0; j < num_out - 1; j++) { + fd << outputs[j][i] << ":"; + } + fd << outputs[num_out - 1][i] << "\n"; + } + } +}; + +#ifdef __ENABLE_HDF5__ + +template +class hdf5DB final : public FileDB +{ +private: + /** @brief file descriptor */ + hid_t HFile; + /** @brief vector holding the hdf5 dataset descriptor. + * We currently store every input on a separate dataset + */ + std::vector HDIsets; + + /** @brief vector holding the hdf5 dataset descriptor. + * We currently store every output on a separate dataset + */ + std::vector HDOsets; + + /** @brief Total number of elements we have in our file */ + hsize_t totalElements; + + /** @brief HDF5 associated data type with specific TypeValue type */ + hid_t HDType; + + /** @brief create or get existing hdf5 dataset with the provided name + * storing data as Ckunked pieces. The Chunk value controls the chunking + * performed by HDF5 and thus controls the write performance + * @param[in] group in which we will store data under + * @param[in] dName name of the data set + * @param[in] Chunk chunk size of dataset used by HDF5. + * @reval dataset HDF5 key value + */ + hid_t getDataSet(hid_t group, + std::string dName, + const size_t Chunk = 32L * 1024L * 1024L) + { + // Our datasets a.t.m are 1-D vectors + const int nDims = 1; + // We always start from 0 + hsize_t dims = 0; + hid_t dset = -1; + + int exists = H5Lexists(group, dName.c_str(), H5P_DEFAULT); + + if (exists > 0) { + dset = H5Dopen(group, dName.c_str(), H5P_DEFAULT); + HDF5_ERROR(dset); + // We are assuming symmetrical data sets a.t.m + if (totalElements == 0) { + hid_t dspace = H5Dget_space(dset); + const int ndims = H5Sget_simple_extent_ndims(dspace); + hsize_t dims[ndims]; + H5Sget_simple_extent_dims(dspace, dims, NULL); + totalElements = dims[0]; + } + return dset; + } else { + // We will extend the data-set size, so we use unlimited option + hsize_t maxDims = H5S_UNLIMITED; + hid_t fileSpace = H5Screate_simple(nDims, &dims, &maxDims); + HDF5_ERROR(fileSpace); + + hid_t pList = H5Pcreate(H5P_DATASET_CREATE); + HDF5_ERROR(pList); + + herr_t ec = H5Pset_layout(pList, H5D_CHUNKED); + HDF5_ERROR(ec); + + // cDims impacts performance considerably. + // TODO: Align this with the caching mechanism for this option to work + // out. + hsize_t cDims = Chunk; + H5Pset_chunk(pList, nDims, &cDims); + dset = H5Dcreate(group, + dName.c_str(), + HDType, + fileSpace, + H5P_DEFAULT, + pList, + H5P_DEFAULT); + HDF5_ERROR(dset); + H5Sclose(fileSpace); + H5Pclose(pList); + } + return dset; + } + + /** + * @brief Create the HDF5 datasets and store their descriptors in the in/out + * vectors + * @param[in] num_elements of every vector + * @param[in] numIn number of input 1-D vectors + * @param[in] numOut number of output 1-D vectors + */ + void createDataSets(size_t numElements, + const size_t numIn, + const size_t numOut) + { + for (int i = 0; i < numIn; i++) { + hid_t dSet = getDataSet(HFile, std::string("input_") + std::to_string(i)); + HDIsets.push_back(dSet); + } + + for (int i = 0; i < numOut; i++) { + hid_t dSet = + getDataSet(HFile, std::string("output_") + std::to_string(i)); + HDOsets.push_back(dSet); + } + } + + /** + * @brief Write all the data in the vectors in the respective datasets. + * @param[in] dsets Vector containing the hdf5-dataset descriptor for every + * vector to be written + * @param[in] data vectors containing 1-D vectors of numElements values each + * to be written in the db. + * @param[in] numElements The number of elements each vector has + */ + void writeDataToDataset(std::vector& dsets, + std::vector& data, + size_t numElements) + { + int index = 0; + for (auto* I : data) { + writeVecToDataset(dsets[index++], static_cast(I), numElements); + } + } + + /** @brief Writes a single 1-D vector to the dataset + * @param[in] dSet the dataset to write the data to + * @param[in] data the data we need to write + * @param[in] elements the number of data elements we have + */ + void writeVecToDataset(hid_t dSet, void* data, size_t elements) + { + const int nDims = 1; + hsize_t dims = elements; + hsize_t start; + hsize_t count; + hid_t memSpace = H5Screate_simple(nDims, &dims, NULL); + HDF5_ERROR(memSpace); + + dims = totalElements + elements; + H5Dset_extent(dSet, &dims); + + hid_t fileSpace = H5Dget_space(dSet); + HDF5_ERROR(fileSpace); + + // Data set starts at offset totalElements + start = totalElements; + // And we append additional elements + count = elements; + // Select hyperslab + herr_t err = H5Sselect_hyperslab( + fileSpace, H5S_SELECT_SET, &start, NULL, &count, NULL); + HDF5_ERROR(err); + + H5Dwrite(dSet, HDType, memSpace, fileSpace, H5P_DEFAULT, data); + H5Sclose(fileSpace); + } + +public: + // Delete copy constructors. We do not want to copy the DB around + hdf5DB(const hdf5DB&) = delete; + hdf5DB& operator=(const hdf5DB&) = delete; + + /** + * @brief constructs the class and opens the hdf5 file to write to + * @param[in] fn Name of the file to store data to + * @param[in] rId a unique Id for each process taking part in a distributed + * execution (rank-id) + */ + hdf5DB(std::string path, uint64_t rId) : FileDB(path, ".h5", rId) + { + if (isDouble::default_value()) + HDType = H5T_NATIVE_DOUBLE; + else + HDType = H5T_NATIVE_FLOAT; + std::error_code ec; + bool exists = fs::exists(this->fn); + this->checkError(ec); + + if (exists) + HFile = H5Fopen(this->fn.c_str(), H5F_ACC_RDWR, H5P_DEFAULT); + else + HFile = + H5Fcreate(this->fn.c_str(), H5F_ACC_EXCL, H5P_DEFAULT, H5P_DEFAULT); + HDF5_ERROR(HFile); + totalElements = 0; + } + + /** + * @brief deconstructs the class and closes the file + */ + ~hdf5DB() + { + // HDF5 Automatically closes all opened fds at exit of application. + // herr_t err = H5Fclose(HFile); + // HDF5_ERROR(err); + } + + /** + * @brief Define the type of the DB + */ + std::string type() override { return "hdf5"; } + + /** + * @brief Takes an input and an output vector each holding 1-D vectors data, + * and store them into a hdf5 file delimited by ':'. This should never be used + * for large scale simulations as txt/hdf5 format will be extremely slow. + * @param[in] num_elements Number of elements of each 1-D vector + * @param[in] inputs Vector of 1-D vectors containing the inputs to bestored + * @param[in] inputs Vector of 1-D vectors, each 1-D vectors contains + * 'num_elements' values to be stored + * @param[in] outputs Vector of 1-D vectors, each 1-D vectors contains + * 'num_elements' values to be stored + */ + PERFFASPECT() + virtual void store(size_t num_elements, + std::vector& inputs, + std::vector& outputs) override + { + + DBG(DB, + "DB of type %s stores %ld elements of input/output dimensions (%d, %d)", + type().c_str(), + num_elements, + inputs.size(), + outputs.size()) + const size_t num_in = inputs.size(); + const size_t num_out = outputs.size(); + + if (HDIsets.empty()) { + createDataSets(num_elements, num_in, num_out); + } + + if (HDIsets.size() != num_in || HDOsets.size() != num_out) { + std::cerr << "The data dimensionality is different than the one in the " + "DB\n"; + exit(-1); + } + + writeDataToDataset(HDIsets, inputs, num_elements); + writeDataToDataset(HDOsets, outputs, num_elements); + totalElements += num_elements; + } +}; +#endif + +#ifdef __ENABLE_REDIS__ +template +class RedisDB : public BaseDB +{ + const std::string _fn; // path to the file storing the DB access config + uint64_t _dbid; + Redis* _redis; + uint64_t keyId; + +public: + RedisDB(const RedisDB&) = delete; + RedisDB& operator=(const RedisDB&) = delete; + + /** + * @brief constructs the class and opens the file to write to + * @param[in] fn Name of the file to store data to + * @param[in] rId a unique Id for each process taking part in a distributed + * execution (rank-id) + */ + RedisDB(std::string fn, uint64_t rId) + : BaseDB(rId), _fn(fn), _redis(nullptr), keyId(0) + { + _dbid = reinterpret_cast(this); + auto connection_info = read_json(fn); + + ConnectionOptions connection_options; + connection_options.type = ConnectionType::TCP; + connection_options.host = connection_info["host"]; + connection_options.port = std::stoi(connection_info["service-port"]); + connection_options.password = connection_info["database-password"]; + connection_options.db = 0; // Optionnal, 0 is the default + connection_options.tls.enabled = + true; // Required to connect to PDS within LC + connection_options.tls.cacert = connection_info["cert"]; + + ConnectionPoolOptions pool_options; + pool_options.size = 100; // Pool size, i.e. max number of connections. + + _redis = new Redis(connection_options, pool_options); + } + + ~RedisDB() + { + std::cerr << "Deleting RedisDB object\n"; + delete _redis; + } + + inline std::string type() override { return "RedisDB"; } + + inline std::string info() { return _redis->info(); } + + // Return the number of keys in the DB + inline long long dbsize() { return _redis->dbsize(); } + + /* ! + * ! WARNING: Flush the entire Redis, accross all DBs! + * ! + */ + inline void flushall() { _redis->flushall(); } + + /* + * ! WARNING: Flush the entire current DB! + * ! + */ + inline void flushdb() { _redis->flushdb(); } + + std::unordered_map read_json(std::string fn) + { + std::ifstream config; + std::unordered_map connection_info = { + {"database-password", ""}, + {"host", ""}, + {"service-port", ""}, + {"cert", ""}, + }; + + config.open(fn, std::ifstream::in); + if (config.is_open()) { + std::string line; + // Quite inefficient parsing (to say the least..) but the file to parse is + // small (4 lines) + // TODO: maybe use Boost or another JSON library + while (std::getline(config, line)) { + if (line.find("{") != std::string::npos || + line.find("}") != std::string::npos) { + continue; + } + line.erase(std::remove(line.begin(), line.end(), ' '), line.end()); + line.erase(std::remove(line.begin(), line.end(), ','), line.end()); + line.erase(std::remove(line.begin(), line.end(), '"'), line.end()); + + std::string key = line.substr(0, line.find(':')); + line.erase(0, line.find(":") + 1); + connection_info[key] = line; + // std::cerr << "key=" << key << " and value=" << line << std::endl; + } + config.close(); + } else { + std::cerr << "Config located at: " << fn << std::endl; + throw std::runtime_error("Could not open Redis config file"); + } + return connection_info; + } + + void store(size_t num_elements, + std::vector& inputs, + std::vector& outputs) + { + + const size_t num_in = inputs.size(); + const size_t num_out = outputs.size(); + + // TODO: + // Make insertion more efficient. + // Right now it's pretty naive and expensive + auto start = std::chrono::high_resolution_clock::now(); + + for (size_t i = 0; i < num_elements; i++) { + std::string key = std::to_string(_dbid) + ":" + std::to_string(keyId) + + ":" + + std::to_string(i); // In Redis a key must be a string + std::ostringstream fd; + for (size_t j = 0; j < num_in; j++) { + fd << inputs[j][i] << ":"; + } + for (size_t j = 0; j < num_out - 1; j++) { + fd << outputs[j][i] << ":"; + } + fd << outputs[num_out - 1][i]; + std::string val(fd.str()); + _redis->set(key, val); + } + + keyId += 1; + + auto stop = std::chrono::high_resolution_clock::now(); + auto duration = + std::chrono::duration_cast(stop - start); + auto nb_keys = this->dbsize(); + + std::cout << std::setprecision(2) << "Inserted " << num_elements + << " keys [Total keys = " << nb_keys << "] into RedisDB [Total " + << duration.count() << "ms, " + << static_cast(num_elements) / duration.count() + << " key/ms]" << std::endl; + } +}; + +#endif // __ENABLE_HDF5__ + +#ifdef __ENABLE_RMQ__ + +/** + * @brief AMS represents the header as follows: + * The header is 12 bytes long: + * - 1 byte is the size of the header (here 12). Limit max: 255 + * - 1 byte is the precision (4 for float, 8 for double). Limit max: 255 + * - 2 bytes are the MPI rank (0 if AMS is not running with MPI). Limit max: 65535 + * - 4 bytes are the number of elements in the message. Limit max: 2^32 - 1 + * - 2 bytes are the input dimension. Limit max: 65535 + * - 2 bytes are the output dimension. Limit max: 65535 + * + * |__Header__|__Datatype__|___Rank___|__#elems__|___InDim___|___OutDim___|...real data...| + * ^ ^ ^ ^ ^ ^ ^ ^ + * | Byte 1 | Byte 2 | Byte 3-4 | Byte 4-8 | Byte 8-10 | Byte 10-12 | Byte 12-X | + * + * where X = datatype * num_element * (InDim + OutDim). Total message size is 12+X. + * + * The data starts at byte 12, ends at byte X. + * The data is structured as pairs of input/outputs. Let K be the total number of + * elements, then we have K pairs of inputs/outputs (either float or double): + * + * |__Header_(12B)__|__Input 1__|__Output 1__|...|__Input_K__|__Output_K__| + */ +template +struct AMSMsgHeader { + /** @brief Heaader size (bytes) */ + uint8_t hsize; + /** @brief Data type size (bytes) */ + uint8_t dtype; + /** @brief MPI rank */ + uint16_t mpi_rank; + /** @brief Number of elements */ + uint32_t num_elem; + /** @brief Inputs dimension */ + uint16_t in_dim; + /** @brief Outputs dimension */ + uint16_t out_dim; + + /** + * @brief Constructor for AMSMsgHeader + * @param[in] mpi_rank MPI rank + * @param[in] num_elem Number of elements (input/outputs) + * @param[in] in_dim Inputs dimension + * @param[in] out_dim Outputs dimension + */ + AMSMsgHeader(size_t mpi_rank, size_t num_elem, size_t in_dim, size_t out_dim) + : hsize(static_cast(AMSMsgHeader::size())), + dtype(static_cast(sizeof(TypeValue))), + mpi_rank(static_cast(mpi_rank)), + num_elem(static_cast(num_elem)), + in_dim(static_cast(in_dim)), + out_dim(static_cast(out_dim)) + { + } + + /** + * @brief Return the size of a header in the AMS protocol. + * @return The size of a message header in AMS (in byte) + */ + static size_t size() + { + return sizeof(hsize) + sizeof(dtype) + sizeof(mpi_rank) + sizeof(num_elem) + + sizeof(in_dim) + sizeof(out_dim); + } + + /** + * @brief Fill an empty buffer with a valid header. + * @param[in] data_blob The buffer to fill + * @return The number of bytes in the header or 0 if error + */ + size_t encode(uint8_t* data_blob) + { + if (!data_blob) return 0; + + size_t current_offset = 0; + // Header size (should be 1 bytes) + data_blob[current_offset] = hsize; + current_offset += sizeof(hsize); + // Data type (should be 1 bytes) + data_blob[current_offset] = dtype; + current_offset += sizeof(dtype); + // MPI rank (should be 2 bytes) + std::memcpy(data_blob + current_offset, &(mpi_rank), sizeof(mpi_rank)); + current_offset += sizeof(mpi_rank); + // Num elem (should be 4 bytes) + std::memcpy(data_blob + current_offset, &(num_elem), sizeof(num_elem)); + current_offset += sizeof(num_elem); + // Input dim (should be 2 bytes) + std::memcpy(data_blob + current_offset, &(in_dim), sizeof(in_dim)); + current_offset += sizeof(in_dim); + // Output dim (should be 2 bytes) + std::memcpy(data_blob + current_offset, &(out_dim), sizeof(out_dim)); + current_offset += sizeof(out_dim); + + return current_offset; + } +}; + + +/** + * @brief Class representing a message for the AMSLib + */ +template +class AMSMessage +{ +private: + /** @brief message ID */ + int _id; + /** @brief The MPI rank (0 if MPI is not used) */ + int _rank; + /** @brief The data represented as a binary blob */ + uint8_t* _data; + /** @brief The total size of the binary blob in bytes */ + size_t _total_size; + /** @brief The number of input/output pairs */ + size_t _num_elements; + /** @brief The dimensions of inputs */ + size_t _input_dim; + /** @brief The dimensions of outputs */ + size_t _output_dim; + +public: + /** + * @brief Constructor + * @param[in] num_elements Number of elements + * @param[in] inputs Inputs + * @param[in] outputs Outputs + */ + AMSMessage(int id, + size_t num_elements, + const std::vector& inputs, + const std::vector& outputs) + : _id(id), + _num_elements(num_elements), + _input_dim(inputs.size()), + _output_dim(outputs.size()), + _data(nullptr), + _total_size(0) + { +#ifdef __ENABLE_MPI__ + MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &_rank)); +#endif + _total_size = AMSMsgHeader::size() + getDataSize(); + _data = ams::ResourceManager::allocate(_total_size, + AMSResourceType::HOST); + + AMSMsgHeader header(_rank, + _num_elements, + _input_dim, + _output_dim); + size_t current_offset = header.encode(_data); + current_offset = encode_data(_data, current_offset, inputs, outputs); + } + + AMSMessage(const AMSMessage&) = delete; + AMSMessage& operator=(const AMSMessage&) = delete; + + AMSMessage(AMSMessage&& other) noexcept { *this = std::move(other); } + + AMSMessage& operator=(AMSMessage&& other) noexcept + { + if (this != &other) { + _id = other._id; + _num_elements = other._num_elements; + _input_dim = other._input_dim; + _output_dim = other._output_dim; + _total_size = other._total_size; + _data = other._data; + other._data = nullptr; + } + return *this; + } + + /** + * @brief Fill a buffer with a data section starting at a given position. + * @param[in] data_blob The buffer to fill + * @param[in] offset Position where to start writing in the buffer + * @param[in] inputs Inputs + * @param[in] outputs Outputs + * @return The number of bytes in the message or 0 if error + */ + size_t encode_data(uint8_t* data_blob, + size_t offset, + const std::vector& inputs, + const std::vector& outputs) + { + if (!data_blob) return 0; + // Creating the body part of the messages + // TODO: slow method (one copy per element!), improve by reducing number of copies + for (size_t i = 0; i < _num_elements; i++) { + for (size_t j = 0; j < _input_dim; j++) { + ams::ResourceManager::copy(&(inputs[j][i]), reinterpret_cast(_data + offset), sizeof(TypeValue)); + offset += sizeof(TypeValue); + } + for (size_t j = 0; j < _output_dim; j++) { + ams::ResourceManager::copy(&(outputs[j][i]), reinterpret_cast(_data + offset), sizeof(TypeValue)); + offset += sizeof(TypeValue); + } + } + + return offset; + } + + /** + * @brief Return the size of the data portion for that message + * @return Size in bytes of the data portion + */ + size_t getDataSize() + { + return (_num_elements * (_input_dim + _output_dim)) * sizeof(TypeValue); + } + + /** + * @brief Return the underlying data pointer + * @return Data pointer (binary blob) + */ + uint8_t* data() const { return _data; } + + /** + * @brief Return message ID + * @return message ID + */ + int id() const { return _id; } + + /** + * @brief Return the size in bytes of the underlying binary blob + * @return Byte size of data pointer + */ + size_t size() const { return _total_size; } + + ~AMSMessage() + { + if (_data) + ams::ResourceManager::deallocate(_data, AMSResourceType::HOST); + } +}; // class AMSMessage + +/** @brief Structure that represents a received RabbitMQ message. + * - The first field is the message content (body) + * - The second field is the RMQ exchange from which the message + * has been received + * - The third field is the routing key + * - The fourth is the delivery tag (ID of the message) + * - The fifth field is a boolean that indicates if that message + * has been redelivered by RMQ. + */ +typedef std::tuple + inbound_msg; + +/** + * @brief Specific handler for RabbitMQ connections based on libevent. + */ +template +class RMQConsumerHandler : public AMQP::LibEventHandler +{ +private: + /** @brief Path to TLS certificate */ + std::string _cacert; + /** @brief The MPI rank (0 if MPI is not used) */ + int _rank; + /** @brief LibEvent I/O loop */ + std::shared_ptr _loop; + /** @brief main channel used to send data to the broker */ + std::shared_ptr _channel; + /** @brief RabbitMQ queue */ + std::string _queue; + /** @brief Queue that contains all the messages received on receiver queue */ + std::shared_ptr> _messages; + +public: + /** + * @brief Constructor + * @param[in] loop Event Loop + * @param[in] cacert SSL Cacert + * @param[in] rank MPI rank + */ + RMQConsumerHandler(std::shared_ptr loop, + std::string cacert, + std::string queue) + : AMQP::LibEventHandler(loop.get()), + _loop(loop), + _rank(0), + _cacert(std::move(cacert)), + _queue(queue), + _messages(std::make_shared>()), + _channel(nullptr) + { +#ifdef __ENABLE_MPI__ + MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &_rank)); +#endif + } + + ~RMQConsumerHandler() = default; + +private: + /** + * @brief Method that is called after a TCP connection has been set up, and + * right before the SSL handshake is going to be performed to secure the + * connection (only for amqps:// connections). This method can be overridden + * in user space to load client side certificates. + * @param[in] connection The connection for which TLS was just started + * @param[in] ssl Pointer to the SSL structure that can be + * modified + * @return bool True to proceed / accept the connection, false + * to break up + */ + virtual bool onSecuring(AMQP::TcpConnection* connection, SSL* ssl) + { + ERR_clear_error(); + unsigned long err; +#if OPENSSL_VERSION_NUMBER < 0x10100000L + int ret = SSL_use_certificate_file(ssl, _cacert.c_str(), SSL_FILETYPE_PEM); +#else + int ret = SSL_use_certificate_chain_file(ssl, _cacert.c_str()); +#endif + // TODO: with openssl 3.0 + // SSL_set_options(ssl, SSL_OP_IGNORE_UNEXPECTED_EOF); + + if (ret != 1) { + std::string error("openssl: error loading ca-chain (" + _cacert + + ") + from ["); + SSL_get_error(ssl, ret); + if ((err = ERR_get_error())) { + error += std::string(ERR_reason_error_string(err)); + } + error += "]"; + throw std::runtime_error(error); + } else { + DBG(RMQConsumerHandler, + "Success logged with ca-chain %s", + _cacert.c_str()) + return true; + } + } + + /** + * @brief Method that is called when the secure TLS connection has been + * established. This is only called for amqps:// connections. It allows you to + * inspect whether the connection is secure enough for your liking (you can + * for example check the server certificate). The AMQP protocol still has + * to be started. + * @param[in] connection The connection that has been secured + * @param[in] ssl SSL structure from openssl library + * @return bool True if connection can be used + */ + virtual bool onSecured(AMQP::TcpConnection* connection, + const SSL* ssl) override + { + DBG(RMQConsumerHandler, + "[rank=%d] Secured TLS connection has been established.", + _rank) + return true; + } + + /** + * @brief Method that is called by the AMQP library when the login attempt + * succeeded. After this the connection is ready to use. + * @param[in] connection The connection that can now be used + */ + virtual void onReady(AMQP::TcpConnection* connection) override + { + DBG(RMQConsumerHandler, + "[rank=%d] Sucessfuly logged in. Connection ready to use.\n", + _rank) + + _channel = std::make_shared(connection); + _channel->onError([&](const char* message) { + CFATAL(RMQConsumerHandler, + false, + "[rank=%d] Error on channel: %s", + _rank, + message) + }); + + _channel->declareQueue(_queue) + .onSuccess([&](const std::string& name, + uint32_t messagecount, + uint32_t consumercount) { + if (messagecount > 0 || consumercount > 1) { + CWARNING(RMQConsumerHandler, + _rank == 0, + "[rank=%d] declared queue: %s (messagecount=%d, " + "consumercount=%d)", + _rank, + _queue.c_str(), + messagecount, + consumercount) + } + // We can now install callback functions for when we will consumme messages + // callback function that is called when the consume operation starts + auto startCb = [](const std::string& consumertag) { + DBG(RMQConsumerHandler, + "consume operation started with tag: %s", + consumertag.c_str()) + }; + + // callback function that is called when the consume operation failed + auto errorCb = [](const char* message) { + CFATAL(RMQConsumerHandler, + false, + "consume operation failed: %s", + message); + }; + // callback operation when a message was received + auto messageCb = [&](const AMQP::Message& message, + uint64_t deliveryTag, + bool redelivered) { + // acknowledge the message + _channel->ack(deliveryTag); + std::string msg(message.body(), message.bodySize()); + DBG(RMQConsumerHandler, + "message received [tag=%d] : '%s' of size %d B from '%s'/'%s'", + deliveryTag, + msg.c_str(), + message.bodySize(), + message.exchange().c_str(), + message.routingkey().c_str()) + _messages->push_back(std::make_tuple(std::move(msg), + message.exchange(), + message.routingkey(), + deliveryTag, + redelivered)); + }; + + /* callback that is called when the consumer is cancelled by RabbitMQ (this + * only happens in rare situations, for example when someone removes the queue + * that you are consuming from) + */ + auto cancelledCb = [](const std::string& consumertag) { + WARNING(RMQConsumerHandler, + "consume operation cancelled by the RabbitMQ server: %s", + consumertag.c_str()) + }; + + // start consuming from the queue, and install the callbacks + _channel->consume(_queue) + .onReceived(messageCb) + .onSuccess(startCb) + .onCancelled(cancelledCb) + .onError(errorCb); + }) + .onError([&](const char* message) { + CFATAL(RMQConsumerHandler, + false, + "[ERROR][rank=%d] Error while creating broker queue (%s): %s", + _rank, + _queue.c_str(), + message) + }); + } + + /** + * Method that is called when the AMQP protocol is ended. This is the + * counter-part of a call to connection.close() to graceful shutdown + * the connection. Note that the TCP connection is at this time still + * active, and you will also receive calls to onLost() and onDetached() + * @param connection The connection over which the AMQP protocol ended + */ + virtual void onClosed(AMQP::TcpConnection* connection) override + { + DBG(RMQConsumerHandler, "[rank=%d] Connection is closed.\n", _rank) + } + + /** + * @brief Method that is called by the AMQP library when a fatal error occurs + * on the connection, for example because data received from RabbitMQ + * could not be recognized, or the underlying connection is lost. This + * call is normally followed by a call to onLost() (if the error occurred + * after the TCP connection was established) and onDetached(). + * @param[in] connection The connection on which the error occurred + * @param[in] message A human readable error message + */ + virtual void onError(AMQP::TcpConnection* connection, + const char* message) override + { + DBG(RMQConsumerHandler, + "[rank=%d] fatal error when establishing TCP connection: %s\n", + _rank, + message) + } + + /** + * Final method that is called. This signals that no further calls to your + * handler will be made about the connection. + * @param connection The connection that can be destructed + */ + virtual void onDetached(AMQP::TcpConnection* connection) override + { + // add your own implementation, like cleanup resources or exit the application + DBG(RMQConsumerHandler, "[rank=%d] Connection is detached.\n", _rank) + } +}; // class RMQConsumerHandler + +/** + * @brief Class that manages a RabbitMQ broker and handles connection, event + * loop and set up various handlers. + */ +template +class RMQConsumer +{ +private: + /** @brief Connection to the broker */ + AMQP::TcpConnection* _connection; + /** @brief name of the queue to send data */ + std::string _queue; + /** @brief TLS certificate file */ + std::string _cacert; + /** @brief MPI rank (if MPI is used, otherwise 0) */ + int _rank; + /** @brief The event loop for sender (usually the default one in libevent) */ + std::shared_ptr _loop; + /** @brief The handler which contains various callbacks for the sender */ + std::shared_ptr> _handler; + /** @brief Queue that contains all the messages received on receiver queue (messages can be popped in) */ + std::vector _messages; + +public: + RMQConsumer(const RMQConsumer&) = delete; + RMQConsumer& operator=(const RMQConsumer&) = delete; + + RMQConsumer(const AMQP::Address& address, + std::string cacert, + std::string queue) + : _rank(0), _queue(queue), _cacert(cacert), _handler(nullptr) + { +#ifdef __ENABLE_MPI__ + MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &_rank)); +#endif +#ifdef EVTHREAD_USE_PTHREADS_IMPLEMENTED + evthread_use_pthreads(); +#endif + CDEBUG(RMQConsumer, + _rank == 0, + "Libevent %s (LIBEVENT_VERSION_NUMBER = %#010x)", + event_get_version(), + event_get_version_number()); + CDEBUG(RMQConsumer, + _rank == 0, + "%s (OPENSSL_VERSION_NUMBER = %#010x)", + OPENSSL_VERSION_TEXT, + OPENSSL_VERSION_NUMBER); +#if OPENSSL_VERSION_NUMBER < 0x10100000L + SSL_library_init(); +#else + OPENSSL_init_ssl(0, NULL); +#endif + CINFO(RMQConsumer, + _rank == 0, + "RabbitMQ address: %s:%d/%s (queue = %s)", + address.hostname().c_str(), + address.port(), + address.vhost().c_str(), + _queue.c_str()) + + _loop = std::shared_ptr(event_base_new(), + [](struct event_base* event) { + event_base_free(event); + }); + _handler = + std::make_shared>(_loop, _cacert, _queue); + _connection = new AMQP::TcpConnection(_handler.get(), address); + } + + /** + * @brief Start the underlying I/O loop (blocking call) + */ + void start() { event_base_dispatch(_loop.get()); } + + /** + * @brief Stop the underlying I/O loop + */ + void stop() { event_base_loopexit(_loop.get(), NULL); } + + /** + * @brief Return the most recent messages and delete it + * @return A structure inbound_msg which is a std::tuple (see typedef) + */ + inbound_msg pop_messages() + { + if (!_messages.empty()) { + inbound_msg msg = _messages.back(); + _messages.pop_back(); + return msg; + } + return std::make_tuple("", "", "", -1, false); + } + + /** + * @brief Return the message corresponding to the delivery tag. Do not delete the + * message. + * @param[in] delivery_tag Delivery tag that will be returned (if found) + * @return A structure inbound_msg which is a std::tuple (see typedef) + */ + inbound_msg get_messages(uint64_t delivery_tag) + { + if (!_messages.empty()) { + auto it = std::find_if(_messages.begin(), + _messages.end(), + [&delivery_tag](const inbound_msg& e) { + return std::get<3>(e) == delivery_tag; + }); + if (it != _messages.end()) return *it; + } + return std::make_tuple("", "", "", -1, false); + } + + ~RMQConsumer() + { + _connection->close(false); + delete _connection; + } +}; // class RMQConsumer + +/** + * @brief Specific handler for RabbitMQ connections based on libevent. + */ +template +class RMQPublisherHandler : public AMQP::LibEventHandler +{ +private: + /** @brief Path to TLS certificate */ + std::string _cacert; + /** @brief The MPI rank (0 if MPI is not used) */ + int _rank; + /** @brief LibEvent I/O loop */ + std::shared_ptr _loop; + /** @brief main channel used to send data to the broker */ + std::shared_ptr _channel; + /** @brief AMQP reliable channel (wrapper of classic channel with added functionalities) */ + std::shared_ptr> _rchannel; + /** @brief RabbitMQ queue */ + std::string _queue; + /** @brief Total number of messages sent */ + int _nb_msg; + /** @brief Number of messages successfully acknowledged */ + int _nb_msg_ack; + +public: + /** + * @brief Constructor + * @param[in] loop Event Loop + * @param[in] cacert SSL Cacert + * @param[in] rank MPI rank + */ + RMQPublisherHandler(std::shared_ptr loop, + std::string cacert, + std::string queue) + : AMQP::LibEventHandler(loop.get()), + _loop(loop), + _rank(0), + _cacert(std::move(cacert)), + _queue(queue), + _nb_msg_ack(0), + _nb_msg(0), + _channel(nullptr), + _rchannel(nullptr) + { +#ifdef __ENABLE_MPI__ + MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &_rank)); +#endif + } + + /** + * @brief Publish data on RMQ queue. + * @param[in] data The data pointer + * @param[in] data_size The number of bytes in the data pointer + */ + void publish(const AMSMessage& msg) + { + if (_rchannel) { + // publish a message via the reliable-channel + _rchannel + ->publish("", _queue, reinterpret_cast(msg.data()), msg.size()) + .onAck([&]() { + DBG(RMQPublisherHandler, + "[rank=%d] message #%d got acknowledged successfully by RMQ " + "server", + _rank, + _nb_msg) + _nb_msg_ack++; + }) + .onNack([&]() { + WARNING(RMQPublisherHandler, + "[rank=%d] message #%d received negative acknowledged by " + "RMQ " + "server", + _rank, + _nb_msg) + }) + .onLost([&]() { + CFATAL(RMQPublisherHandler, + false, + "[rank=%d] message #%d likely got lost by RMQ server", + _rank, + _nb_msg) + }) + .onError([&](const char* err_message) { + CFATAL(RMQPublisherHandler, + false, + "[rank=%d] message #%d did not get send: %s", + _rank, + _nb_msg, + err_message) + }); + } else { + WARNING(RMQPublisherHandler, + "[rank=%d] The reliable channel was not ready for message #%d.", + _rank, + _nb_msg) + } + _nb_msg++; + } + + ~RMQPublisherHandler() = default; + +private: + /** + * @brief Method that is called after a TCP connection has been set up, and + * right before the SSL handshake is going to be performed to secure the + * connection (only for amqps:// connections). This method can be overridden + * in user space to load client side certificates. + * @param[in] connection The connection for which TLS was just started + * @param[in] ssl Pointer to the SSL structure that can be + * modified + * @return bool True to proceed / accept the connection, false + * to break up + */ + virtual bool onSecuring(AMQP::TcpConnection* connection, SSL* ssl) + { + ERR_clear_error(); + unsigned long err; +#if OPENSSL_VERSION_NUMBER < 0x10100000L + int ret = SSL_use_certificate_file(ssl, _cacert.c_str(), SSL_FILETYPE_PEM); +#else + int ret = SSL_use_certificate_chain_file(ssl, _cacert.c_str()); +#endif + if (ret != 1) { + std::string error("openssl: error loading ca-chain (" + _cacert + + ") + from ["); + SSL_get_error(ssl, ret); + if ((err = ERR_get_error())) { + error += std::string(ERR_reason_error_string(err)); + } + error += "]"; + throw std::runtime_error(error); + } else { + DBG(RMQPublisherHandler, + "Success logged with ca-chain %s", + _cacert.c_str()) + return true; + } + } + + /** + * @brief Method that is called when the secure TLS connection has been + * established. This is only called for amqps:// connections. It allows you to + * inspect whether the connection is secure enough for your liking (you can + * for example check the server certificate). The AMQP protocol still has + * to be started. + * @param[in] connection The connection that has been secured + * @param[in] ssl SSL structure from openssl library + * @return bool True if connection can be used + */ + virtual bool onSecured(AMQP::TcpConnection* connection, + const SSL* ssl) override + { + DBG(RMQPublisherHandler, + "[rank=%d] Secured TLS connection has been established.", + _rank) + return true; + } + + /** + * @brief Method that is called by the AMQP library when the login attempt + * succeeded. After this the connection is ready to use. + * @param[in] connection The connection that can now be used + */ + virtual void onReady(AMQP::TcpConnection* connection) override + { + DBG(RMQPublisherHandler, + "[rank=%d] Sucessfuly logged in. Connection ready to use.\n", + _rank) + + _channel = std::make_shared(connection); + _channel->onError([&](const char* message) { + CFATAL(RMQPublisherHandler, + false, + "[rank=%d] Error on channel: %s", + _rank, + message) + }); + + _channel->declareQueue(_queue) + .onSuccess([&](const std::string& name, + uint32_t messagecount, + uint32_t consumercount) { + if (messagecount > 0 || consumercount > 1) { + CWARNING(RMQPublisherHandler, + _rank == 0, + "[rank=%d] declared queue: %s (messagecount=%d, " + "consumercount=%d)", + _rank, + _queue.c_str(), + messagecount, + consumercount) + } + // We can now instantiate the shared buffer between AMS and RMQ + DBG(RMQPublisherHandler, + "[rank=%d] declared queue: %s", + _rank, + _queue.c_str()) + _rchannel = + std::make_shared>(*_channel.get()); + }) + .onError([&](const char* message) { + CFATAL(RMQPublisherHandler, + false, + "[ERROR][rank=%d] Error while creating broker queue (%s): %s", + _rank, + _queue.c_str(), + message) + }); + } + + /** + * Method that is called when the AMQP protocol is ended. This is the + * counter-part of a call to connection.close() to graceful shutdown + * the connection. Note that the TCP connection is at this time still + * active, and you will also receive calls to onLost() and onDetached() + * @param connection The connection over which the AMQP protocol ended + */ + virtual void onClosed(AMQP::TcpConnection* connection) override + { + DBG(RMQPublisherHandler, "[rank=%d] Connection is closed.\n", _rank) + } + + /** + * @brief Method that is called by the AMQP library when a fatal error occurs + * on the connection, for example because data received from RabbitMQ + * could not be recognized, or the underlying connection is lost. This + * call is normally followed by a call to onLost() (if the error occurred + * after the TCP connection was established) and onDetached(). + * @param[in] connection The connection on which the error occurred + * @param[in] message A human readable error message + */ + virtual void onError(AMQP::TcpConnection* connection, + const char* message) override + { + DBG(RMQPublisherHandler, + "[rank=%d] fatal error when establishing TCP connection: %s\n", + _rank, + message) + } + + /** + * Final method that is called. This signals that no further calls to your + * handler will be made about the connection. + * @param connection The connection that can be destructed + */ + virtual void onDetached(AMQP::TcpConnection* connection) override + { + // add your own implementation, like cleanup resources or exit the application + DBG(RMQPublisherHandler, "[rank=%d] Connection is detached.\n", _rank) + } +}; // class RMQPublisherHandler + + +/** + * @brief Class that manages a RabbitMQ broker and handles connection, event + * loop and set up various handlers. + */ +template +class RMQPublisher +{ +private: + /** @brief Connection to the broker */ + AMQP::TcpConnection* _connection; + /** @brief name of the queue to send data */ + std::string _queue; + /** @brief TLS certificate file */ + std::string _cacert; + /** @brief MPI rank (if MPI is used, otherwise 0) */ + int _rank; + /** @brief The event loop for sender (usually the default one in libevent) */ + std::shared_ptr _loop; + /** @brief The handler which contains various callbacks for the sender */ + std::shared_ptr> _handler; + +public: + RMQPublisher(const RMQPublisher&) = delete; + RMQPublisher& operator=(const RMQPublisher&) = delete; + + RMQPublisher(const AMQP::Address& address, + std::string cacert, + std::string queue) + : _rank(0), _queue(queue), _cacert(cacert), _handler(nullptr) + { +#ifdef __ENABLE_MPI__ + MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &_rank)); +#endif +#ifdef EVTHREAD_USE_PTHREADS_IMPLEMENTED + evthread_use_pthreads(); +#endif + CDEBUG(RMQPublisher, + _rank == 0, + "Libevent %s (LIBEVENT_VERSION_NUMBER = %#010x)", + event_get_version(), + event_get_version_number()); + CDEBUG(RMQPublisher, + _rank == 0, + "%s (OPENSSL_VERSION_NUMBER = %#010x)", + OPENSSL_VERSION_TEXT, + OPENSSL_VERSION_NUMBER); +#if OPENSSL_VERSION_NUMBER < 0x10100000L + SSL_library_init(); +#else + OPENSSL_init_ssl(0, NULL); +#endif + CINFO(RMQPublisher, + _rank == 0, + "RabbitMQ address: %s:%d/%s (queue = %s)", + address.hostname().c_str(), + address.port(), + address.vhost().c_str(), + _queue.c_str()) + + _loop = std::shared_ptr(event_base_new(), + [](struct event_base* event) { + event_base_free(event); + }); + + _handler = std::make_shared>(_loop, + _cacert, + _queue); + _connection = new AMQP::TcpConnection(_handler.get(), address); + } + + /** + * @brief Check if the underlying RabbitMQ connection is ready and usable + * @return True if the publisher is ready to publish + */ + bool ready_publish() { return _connection->ready() && _connection->usable(); } + + /** + * @brief Wait that the connection is ready (blocking call) + * @return True if the publisher is ready to publish + */ + void wait_ready(int ms = 500, int timeout_sec = 30) + { + // We wait for the connection to be ready + int total_time = 0; + while (!ready_publish()) { + std::this_thread::sleep_for(std::chrono::milliseconds(ms)); + DBG(RMQPublisher, + "[rank=%d] Waiting for connection to be ready...", + _rank) + total_time += ms; + if (total_time > timeout_sec * 1000) { + DBG(RMQPublisher, "[rank=%d] Connection timeout", _rank) + break; + // TODO: if connection is not working -> revert to classic file DB. + } + } + } + + /** + * @brief Start the underlying I/O loop (blocking call) + */ + void start() + { + event_base_dispatch(_loop.get()); + // We wait for the connection to be ready + wait_ready(); + } + + /** + * @brief Stop the underlying I/O loop + */ + void stop() { event_base_loopexit(_loop.get(), NULL); } + + void publish(const AMSMessage& message) + { + _handler->publish(message); + } + + ~RMQPublisher() { delete _connection; } +}; // class RMQPublisher + +/** + * @brief Class that manages a RabbitMQ broker and handles connection, event + * loop and set up various handlers. + * @details This class manages a specific type of database backend in AMSLib. + * Instead of writing inputs/outputs directly to files (CSV or HDF5), we + * send these elements (a collection of inputs and their corresponding outputs) + * to a service called RabbitMQ which is listening on a given IP and port. + * + * This class requires a RabbitMQ server to be running somewhere, + * the credentials of that server should be formatted as a JSON file as follows: + * + * { + * "rabbitmq-name": "testamsrabbitmq", + * "rabbitmq-password": "XXX", + * "rabbitmq-user": "pottier1", + * "rabbitmq-vhost": "ams", + * "service-port": 31495, + * "service-host": "url.czapps.llnl.gov", + * "rabbitmq-cert": "tls-cert.crt", + * "rabbitmq-inbound-queue": "test4", + * "rabbitmq-outbound-queue": "test3" + * } + * + * The TLS certificate must be generated by the user and the absolute paths are preferred. + * A TLS certificate can be generated with the following command: + * + * openssl s_client \ + * -connect $REMOTE_HOST:$REMOTE_PORT -showcerts < /dev/null \ + * 2>/dev/null | sed -ne '/-BEGIN CERTIFICATE-/,/-END CERTIFICATE-/p' > tls.crt + * + * RabbitMQDB creates two RabbitMQ connections per MPI rank, one for publishing data to RMQ and one for consuming data. + * Each connection has its own I/O loop (based on Libevent) running in a dedicated thread because I/O loop are blocking. + * Therefore, we have two threads per MPI rank. + * + * 1. Publishing data: When the store() method is being called, it triggers a series of calls: + * + * RabbitMQDB::store() -> RMQPublisher::publish() -> RMQPublisherHandler::publish() + * + * Here, RMQPublisherHandler::publish() has access to internal RabbitMQ channels and can publish the message + * on the outbound queue (rabbitmq-outbound-queue in the JSON configuration). + * Note that storing data like that is much faster than with writing files as a call to RabbitMQDB::store() + * is virtually free, the actual data sending part is taking place in a thread and does not slow down + * the main simulation (MPI). + * + * 2. Consuming data: The inbound queue (rabbitmq-inbound-queue in the JSON configuration) is the queue for incoming data. The + * RMQConsumer is listening on that queue for messages. In the AMSLib approach, that queue is used to communicate + * updates to rank regarding the ML surrrogate model. RMQConsumer will automatically populate a std::vector with all + * messages received since the execution of AMS started. + * + * Glabal note: Most calls dealing with RabbitMQ (to establish a RMQ connection, opening a channel, publish data etc) + * are asynchronous callbacks (similar to asyncio in Python or future in C++). + * So, the simulation can have already started and the RMQ connection might not be valid which is why most part + * of the code that deals with RMQ are wrapped into callbacks that will get run only in case of success. + * For example, we create a channel only if the underlying connection has been succesfuly initiated + * (see RMQPublisherHandler::onReady()). + */ +template +class RabbitMQDB final : public BaseDB +{ +private: + /** @brief Path of the config file (JSON) */ + std::string _config; + /** @brief name of the queue to send data */ + std::string _queue_sender; + /** @brief name of the queue to receive data */ + std::string _queue_receiver; + /** @brief MPI rank (if MPI is used, otherwise 0) */ + int _rank; + /** @brief Represent the ID of the last message sent */ + int _msg_tag; + /** @brief Publisher sending messages to RMQ server */ + std::shared_ptr> _publisher; + /** @brief Thread in charge of the publisher */ + std::thread _publisher_thread; + /** @brief Consumer listening to RMQ and consuming messages */ + std::shared_ptr> _consumer; + /** @brief Thread in charge of the consumer */ + std::thread _consumer_thread; + + /** + * @brief Read a JSON and create a hashmap + * @param[in] fn Path of the RabbitMQ JSON config file + * @return a hashmap (std::unordered_map) of the JSON file + */ + std::unordered_map _read_config(std::string fn) + { + std::ifstream config; + std::unordered_map connection_info = { + {"rabbitmq-erlang-cookie", ""}, + {"rabbitmq-name", ""}, + {"rabbitmq-password", ""}, + {"rabbitmq-user", ""}, + {"rabbitmq-vhost", ""}, + {"service-port", ""}, + {"service-host", ""}, + {"rabbitmq-cert", ""}, + {"rabbitmq-inbound-queue", ""}, + {"rabbitmq-outbound-queue", ""}, + }; + + config.open(fn, std::ifstream::in); + + if (config.is_open()) { + std::string line; + while (std::getline(config, line)) { + if (line.find("{") != std::string::npos || + line.find("}") != std::string::npos) { + continue; + } + line.erase(std::remove(line.begin(), line.end(), ' '), line.end()); + line.erase(std::remove(line.begin(), line.end(), ','), line.end()); + line.erase(std::remove(line.begin(), line.end(), '"'), line.end()); + + std::string key = line.substr(0, line.find(':')); + line.erase(0, line.find(":") + 1); + connection_info[key] = line; + } + config.close(); + } else { + std::string err = "Could not open JSON file: " + fn; + CFATAL(RabbitMQDB, false, err.c_str()); + } + return connection_info; + } + +public: + RabbitMQDB(const RabbitMQDB&) = delete; + RabbitMQDB& operator=(const RabbitMQDB&) = delete; + + RabbitMQDB(char* config, uint64_t id) + : BaseDB(id), + _rank(0), + _msg_tag(0), + _config(std::string(config)), + _publisher(nullptr), + _consumer(nullptr) + { + std::unordered_map rmq_config = + _read_config(_config); + _queue_sender = + rmq_config["rabbitmq-outbound-queue"]; // Queue to send data to + _queue_receiver = + rmq_config["rabbitmq-inbound-queue"]; // Queue to receive data from PDS + bool is_secure = true; + + if (rmq_config["service-port"].empty()) { + CFATAL(RabbitMQDB, + false, + "service-port is empty, make sure the port number is present in " + "the JSON configuration") + return; + } + if (rmq_config["service-host"].empty()) { + CFATAL(RabbitMQDB, + false, + "service-host is empty, make sure the host is present in the JSON " + "configuration") + return; + } + + uint16_t port = std::stoi(rmq_config["service-port"]); + if (_queue_sender.empty() || _queue_receiver.empty()) { + CFATAL(RabbitMQDB, + false, + "Queues are empty, please check your credentials file and make " + "sure rabbitmq-inbound-queue and rabbitmq-outbound-queue exist") + return; + } + + AMQP::Login login(rmq_config["rabbitmq-user"], + rmq_config["rabbitmq-password"]); + AMQP::Address address(rmq_config["service-host"], + port, + login, + rmq_config["rabbitmq-vhost"], + is_secure); + + std::string cacert = rmq_config["rabbitmq-cert"]; + _publisher = std::make_shared>(address, + cacert, + _queue_sender); + _consumer = std::make_shared>(address, + cacert, + _queue_receiver); + + _publisher_thread = std::thread([&]() { _publisher->start(); }); + _consumer_thread = std::thread([&]() { _consumer->start(); }); + } + + /** + * @brief Takes an input and an output vector each holding 1-D vectors data, and push + * it onto the libevent buffer. + * @param[in] num_elements Number of elements of each 1-D vector + * @param[in] inputs Vector of 1-D vectors containing the inputs to be sent + * @param[in] outputs Vector of 1-D vectors, each 1-D vectors contains + * 'num_elements' values to be sent + */ + PERFFASPECT() + void store(size_t num_elements, + std::vector& inputs, + std::vector& outputs) override + { + DBG(RabbitMQDB, + "[tag=%d] %s stores %ld elements of input/output " + "dimensions (%d, %d)", + _msg_tag, + type().c_str(), + num_elements, + inputs.size(), + outputs.size()) + + auto msg = AMSMessage(_msg_tag, num_elements, inputs, outputs); + _publisher->publish(msg); + _msg_tag++; + } + + /** + * @brief Return the type of this broker + * @return The type of the broker + */ + std::string type() override { return "rabbitmq"; } + + ~RabbitMQDB() + { + _publisher->stop(); + _consumer->stop(); + _publisher_thread.join(); + _consumer_thread.join(); + } +}; // class RabbitMQDB + +#endif // __ENABLE_RMQ__ + + +/** + * @brief Create an object of the respective database. + * This should never be used for large scale simulations as txt/csv format will + * be extremely slow. + * @param[in] dbPath path to the directory storing the data + * @param[in] dbType Type of the database to create + * @param[in] rId a unique Id for each process taking part in a distributed + * execution (rank-id) + */ +template +BaseDB* createDB(char* dbPath, AMSDBType dbType, uint64_t rId = 0) +{ + DBG(DB, "Instantiating data base"); +#ifdef __ENABLE_DB__ + if (dbPath == nullptr) { + std::cerr << " [WARNING] Path of DB is NULL, Please provide a valid path " + "to enable db\n"; + std::cerr << " [WARNING] Continueing\n"; + return nullptr; + } + + switch (dbType) { + case AMSDBType::CSV: + return new csvDB(dbPath, rId); +#ifdef __ENABLE_REDIS__ + case AMSDBType::REDIS: + return new RedisDB(dbPath, rId); +#endif +#ifdef __ENABLE_HDF5__ + case AMSDBType::HDF5: + return new hdf5DB(dbPath, rId); +#endif +#ifdef __ENABLE_RMQ__ + case AMSDBType::RMQ: + return new RabbitMQDB(dbPath, rId); +#endif + default: + return nullptr; + } +#else + return nullptr; +#endif +} + +#endif // __AMS_BASE_DB__ \ No newline at end of file diff --git a/src/wf/cuda/utilities.cuh b/src/AMSlib/wf/cuda/utilities.cuh similarity index 100% rename from src/wf/cuda/utilities.cuh rename to src/AMSlib/wf/cuda/utilities.cuh diff --git a/src/wf/data_handler.hpp b/src/AMSlib/wf/data_handler.hpp similarity index 100% rename from src/wf/data_handler.hpp rename to src/AMSlib/wf/data_handler.hpp diff --git a/src/wf/debug.h b/src/AMSlib/wf/debug.h similarity index 100% rename from src/wf/debug.h rename to src/AMSlib/wf/debug.h diff --git a/src/wf/device.hpp b/src/AMSlib/wf/device.hpp similarity index 100% rename from src/wf/device.hpp rename to src/AMSlib/wf/device.hpp diff --git a/src/wf/redist_load.hpp b/src/AMSlib/wf/redist_load.hpp similarity index 100% rename from src/wf/redist_load.hpp rename to src/AMSlib/wf/redist_load.hpp diff --git a/src/wf/resource_manager.cpp b/src/AMSlib/wf/resource_manager.cpp similarity index 100% rename from src/wf/resource_manager.cpp rename to src/AMSlib/wf/resource_manager.cpp diff --git a/src/wf/resource_manager.hpp b/src/AMSlib/wf/resource_manager.hpp similarity index 100% rename from src/wf/resource_manager.hpp rename to src/AMSlib/wf/resource_manager.hpp diff --git a/src/wf/utils.hpp b/src/AMSlib/wf/utils.hpp similarity index 100% rename from src/wf/utils.hpp rename to src/AMSlib/wf/utils.hpp diff --git a/src/wf/workflow.hpp b/src/AMSlib/wf/workflow.hpp similarity index 100% rename from src/wf/workflow.hpp rename to src/AMSlib/wf/workflow.hpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5171d166..67f7d824 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,82 +1,10 @@ -# Copyright (c) Lawrence Livermore National Security, LLC and other AMS -# Project developers. See top-level LICENSE AND COPYRIGHT files for dates and -# other details. No copyright assignment is required to contribute +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# ------------------------------------------------------------------------------ -# handle sources and headers -file(GLOB_RECURSE MINIAPP_INCLUDES "*.hpp") -#set global library path to link with tests if necessary -set(LIBRARY_OUTPUT_PATH ${AMS_LIB_OUT_PATH}) -set(AMS_LIB_SRC ${MINIAPP_INCLUDES} AMS.cpp wf/resource_manager.cpp wf/base64.c) -# two targets: a shared lib and an exec -add_library(AMS ${AMS_LIB_SRC} ${MINIAPP_INCLUDES}) +add_subdirectory(AMSlib) -# ------------------------------------------------------------------------------ -if (WITH_CUDA) - set_target_properties(AMS PROPERTIES CUDA_ARCHITECTURES ${AMS_CUDA_ARCH}) - - # if (BUILD_SHARED_LIBS) - # set_target_properties(AMS PROPERTIES CUDA_SEPARABLE_COMPILATION ON) - # else() - # set_target_properties(AMS PROPERTIES CUDA_SEPARABLE_COMPILATION ON) - # set_target_properties(AMS PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON) - # endif() - - set_source_files_properties(AMS.cpp PROPERTIES LANGUAGE CUDA) - set_source_files_properties(AMS.cpp PROPERTIES CUDA_ARCHITECTURES ${AMS_CUDA_ARCH}) - set_source_files_properties(AMS.cpp PROPERTIES COMPILE_FLAGS "--expt-extended-lambda") - - if (WITH_PERFFLOWASPECT) - set_property(SOURCE AMS.cpp APPEND_STRING PROPERTY COMPILE_FLAGS " -ccbin clang++ -Xcompiler=-Xclang -Xcompiler=-load -Xcompiler=-Xclang -Xcompiler=${PERFFLOWASPECT_LIB_DIR}/libWeavePass.so") - set_source_files_properties(wf/resource_manager.cpp COMPILE_FLAGS "-Xclang -load -Xclang ${PERFFLOWASPECT_LIB_DIR}/libWeavePass.so") - endif() -endif() - -# ------------------------------------------------------------------------------ -# setup the lib first -message(STATUS "ALL INCLUDES ARE ${AMS_APP_INCLUDES}") -target_compile_definitions(AMS PRIVATE ${AMS_APP_DEFINES}) -target_include_directories(AMS PRIVATE ${AMS_APP_INCLUDES}) -target_include_directories(AMS PUBLIC - $ - $) -target_include_directories(AMS PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) -target_link_directories(AMS PUBLIC ${AMS_APP_LIB_DIRS}) -target_link_libraries(AMS PUBLIC ${AMS_APP_LIBRARIES} stdc++fs) - -#------------------------------------------------------------------------------- -# create the configuration header file with the respective information -#------------------------------------------------------------------------------- -set(CALIPER_DEFINES "// #define __AMS_ENABLE_CALIPER__") -set(MPI_DEFINES "// #define __AMS_ENABLE_MPI__") -set(PERFF_DEFINES "// #define __AMS_ENABLE_PERFFLOWASPECT__") - -if (${WITH_CALIPER}) - set(CALIPER_DEFINES "#define __AMS_ENABLE_CALIPER__") +if (WITH_WORKFLOW) + add_subdirectory(AMSWorkflow) endif() - -if (${WITH_MPI}) - set(MPI_DEFINES "#define __AMS_ENABLE_MPI__") -endif() - -if (${WITH_PERFFLOWASPECT}) - set(PERFF_DEFINES "#define __AMS_ENABLE_PERFFLOWASPECT__") -endif() - -configure_file ("${CMAKE_CURRENT_SOURCE_DIR}/include/AMS-config.h.in" "${PROJECT_BINARY_DIR}/include/AMS-config.h") -configure_file ("${CMAKE_CURRENT_SOURCE_DIR}/include/AMS.h" "${PROJECT_BINARY_DIR}/include/AMS.h" COPYONLY) - -# setup the exec -#SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,-rpath -Wl,$ORIGIN") -# ------------------------------------------------------------------------------ -# installation paths -install(TARGETS AMS - EXPORT AMSTargets - DESTINATION lib) - -install(EXPORT AMSTargets - FILE AMS.cmake - DESTINATION lib/cmake/AMS) - -install(FILES ${PROJECT_BINARY_DIR}/include/AMS.h DESTINATION include) -install(FILES ${PROJECT_BINARY_DIR}/include/AMS-config.h DESTINATION include) diff --git a/src/wf/base64.c b/src/wf/base64.c deleted file mode 100644 index d5d71f02..00000000 --- a/src/wf/base64.c +++ /dev/null @@ -1,282 +0,0 @@ -/** - * Disclaimer: This code comes from flux-core libccan library, it provides - * utility functions to use base64 encoding on strings. base64 is used to - * encode data for data transfers and encapsulation. - * base64 represents 24-bit groups of input bits as output - * strings of 4 encoded characters (cf rfc4648). - * - * This code is licensed under BSD-MIT (see below). - * The original code can be found under flux-core repository. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#include "base64.h" - -#include -#include -#include -#include - - -/** - * @brief sixbit_to_b64 - maps a 6-bit value to the base64 alphabet - * @param[in] map A base 64 map (see base64_init_map) - * @param[in] sixbit Six-bit value to map - * @return a base 64 character - */ -static char sixbit_to_b64(const base64_maps_t *maps, const uint8_t sixbit) -{ - assert(sixbit <= 63); - - return maps->encode_map[(unsigned char)sixbit]; -} - -/** - * @brief sixbit_from_b64 - maps a base64-alphabet character to its 6-bit value - * @param[in] maps A base 64 maps structure (see base64_init_maps) - * @param[in] sixbit Six-bit value to map - * @return a six-bit value - */ -static int8_t sixbit_from_b64(const base64_maps_t *maps, - const unsigned char b64letter) -{ - int8_t ret; - - ret = maps->decode_map[(unsigned char)b64letter]; - if (ret == (int8_t)0xff) { - errno = EDOM; - return -1; - } - - return ret; -} - -bool base64_char_in_alphabet(const base64_maps_t *maps, const char b64char) -{ - return (maps->decode_map[(const unsigned char)b64char] != (int8_t)0xff); -} - -void base64_init_maps(base64_maps_t *dest, const char src[64]) -{ - unsigned char i; - - memcpy(dest->encode_map,src,64); - memset(dest->decode_map,0xff,256); - for (i=0; i<64; i++) { - dest->decode_map[(unsigned char)src[i]] = i; - } -} - -size_t base64_encoded_length(size_t srclen) -{ - return ((srclen + 2) / 3) * 4; -} - -void base64_encode_triplet_using_maps(const base64_maps_t *maps, - char dest[4], const char src[3]) -{ - char a = src[0]; - char b = src[1]; - char c = src[2]; - - dest[0] = sixbit_to_b64(maps, (a & 0xfc) >> 2); - dest[1] = sixbit_to_b64(maps, ((a & 0x3) << 4) | ((b & 0xf0) >> 4)); - dest[2] = sixbit_to_b64(maps, ((c & 0xc0) >> 6) | ((b & 0xf) << 2)); - dest[3] = sixbit_to_b64(maps, c & 0x3f); -} - -void base64_encode_tail_using_maps(const base64_maps_t *maps, char dest[4], - const char *src, const size_t srclen) -{ - char longsrc[3] = { 0 }; - - assert(srclen <= 3); - - memcpy(longsrc, src, srclen); - base64_encode_triplet_using_maps(maps, dest, longsrc); - memset(dest+1+srclen, '=', 3-srclen); -} - -ssize_t base64_encode_using_maps(const base64_maps_t *maps, - char *dest, const size_t destlen, - const char *src, const size_t srclen) -{ - size_t src_offset = 0; - size_t dest_offset = 0; - - if (destlen < base64_encoded_length(srclen)) { - errno = EOVERFLOW; - return -1; - } - - while (srclen - src_offset >= 3) { - base64_encode_triplet_using_maps(maps, &dest[dest_offset], &src[src_offset]); - src_offset += 3; - dest_offset += 4; - } - - if (src_offset < srclen) { - base64_encode_tail_using_maps(maps, &dest[dest_offset], &src[src_offset], srclen-src_offset); - dest_offset += 4; - } - - memset(&dest[dest_offset], '\0', destlen-dest_offset); - - return dest_offset; -} - -size_t base64_decoded_length(size_t srclen) -{ - return ((srclen+3)/4*3); -} - -ssize_t base64_decode_quartet_using_maps(const base64_maps_t *maps, char dest[3], - const char src[4]) -{ - signed char a; - signed char b; - signed char c; - signed char d; - - a = sixbit_from_b64(maps, src[0]); - b = sixbit_from_b64(maps, src[1]); - c = sixbit_from_b64(maps, src[2]); - d = sixbit_from_b64(maps, src[3]); - - if ((a == -1) || (b == -1) || (c == -1) || (d == -1)) { - return -1; - } - - dest[0] = (a << 2) | (b >> 4); - dest[1] = ((b & 0xf) << 4) | (c >> 2); - dest[2] = ((c & 0x3) << 6) | d; - - return 0; -} - - -ssize_t base64_decode_tail_using_maps(const base64_maps_t *maps, char dest[3], - const char * src, const size_t srclen) -{ - char longsrc[4]; - int quartet_result; - size_t insize = srclen; - - while (insize != 0 && - src[insize-1] == '=') { /* throw away padding symbols */ - insize--; - } - if (insize == 0) { - return 0; - } - if (insize == 1) { - /* the input is malformed.... */ - errno = EINVAL; - return -1; - } - memcpy(longsrc, src, insize); - memset(longsrc+insize, 'A', 4-insize); - quartet_result = base64_decode_quartet_using_maps(maps, dest, longsrc); - if (quartet_result == -1) { - return -1; - } - - return insize - 1; -} - -ssize_t base64_decode_using_maps(const base64_maps_t *maps, - char *dest, const size_t destlen, - const char *src, const size_t srclen) -{ - ssize_t dest_offset = 0; - ssize_t i; - ssize_t more; - - if (destlen < base64_decoded_length(srclen)) { - errno = EOVERFLOW; - return -1; - } - - for(i=0; srclen - i > 4; i+=4) { - if (base64_decode_quartet_using_maps(maps, &dest[dest_offset], &src[i]) == -1) { - return -1; - } - dest_offset += 3; - } - - more = base64_decode_tail_using_maps(maps, &dest[dest_offset], &src[i], srclen - i); - if (more == -1) { - return -1; - } - dest_offset += more; - - memset(&dest[dest_offset], '\0', destlen-dest_offset); - - return dest_offset; -} - - - - -/** - * @brief base64_maps_rfc4648 - pregenerated maps struct for rfc4648 - */ -const base64_maps_t base64_maps_rfc4648 = { -"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", - -"\xff\xff\xff\xff\xff" /* 0 */ \ -"\xff\xff\xff\xff\xff" /* 5 */ \ -"\xff\xff\xff\xff\xff" /* 10 */ \ -"\xff\xff\xff\xff\xff" /* 15 */ \ -"\xff\xff\xff\xff\xff" /* 20 */ \ -"\xff\xff\xff\xff\xff" /* 25 */ \ -"\xff\xff\xff\xff\xff" /* 30 */ \ -"\xff\xff\xff\xff\xff" /* 35 */ \ -"\xff\xff\xff\x3e\xff" /* 40 */ \ -"\xff\xff\x3f\x34\x35" /* 45 */ \ -"\x36\x37\x38\x39\x3a" /* 50 */ \ -"\x3b\x3c\x3d\xff\xff" /* 55 */ \ -"\xff\xff\xff\xff\xff" /* 60 */ \ -"\x00\x01\x02\x03\x04" /* 65 A */ \ -"\x05\x06\x07\x08\x09" /* 70 */ \ -"\x0a\x0b\x0c\x0d\x0e" /* 75 */ \ -"\x0f\x10\x11\x12\x13" /* 80 */ \ -"\x14\x15\x16\x17\x18" /* 85 */ \ -"\x19\xff\xff\xff\xff" /* 90 */ \ -"\xff\xff\x1a\x1b\x1c" /* 95 */ \ -"\x1d\x1e\x1f\x20\x21" /* 100 */ \ -"\x22\x23\x24\x25\x26" /* 105 */ \ -"\x27\x28\x29\x2a\x2b" /* 110 */ \ -"\x2c\x2d\x2e\x2f\x30" /* 115 */ \ -"\x31\x32\x33\xff\xff" /* 120 */ \ -"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" /* 125 */ \ -"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" \ -"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" \ -"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" /* 155 */ \ -"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" \ -"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" \ -"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" /* 185 */ \ -"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" \ -"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" \ -"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" /* 215 */ \ -"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" \ -"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" \ -"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" /* 245 */ -}; diff --git a/src/wf/base64.h b/src/wf/base64.h deleted file mode 100644 index a2ee4b91..00000000 --- a/src/wf/base64.h +++ /dev/null @@ -1,271 +0,0 @@ -/** - * Disclaimer: This code comes from flux-core libccan library, it provides - * utility functions to use base64 encoding on strings. base64 is used to - * encode data for data transfers and encapsulation. - * base64 represents 24-bit groups of input bits as output - * strings of 4 encoded characters (cf rfc4648). - * - * This code is licensed under BSD-MIT (see below). - * The original code can be found under flux-core repository. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#ifndef CCAN_BASE64_H -#define CCAN_BASE64_H - -#include -#include -#include - - -/** - * @brief base64_maps_t - structure to hold maps for encode/decode - */ -typedef struct { - char encode_map[64]; - signed char decode_map[256]; -} base64_maps_t; - -/** - * @brief base64_encoded_length - Calculate encode buffer length - * @param[in] srclen the size of the data to be encoded - * @note add 1 to this to get null-termination - * @return Buffer length required for encode - */ -size_t base64_encoded_length(size_t srclen); - -/** - * @brief base64_decoded_length - Calculate decode buffer length - * @param[in] srclen Length of the data to be decoded - * @note This does not return the size of the decoded data! see base64_decode - * @return Minimum buffer length for safe decode - */ -size_t base64_decoded_length(size_t srclen); - -/** - * @brief base64_init_maps - populate a base64_maps_t based on a supplied alphabet - * @param[out] dest A base64 maps object - * @param[in] src Alphabet to populate the maps from (e.g. base64_alphabet_rfc4648) - */ -void base64_init_maps(base64_maps_t *dest, const char src[64]); - - -/** - * @brief base64_encode_triplet_using_maps - encode 3 bytes into base64 using a specific alphabet - * @param[in] maps Maps to use for encoding (see base64_init_maps) - * @param[out] dest Buffer containing 3 bytes - * @param[in] src Buffer containing 4 characters - */ -void base64_encode_triplet_using_maps(const base64_maps_t *maps, - char dest[4], const char src[3]); - -/** - * @brief base64_encode_tail_using_maps - encode the final bytes of a source using a specific alphabet - * @param[in] maps Maps to use for encoding (see base64_init_maps) - * @param[out] dest Buffer containing 4 bytes - * @param[in] src Buffer containing srclen bytes - * @param[in] srclen Number of bytes (<= 3) to encode in src - */ -void base64_encode_tail_using_maps(const base64_maps_t *maps, char dest[4], - const char *src, size_t srclen); - -/** - * @brief base64_encode_using_maps - encode a buffer into base64 using a specific alphabet - * @param[in] maps Maps to use for encoding (see base64_init_maps) - * @param[out] dest Buffer to encode into - * @param[out] destlen Length of dest - * @param[in] src Buffer to encode - * @param[in] srclen Length of the data to encode - * @return Number of encoded bytes set in dest. -1 on error (and errno set) - * @note dest will be nul-padded to destlen (past any required padding) - * @note sets errno = EOVERFLOW if destlen is too small - */ -ssize_t base64_encode_using_maps(const base64_maps_t *maps, - char *dest, size_t destlen, - const char *src, size_t srclen); - -/** - * @brief base64_char_in_alphabet - returns true if character can be part of an encoded string - * @param[in] maps A base64 maps object (see base64_init_maps) - * @param[in] b64char Character to check - * @return True if character can be part of an encoded string - */ -bool base64_char_in_alphabet(const base64_maps_t *maps, char b64char); - -/** - * @brief base64_decode_using_maps - decode a base64-encoded string using a specific alphabet - * @param[in] maps A base64 maps object (see base64_init_maps) - * @param[out] dest Buffer to decode into - * @param[out] destlen length of dest - * @param[in] src the buffer to decode - * @param[in] srclen the length of the data to decode - * @return Number of decoded bytes set in dest. -1 on error (and errno set) - * @note dest will be nul-padded to destlen - * @note sets errno = EOVERFLOW if destlen is too small - * @note sets errno = EDOM if src contains invalid characters - */ -ssize_t base64_decode_using_maps(const base64_maps_t *maps, - char *dest, size_t destlen, - const char *src, size_t srclen); - -/** - * @brief base64_decode_quartet_using_maps - decode 4 bytes from base64 using a specific alphabet - * @param[in] maps A base64 maps object (see base64_init_maps) - * @param[out] dest Buffer containing 3 bytes - * @param[in] src Buffer containing 4 bytes - * @return Number of decoded bytes set in dest. -1 on error (and errno set) - * @note sets errno = EDOM if src contains invalid characters - */ -ssize_t base64_decode_quartet_using_maps(const base64_maps_t *maps, - char dest[3], const char src[4]); - -/** - * @brief base64_decode_tail_using_maps - decode the final bytes of a base64 string using a specific alphabet - * @param[in] maps A base64 maps object (see base64_init_maps) - * @param[out] dest Buffer containing 3 bytes - * @param[in] src Buffer containing 4 bytes - padded with '=' as required - * @param[in] srclen Number of bytes to decode in src - * @return Number of decoded bytes set in dest. -1 on error (and errno set) - * @note sets errno = EDOM if src contains invalid characters - * @note sets errno = EINVAL if src is an invalid base64 tail - */ -ssize_t base64_decode_tail_using_maps(const base64_maps_t *maps, char dest[3], - const char *src, size_t srclen); - - -/* The rfc4648 functions: */ - -extern const base64_maps_t base64_maps_rfc4648; - -/** - * @brief base64_encode - Encode a buffer into base64 according to rfc4648 - * @param[out] dest Buffer to encode into - * @param[out] destlen Length of the destination buffer - * @param[in] src Buffer to encode - * @param[in] srclen Length of the data to encode - * @return Number of encoded bytes set in dest. -1 on error (and errno set) - * @note dest will be nul-padded to destlen (past any required padding) - * @note sets errno = EOVERFLOW if destlen is too small - * - * @details This function encodes src according to http://tools.ietf.org/html/rfc4648 - * - * Example: - * size_t encoded_length; - * char dest[100]; - * const char *src = "This string gets encoded"; - * encoded_length = base64_encode(dest, sizeof(dest), src, strlen(src)); - * printf("Returned data of length %zd @%p\n", encoded_length, &dest); - */ -static inline -ssize_t base64_encode(char *dest, size_t destlen, - const char *src, size_t srclen) -{ - return base64_encode_using_maps(&base64_maps_rfc4648, - dest, destlen, src, srclen); -} - -/** - * @brief base64_encode_triplet - encode 3 bytes into base64 according to rfc4648 - * @param[out] dest Buffer containing 4 bytes - * @param[in] src Buffer containing 3 bytes - */ -static inline -void base64_encode_triplet(char dest[4], const char src[3]) -{ - base64_encode_triplet_using_maps(&base64_maps_rfc4648, dest, src); -} - -/** - * @brief base64_encode_tail - encode the final bytes of a source according to rfc4648 - * @param[out] dest Buffer containing 4 bytes - * @param[in] src Buffer containing srclen bytes - * @param[in] srclen Number of bytes (<= 3) to encode in src - */ -static inline -void base64_encode_tail(char dest[4], const char *src, size_t srclen) -{ - base64_encode_tail_using_maps(&base64_maps_rfc4648, dest, src, srclen); -} - - -/** - * @brief base64_decode - decode An rfc4648 base64-encoded string - * @param[out] dest Buffer to decode into - * @param[out] destlen Length of the destination buffer - * @param[in] src Buffer to decode - * @param[in] srclen Length of the data to decode - * @return Number of decoded bytes set in dest. -1 on error (and errno set) - * @note dest will be nul-padded to destlen - * @note sets errno = EOVERFLOW if destlen is too small - * @note sets errno = EDOM if src contains invalid characters - * - * @details This function decodes the buffer according to - * http://tools.ietf.org/html/rfc4648 - * - * Example: - * size_t decoded_length; - * char ret[100]; - * const char *src = "Zm9vYmFyYmF6"; - * decoded_length = base64_decode(ret, sizeof(ret), src, strlen(src)); - * printf("Returned data of length %zd @%p\n", decoded_length, &ret); - */ -static inline -ssize_t base64_decode(char *dest, size_t destlen, - const char *src, size_t srclen) -{ - return base64_decode_using_maps(&base64_maps_rfc4648, - dest, destlen, src, srclen); -} - -/** - * @brief base64_decode_quartet - decode the first 4 characters in src into dest - * @param[out] dest Buffer containing 3 bytes - * @param[in] src Buffer containing 4 characters - * @return Number of decoded bytes set in dest. -1 on error (and errno set) - * @note sets errno = EDOM if src contains invalid characters - */ -static inline -ssize_t base64_decode_quartet(char dest[3], const char src[4]) -{ - return base64_decode_quartet_using_maps(&base64_maps_rfc4648, - dest, src); -} - -/** - * @brief decode the final bytes of a base64 string from src into dest - * @param[out] dest Buffer containing 3 bytes - * @param[in] src Buffer containing 4 bytes - padded with '=' as required - * @param[in] srclen Number of bytes to decode in src - * @return Number of decoded bytes set in dest. -1 on error (and errno set) - * @note sets errno = EDOM if src contains invalid characters - * @note sets errno = EINVAL if src is an invalid base64 tail - */ -static inline -ssize_t base64_decode_tail(char dest[3], const char *src, size_t srclen) -{ - return base64_decode_tail_using_maps(&base64_maps_rfc4648, - dest, src, srclen); -} - -/* end rfc4648 functions */ - - - -#endif /* CCAN_BASE64_H */ \ No newline at end of file diff --git a/src/wf/basedb.hpp b/src/wf/basedb.hpp deleted file mode 100644 index 74bb2e3e..00000000 --- a/src/wf/basedb.hpp +++ /dev/null @@ -1,1661 +0,0 @@ -/* - * Copyright 2021-2023 Lawrence Livermore National Security, LLC and other - * AMSLib Project Developers - * - * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - */ - -#ifndef __AMS_BASE_DB__ -#define __AMS_BASE_DB__ - -#include -#include -#include -#include -#include -#include - -#include "AMS.h" -#include "wf/debug.h" -#include "wf/utils.hpp" - -namespace fs = std::experimental::filesystem; - -#ifdef __ENABLE_REDIS__ -#include - -#include -// TODO: We should comment out "using" in header files as -// it propagates to every other file including this file -#warning Redis is currently not supported/tested -using namespace sw::redis; -#endif - - -#ifdef __ENABLE_HDF5__ -#include - -#define HDF5_ERROR(Eid) \ - if (Eid < 0) { \ - std::cerr << "[Error] Happened in " << __FILE__ << ":" \ - << __PRETTY_FUNCTION__ << " ( " << __LINE__ << ")\n"; \ - exit(-1); \ - } -#endif - -#ifdef __ENABLE_RMQ__ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif -#include "base64.h" -#ifdef __cplusplus -} -#endif - -#endif // __ENABLE_RMQ__ - -/** - * @brief A simple pure virtual interface to store data in some - * persistent storage device - */ -template -class BaseDB -{ - /** @brief unique id of the process running this simulation */ - uint64_t id; - -public: - BaseDB(const BaseDB&) = delete; - BaseDB& operator=(const BaseDB&) = delete; - - BaseDB(uint64_t id) : id(id) {} - virtual ~BaseDB() {} - - /** - * @brief Define the type of the DB (File, Redis etc) - */ - virtual std::string type() = 0; - - /** - * @brief Takes an input and an output vector each holding 1-D vectors data, and - * store. them in persistent data storage. - * @param[in] num_elements Number of elements of each 1-D vector - * @param[in] inputs Vector of 1-D vectors containing the inputs to be stored - * @param[in] inputs Vector of 1-D vectors, each 1-D vectors contains - * 'num_elements' values to be stored - * @param[in] outputs Vector of 1-D vectors, each 1-D vectors contains - * 'num_elements' values to be stored - */ - virtual void store(size_t num_elements, - std::vector& inputs, - std::vector& outputs) = 0; -}; - -/** - * @brief A pure virtual interface for data bases storing data using - * some file format (filesystem DB). - */ -template -class FileDB : public BaseDB -{ -protected: - /** @brief Path to file to write data to */ - std::string fn; - /** @brief absolute path to directory storing the data */ - std::string fp; - - /** - * @brief check error code, if it exists print message and exit application - * @param[in] ec error code - */ - void checkError(std::error_code& ec) - { - if (ec) { - std::cerr << "Error in is_regular_file: " << ec.message(); - exit(-1); - } - } - -public: - /** - * @brief Takes an input and an output vector each holding 1-D vectors data, and - * store. them in persistent data storage. - * @param[in] path Path to an existing directory where to store our data - * @param[in] suffix The suffix of the file to write to - * @param[in] rId a unique Id for each process taking part in a distributed - * execution (rank-id) - * */ - FileDB(std::string path, const std::string suffix, uint64_t rId) - : BaseDB(rId) - { - fs::path Path(path); - std::error_code ec; - - if (!fs::exists(Path, ec)) { - std::cerr << "[ERROR]: Path:'" << path << "' does not exist\n"; - exit(-1); - } - - checkError(ec); - - if (!fs::is_directory(Path, ec)) { - std::cerr << "[ERROR]: Path:'" << path << "' is a file NOT a directory\n"; - exit(-1); - } - - Path = fs::absolute(Path); - fp = Path.string(); - - // We can now create the filename - std::string dbfn("data_"); - dbfn += std::to_string(rId) + suffix; - Path /= fs::path(dbfn); - fn = Path.string(); - DBG(DB, "File System DB writes to file %s", fn.c_str()) - } -}; - - -template -class csvDB final : public FileDB -{ -private: - /** @brief file descriptor */ - std::fstream fd; - -public: - csvDB(const csvDB&) = delete; - csvDB& operator=(const csvDB&) = delete; - - /** - * @brief constructs the class and opens the file to write to - * @param[in] fn Name of the file to store data to - * @param[in] rId a unique Id for each process taking part in a distributed - * execution (rank-id) - */ - csvDB(std::string path, uint64_t rId) : FileDB(path, ".csv", rId) - { - fd.open(this->fn, std::ios_base::app | std::ios_base::out); - if (!fd.is_open()) { - std::cerr << "Cannot open db file: " << this->fn << std::endl; - } - DBG(DB, "DB Type: %s", type()) - } - - /** - * @brief deconstructs the class and closes the file - */ - ~csvDB() { fd.close(); } - - /** - * @brief Define the type of the DB (File, Redis etc) - */ - std::string type() override { return "csv"; } - - /** - * @brief Takes an input and an output vector each holding 1-D vectors data, and - * store them into a csv file delimited by ':'. This should never be used for - * large scale simulations as txt/csv format will be extremely slow. - * @param[in] num_elements Number of elements of each 1-D vector - * @param[in] inputs Vector of 1-D vectors containing the inputs to bestored - * @param[in] inputs Vector of 1-D vectors, each 1-D vectors contains - * 'num_elements' values to be stored - * @param[in] outputs Vector of 1-D vectors, each 1-D vectors contains - * 'num_elements' values to be stored - */ - PERFFASPECT() - virtual void store(size_t num_elements, - std::vector& inputs, - std::vector& outputs) override - { - DBG(DB, - "DB of type %s stores %ld elements of input/output dimensions (%d, %d)", - type().c_str(), - num_elements, - inputs.size(), - outputs.size()) - - const size_t num_in = inputs.size(); - const size_t num_out = outputs.size(); - - for (size_t i = 0; i < num_elements; i++) { - for (size_t j = 0; j < num_in; j++) { - fd << inputs[j][i] << ":"; - } - - for (size_t j = 0; j < num_out - 1; j++) { - fd << outputs[j][i] << ":"; - } - fd << outputs[num_out - 1][i] << "\n"; - } - } -}; - -#ifdef __ENABLE_HDF5__ - -template -class hdf5DB final : public FileDB -{ -private: - /** @brief file descriptor */ - hid_t HFile; - /** @brief vector holding the hdf5 dataset descriptor. - * We currently store every input on a separate dataset - */ - std::vector HDIsets; - - /** @brief vector holding the hdf5 dataset descriptor. - * We currently store every output on a separate dataset - */ - std::vector HDOsets; - - /** @brief Total number of elements we have in our file */ - hsize_t totalElements; - - /** @brief HDF5 associated data type with specific TypeValue type */ - hid_t HDType; - - /** @brief create or get existing hdf5 dataset with the provided name - * storing data as Ckunked pieces. The Chunk value controls the chunking - * performed by HDF5 and thus controls the write performance - * @param[in] group in which we will store data under - * @param[in] dName name of the data set - * @param[in] Chunk chunk size of dataset used by HDF5. - * @reval dataset HDF5 key value - */ - hid_t getDataSet(hid_t group, - std::string dName, - const size_t Chunk = 32L * 1024L * 1024L) - { - // Our datasets a.t.m are 1-D vectors - const int nDims = 1; - // We always start from 0 - hsize_t dims = 0; - hid_t dset = -1; - - int exists = H5Lexists(group, dName.c_str(), H5P_DEFAULT); - - if (exists > 0) { - dset = H5Dopen(group, dName.c_str(), H5P_DEFAULT); - HDF5_ERROR(dset); - // We are assuming symmetrical data sets a.t.m - if (totalElements == 0) { - hid_t dspace = H5Dget_space(dset); - const int ndims = H5Sget_simple_extent_ndims(dspace); - hsize_t dims[ndims]; - H5Sget_simple_extent_dims(dspace, dims, NULL); - totalElements = dims[0]; - } - return dset; - } else { - // We will extend the data-set size, so we use unlimited option - hsize_t maxDims = H5S_UNLIMITED; - hid_t fileSpace = H5Screate_simple(nDims, &dims, &maxDims); - HDF5_ERROR(fileSpace); - - hid_t pList = H5Pcreate(H5P_DATASET_CREATE); - HDF5_ERROR(pList); - - herr_t ec = H5Pset_layout(pList, H5D_CHUNKED); - HDF5_ERROR(ec); - - // cDims impacts performance considerably. - // TODO: Align this with the caching mechanism for this option to work - // out. - hsize_t cDims = Chunk; - H5Pset_chunk(pList, nDims, &cDims); - dset = H5Dcreate(group, - dName.c_str(), - HDType, - fileSpace, - H5P_DEFAULT, - pList, - H5P_DEFAULT); - HDF5_ERROR(dset); - H5Sclose(fileSpace); - H5Pclose(pList); - } - return dset; - } - - /** - * @brief Create the HDF5 datasets and store their descriptors in the in/out - * vectors - * @param[in] num_elements of every vector - * @param[in] numIn number of input 1-D vectors - * @param[in] numOut number of output 1-D vectors - */ - void createDataSets(size_t numElements, - const size_t numIn, - const size_t numOut) - { - for (int i = 0; i < numIn; i++) { - hid_t dSet = getDataSet(HFile, std::string("input_") + std::to_string(i)); - HDIsets.push_back(dSet); - } - - for (int i = 0; i < numOut; i++) { - hid_t dSet = - getDataSet(HFile, std::string("output_") + std::to_string(i)); - HDOsets.push_back(dSet); - } - } - - /** - * @brief Write all the data in the vectors in the respective datasets. - * @param[in] dsets Vector containing the hdf5-dataset descriptor for every - * vector to be written - * @param[in] data vectors containing 1-D vectors of numElements values each - * to be written in the db. - * @param[in] numElements The number of elements each vector has - */ - void writeDataToDataset(std::vector& dsets, - std::vector& data, - size_t numElements) - { - int index = 0; - for (auto* I : data) { - writeVecToDataset(dsets[index++], static_cast(I), numElements); - } - } - - /** @brief Writes a single 1-D vector to the dataset - * @param[in] dSet the dataset to write the data to - * @param[in] data the data we need to write - * @param[in] elements the number of data elements we have - */ - void writeVecToDataset(hid_t dSet, void* data, size_t elements) - { - const int nDims = 1; - hsize_t dims = elements; - hsize_t start; - hsize_t count; - hid_t memSpace = H5Screate_simple(nDims, &dims, NULL); - HDF5_ERROR(memSpace); - - dims = totalElements + elements; - H5Dset_extent(dSet, &dims); - - hid_t fileSpace = H5Dget_space(dSet); - HDF5_ERROR(fileSpace); - - // Data set starts at offset totalElements - start = totalElements; - // And we append additional elements - count = elements; - // Select hyperslab - herr_t err = H5Sselect_hyperslab( - fileSpace, H5S_SELECT_SET, &start, NULL, &count, NULL); - HDF5_ERROR(err); - - H5Dwrite(dSet, HDType, memSpace, fileSpace, H5P_DEFAULT, data); - H5Sclose(fileSpace); - } - -public: - // Delete copy constructors. We do not want to copy the DB around - hdf5DB(const hdf5DB&) = delete; - hdf5DB& operator=(const hdf5DB&) = delete; - - /** - * @brief constructs the class and opens the hdf5 file to write to - * @param[in] fn Name of the file to store data to - * @param[in] rId a unique Id for each process taking part in a distributed - * execution (rank-id) - */ - hdf5DB(std::string path, uint64_t rId) : FileDB(path, ".h5", rId) - { - if (isDouble::default_value()) - HDType = H5T_NATIVE_DOUBLE; - else - HDType = H5T_NATIVE_FLOAT; - std::error_code ec; - bool exists = fs::exists(this->fn); - this->checkError(ec); - - if (exists) - HFile = H5Fopen(this->fn.c_str(), H5F_ACC_RDWR, H5P_DEFAULT); - else - HFile = - H5Fcreate(this->fn.c_str(), H5F_ACC_EXCL, H5P_DEFAULT, H5P_DEFAULT); - HDF5_ERROR(HFile); - totalElements = 0; - } - - /** - * @brief deconstructs the class and closes the file - */ - ~hdf5DB() - { - // HDF5 Automatically closes all opened fds at exit of application. - // herr_t err = H5Fclose(HFile); - // HDF5_ERROR(err); - } - - /** - * @brief Define the type of the DB - */ - std::string type() override { return "hdf5"; } - - /** - * @brief Takes an input and an output vector each holding 1-D vectors data, - * and store them into a hdf5 file delimited by ':'. This should never be used - * for large scale simulations as txt/hdf5 format will be extremely slow. - * @param[in] num_elements Number of elements of each 1-D vector - * @param[in] inputs Vector of 1-D vectors containing the inputs to bestored - * @param[in] inputs Vector of 1-D vectors, each 1-D vectors contains - * 'num_elements' values to be stored - * @param[in] outputs Vector of 1-D vectors, each 1-D vectors contains - * 'num_elements' values to be stored - */ - PERFFASPECT() - virtual void store(size_t num_elements, - std::vector& inputs, - std::vector& outputs) override - { - - DBG(DB, - "DB of type %s stores %ld elements of input/output dimensions (%d, %d)", - type().c_str(), - num_elements, - inputs.size(), - outputs.size()) - const size_t num_in = inputs.size(); - const size_t num_out = outputs.size(); - - if (HDIsets.empty()) { - createDataSets(num_elements, num_in, num_out); - } - - if (HDIsets.size() != num_in || HDOsets.size() != num_out) { - std::cerr << "The data dimensionality is different than the one in the " - "DB\n"; - exit(-1); - } - - writeDataToDataset(HDIsets, inputs, num_elements); - writeDataToDataset(HDOsets, outputs, num_elements); - totalElements += num_elements; - } -}; -#endif - -#ifdef __ENABLE_REDIS__ -template -class RedisDB : public BaseDB -{ - const std::string _fn; // path to the file storing the DB access config - uint64_t _dbid; - Redis* _redis; - uint64_t keyId; - -public: - RedisDB(const RedisDB&) = delete; - RedisDB& operator=(const RedisDB&) = delete; - - /** - * @brief constructs the class and opens the file to write to - * @param[in] fn Name of the file to store data to - * @param[in] rId a unique Id for each process taking part in a distributed - * execution (rank-id) - */ - RedisDB(std::string fn, uint64_t rId) - : BaseDB(rId), _fn(fn), _redis(nullptr), keyId(0) - { - _dbid = reinterpret_cast(this); - auto connection_info = read_json(fn); - - ConnectionOptions connection_options; - connection_options.type = ConnectionType::TCP; - connection_options.host = connection_info["host"]; - connection_options.port = std::stoi(connection_info["service-port"]); - connection_options.password = connection_info["database-password"]; - connection_options.db = 0; // Optionnal, 0 is the default - connection_options.tls.enabled = - true; // Required to connect to PDS within LC - connection_options.tls.cacert = connection_info["cert"]; - - ConnectionPoolOptions pool_options; - pool_options.size = 100; // Pool size, i.e. max number of connections. - - _redis = new Redis(connection_options, pool_options); - } - - ~RedisDB() - { - std::cerr << "Deleting RedisDB object\n"; - delete _redis; - } - - inline std::string type() override { return "RedisDB"; } - - inline std::string info() { return _redis->info(); } - - // Return the number of keys in the DB - inline long long dbsize() { return _redis->dbsize(); } - - /* ! - * ! WARNING: Flush the entire Redis, accross all DBs! - * ! - */ - inline void flushall() { _redis->flushall(); } - - /* - * ! WARNING: Flush the entire current DB! - * ! - */ - inline void flushdb() { _redis->flushdb(); } - - std::unordered_map read_json(std::string fn) - { - std::ifstream config; - std::unordered_map connection_info = { - {"database-password", ""}, - {"host", ""}, - {"service-port", ""}, - {"cert", ""}, - }; - - config.open(fn, std::ifstream::in); - if (config.is_open()) { - std::string line; - // Quite inefficient parsing (to say the least..) but the file to parse is - // small (4 lines) - // TODO: maybe use Boost or another JSON library - while (std::getline(config, line)) { - if (line.find("{") != std::string::npos || - line.find("}") != std::string::npos) { - continue; - } - line.erase(std::remove(line.begin(), line.end(), ' '), line.end()); - line.erase(std::remove(line.begin(), line.end(), ','), line.end()); - line.erase(std::remove(line.begin(), line.end(), '"'), line.end()); - - std::string key = line.substr(0, line.find(':')); - line.erase(0, line.find(":") + 1); - connection_info[key] = line; - // std::cerr << "key=" << key << " and value=" << line << std::endl; - } - config.close(); - } else { - std::cerr << "Config located at: " << fn << std::endl; - throw std::runtime_error("Could not open Redis config file"); - } - return connection_info; - } - - void store(size_t num_elements, - std::vector& inputs, - std::vector& outputs) - { - - const size_t num_in = inputs.size(); - const size_t num_out = outputs.size(); - - // TODO: - // Make insertion more efficient. - // Right now it's pretty naive and expensive - auto start = std::chrono::high_resolution_clock::now(); - - for (size_t i = 0; i < num_elements; i++) { - std::string key = std::to_string(_dbid) + ":" + std::to_string(keyId) + - ":" + - std::to_string(i); // In Redis a key must be a string - std::ostringstream fd; - for (size_t j = 0; j < num_in; j++) { - fd << inputs[j][i] << ":"; - } - for (size_t j = 0; j < num_out - 1; j++) { - fd << outputs[j][i] << ":"; - } - fd << outputs[num_out - 1][i]; - std::string val(fd.str()); - _redis->set(key, val); - } - - keyId += 1; - - auto stop = std::chrono::high_resolution_clock::now(); - auto duration = - std::chrono::duration_cast(stop - start); - auto nb_keys = this->dbsize(); - - std::cout << std::setprecision(2) << "Inserted " << num_elements - << " keys [Total keys = " << nb_keys << "] into RedisDB [Total " - << duration.count() << "ms, " - << static_cast(num_elements) / duration.count() - << " key/ms]" << std::endl; - } -}; - -#endif // __ENABLE_HDF5__ - -#ifdef __ENABLE_RMQ__ - -/** @brief Structure that represents a JSON structure */ -typedef std::unordered_map json; -/** @brief Structure that represents a received RabbitMQ message */ -typedef std::tuple - inbound_msg; - -/** - * @brief Structure that is passed to each worker thread that sends data. - */ -struct rmq_sender { - struct event_base* loop; - pthread_t id; -}; - -/** - * @brief Worker function responsible of starting the event loop for each thread. - * @param[in] arg a pointer on a worker structure - */ -void* start_worker_sender(void* arg) -{ - struct rmq_sender* w = (struct rmq_sender*)arg; - event_base_dispatch(w->loop); - return NULL; -} - -/** - * @brief Structure that is passed to each worker thread receiving data. - */ -struct rmq_consumer { - struct event_base* loop; - pthread_t id; - std::shared_ptr channel; - std::string queue; - std::shared_ptr> messages; // Messages received -}; - -/** - * @brief Worker function responsible of starting the event loop for each thread. - * @param[in] arg a pointer on a worker structure - */ -void* start_worker_consumer(void* arg) -{ - struct rmq_consumer* w = (struct rmq_consumer*)arg; - - // callback function that is called when the consume operation starts - auto startCb = [](const std::string& consumertag) { - DBG(RabbitMQDB, - "consume operation started with tag: %s", - consumertag.c_str()) - }; - - // callback function that is called when the consume operation failed - auto errorCb = [](const char* message) { - CFATAL(RabbitMQDB, false, "consume operation failed: %s", message); - }; - // callback operation when a message was received - auto messageCb = [w](const AMQP::Message& message, - uint64_t deliveryTag, - bool redelivered) { - // acknowledge the message - w->channel->ack(deliveryTag); - std::string s(message.body(), message.bodySize()); - w->messages->push_back(std::make_tuple(std::move(s), - message.exchange(), - message.routingkey(), - deliveryTag, - redelivered)); - DBG(RabbitMQDB, - "message received [tag=%d] : '%s' of size %d B from '%s'/'%s'", - deliveryTag, - s.c_str(), - message.bodySize(), - message.exchange().c_str(), - message.routingkey().c_str()) - }; - - /* callback that is called when the consumer is cancelled by RabbitMQ (this - * only happens in rare situations, for example when someone removes the queue - * that you are consuming from) - */ - auto cancelledCb = [](const std::string& consumertag) { - WARNING(RabbitMQDB, - "consume operation cancelled by the RabbitMQ server: %s", - consumertag.c_str()) - }; - - // start consuming from the queue, and install the callbacks - w->channel->consume(w->queue) - .onReceived(messageCb) - .onSuccess(startCb) - .onCancelled(cancelledCb) - .onError(errorCb); - - // We start the event loop - event_base_dispatch(w->loop); - return NULL; -} - -/** - * @brief Specific handler for RabbitMQ connections based on libevent. - */ -class RabbitMQHandler : public AMQP::LibEventHandler -{ -private: - /** @brief Path to TLS certificate */ - const char* _cacert; - /** @brief The MPI rank (0 if MPI is not used) */ - int _rank; - -public: - /** - * @brief Constructor - * @param[in] loop Event Loop - * @param[in] cacert SSL Cacert - * @param[in] rank MPI rank - */ - RabbitMQHandler(int rank, struct event_base* loop, std::string cacert) - : AMQP::LibEventHandler(loop), _rank(rank), _cacert(cacert.c_str()) - { - } - virtual ~RabbitMQHandler() = default; - -private: - /** - * @brief Method that is called after a TCP connection has been set up, and - * right before the SSL handshake is going to be performed to secure the - * connection (only for amqps:// connections). This method can be overridden - * in user space to load client side certificates. - * @param[in] connection The connection for which TLS was just started - * @param[in] ssl Pointer to the SSL structure that can be - * modified - * @return bool True to proceed / accept the connection, false - * to break up - */ - virtual bool onSecuring(AMQP::TcpConnection* connection, SSL* ssl) - { - ERR_clear_error(); - unsigned long err; -#if OPENSSL_VERSION_NUMBER < 0x10100000L - int ret = SSL_use_certificate_file(ssl, _cacert, SSL_FILETYPE_PEM); -#else - int ret = SSL_use_certificate_chain_file(ssl, _cacert); -#endif - if (ret != 1) { - std::string error("openssl: error loading ca-chain from ["); - SSL_get_error(ssl, ret); - if ((err = ERR_get_error())) { - error += std::string(ERR_reason_error_string(err)); - } - error += "]"; - throw std::runtime_error(error); - } else { - DBG(RabbitMQDB, "Success logged with ca-chain %s", _cacert) - return true; - } - } - - /** - * @brief Method that is called when the secure TLS connection has been - * established. This is only called for amqps:// connections. It allows you to - * inspect whether the connection is secure enough for your liking (you can - * for example check the server certificate). The AMQP protocol still has - * to be started. - * @param[in] connection The connection that has been secured - * @param[in] ssl SSL structure from openssl library - * @return bool True if connection can be used - */ - virtual bool onSecured(AMQP::TcpConnection* connection, - const SSL* ssl) override - { - DBG(RabbitMQDB, - "[rank=%d][ info ] Secured TLS connection has been established", - _rank) - return true; - } - - /** - * @brief Method that is called by the AMQP library when the login attempt - * succeeded. After this the connection is ready to use. - * @param[in] connection The connection that can now be used - */ - virtual void onReady(AMQP::TcpConnection* connection) override - { - DBG(RabbitMQDB, - "[rank=%d][ ok ] Sucessfuly logged in. Connection ready to use!\n", - _rank) - } - - /** - * @brief Method that is called by the AMQP library when a fatal error occurs - * on the connection, for example because data received from RabbitMQ - * could not be recognized, or the underlying connection is lost. This - * call is normally followed by a call to onLost() (if the error occurred - * after the TCP connection was established) and onDetached(). - * @param[in] connection The connection on which the error occurred - * @param[in] message A human readable error message - */ - virtual void onError(AMQP::TcpConnection* connection, - const char* message) override - { - DBG(RabbitMQDB, - "[rank=%d] fatal error when establishing TCP connection: %s\n", - _rank, - message) - } -}; // class RabbitMQHandler - -/** - * @brief An EventBuffer encapsulates an evbuffer (libevent structure). - * Each time data is pushed to the underlying evbuffer, the callback will be - * called. - */ -template -class EventBuffer -{ -private: - /** @brief AMQP reliable channel (wrapper of a classic channel with added functionalities) */ - std::shared_ptr> _rchannel; - /** @brief Name of the RabbitMQ queue */ - std::string _queue; - /** @brief Internal queue of messages to send (data, num_elements) */ - std::deque> _messages; - /** @brief Total number of bytes that must be send */ - size_t _byte_to_send; - /** @brief MPI rank */ - int _rank; - /** @brief Thread ID */ - pthread_t _tid; - /** @brief Event loop */ - struct event_base* _loop; - /** @brief The buffer event structure */ - struct evbuffer* _buffer; - /** @brief Signal event for exiting properly the loop */ - struct event* _signal_exit; - /** @brief Signal event for exiting properly the loop */ - struct event* _signal_term; - /** @brief Custom signal code (by default SIGUSR1) that can be intercepted */ - int _sig_exit; - /** @brief Internal counter for number of messages acknowledged */ - int _counter_ack; - /** @brief Internal counter for number of messages negatively acknowledged */ - int _counter_nack; - - /** - * @brief Callback method that is called by libevent when data is being - * added to the buffer event - * @param[in] fd The loop in which the event was triggered - * @param[in] event Internal timer object - * @param[in] argc The events that triggered this call - */ - static void callback_commit(struct evbuffer* buffer, - const struct evbuffer_cb_info* info, - void* arg) - { - EventBuffer* self = static_cast(arg); - // we remove only if some byte got added (this callback will get - // trigger when data is added AND removed from the buffer - DBG(RabbitMQDB, - "evbuffer_cb_info(lenght=%zu): n_added=%zu B, n_deleted=%zu B, " - "orig_size=%zu B", - evbuffer_get_length(buffer), - info->n_added, - info->n_deleted, - info->orig_size) - - if (info->n_added > 0) { - // Destination buffer (of TypeValue size, either float or double) - size_t datlen = info->n_added; // Total number of bytes - int k = datlen / sizeof(TypeValue); - if (datlen % sizeof(TypeValue) != 0) { - CFATAL(RabbitMQDB, - false, - "Buffer seems corrupted or not the right type of TypeValue"); - } - auto data = std::make_unique(datlen); - - evbuffer_lock(buffer); - // Now we drain the evbuffer structure to fill up the destination buffer - int nbyte_drained = evbuffer_remove(buffer, data.get(), datlen); - if (nbyte_drained < 0) { - WARNING(RabbitMQDB, - "evbuffer_remove(): cannot remove %d data from buffer", - nbyte_drained); - } - evbuffer_unlock(buffer); - - std::string result = - std::to_string(self->_rank) + ":"; - for (int i = 0; i < k - 1; i++) { - result.append(std::to_string(data[i]) + ":"); - } - result.append(std::to_string(data[k - 1]) + "\n"); - - // For resiliency reasons we encode the result in base64 - // Not that it increases the size (n) of messages by approx 4*(n/3) - std::string result_b64 = self->encode64(result); - DBG(RabbitMQDB, - "[rank=%d] #elements (float/double) = %d, stringify size = %d, size in base64 " - "= %d", - self->_rank, - k, - result.size(), - result_b64.size()) - if (result_b64.size() % 4 != 0) { - WARNING(EventBuffer, - "[rank=%d] Frame size (%d elements)" - "cannot be %d more than a multiple of 4!", - self->_rank, - result_b64.size(), - result_b64.size() % 4) - } - - // publish a message via the reliable-channel - self->_rchannel->publish("", self->_queue, result_b64) - .onAck([self, nbyte_drained]() { - DBG(RabbitMQDB, - "[rank=%d] message got ack successfully", - self->_rank) - self->_byte_to_send = self->_byte_to_send - nbyte_drained; - self->_counter_ack++; - }) - .onNack([self]() { - self->_counter_nack++; - WARNING(RabbitMQDB, "[rank=%d] message negative ack", self->_rank) - }) - .onLost([self]() { - CFATAL(RabbitMQDB, false, "[rank=%d] message got lost", self->_rank) - }) - .onError([self](const char* message) { - CFATAL(RabbitMQDB, - false, - "[rank=%d] message did not get send: %s", - self->_rank, - message) - }); - } - } - - /** - * @brief Callback method that is called by libevent when the signal sig is - * intercepted - * @param[in] fd The loop in which the event was triggered - * @param[in] event Internal event object (evsignal in this case) - * @param[in] argc The events that triggered this call - */ - static void callback_exit(int fd, short event, void* argc) - { - EventBuffer* self = static_cast(argc); - DBG(RabbitMQDB, - "caught an interrupt signal; exiting cleanly event loop after %d " - "messages ack (%d negative ack) ...", - self->_counter_ack, - self->_counter_nack) - event_base_loopexit(self->_loop, NULL); - } - -public: - /** - * @brief Constructor - * @param[in] loop Event loop (Libevent in this case) - * @param[in] channel AMQP TCP channel - * @param[in] queue Name of the queue the Event Buffer will publish on - */ - EventBuffer(int rank, - struct event_base* loop, - std::shared_ptr channel, - std::string queue) - : _rank(rank), - _loop(loop), - _buffer(nullptr), - _rchannel(std::make_shared>(*channel.get())), - _queue(std::move(queue)), - _byte_to_send(0), - _counter_ack(0), - _counter_nack(0) - { - pthread_t _tid = pthread_self(); - // initialize the libev buff event structure - _buffer = evbuffer_new(); - evbuffer_add_cb(_buffer, callback_commit, this); - /** - * Force all the callbacks on an evbuffer to be run not immediately after - * the evbuffer is altered, but instead from inside the event loop. - * Without that, the call to callback() would block the main thread. - */ - evbuffer_defer_callbacks(_buffer, _loop); - // We install signal callbacks - _sig_exit = SIGUSR1; - _signal_exit = evsignal_new(_loop, _sig_exit, callback_exit, this); - event_add(_signal_exit, NULL); - _signal_term = evsignal_new(_loop, SIGTERM, callback_exit, this); - event_add(_signal_term, NULL); - } - - /** - * @brief Return the size of the buffer in bytes. - * @return Buffer size in bytes. - */ - size_t size() { return evbuffer_get_length(_buffer); } - - /** - * @brief Return True if the buffer is empty. - * @return True if the number of bytes that has to be sent is equals to 0. - */ - bool is_drained() - { - return (_byte_to_send == 0) && (evbuffer_get_length(_buffer) == 0); - } - - /** - * @brief Push data to the underlying event buffer, which - * will trigger the callback. - * @return The number of bytes that has to be sent. - */ - size_t get_byte_to_send() { return _byte_to_send; } - - /** - * @brief Push data to the underlying event buffer, which - * will trigger the callback. - * @param[in] data The data pointer - * @param[in] data_size The number of bytes in the data pointer - */ - void push(void* data, size_t data_size) - { - evbuffer_lock(_buffer); - DBG(RabbitMQDB, - "[push()] adding %zu B to buffer => " - "size of evbuffer %zu", - data_size, - size()) - - int ret = evbuffer_add(_buffer, data, data_size); - if (ret == -1) { - perror("Error evbuffer_add()\n"); - } - _byte_to_send = _byte_to_send + data_size; - evbuffer_unlock(_buffer); - } - - /** - * @brief Method to encode a string into base64 - * @param[in] input The input string - * @return The encoded string - */ - std::string encode64(const std::string& input) - { - if (input.size() == 0) return ""; - size_t unencoded_length = input.size(); - size_t encoded_length = base64_encoded_length(unencoded_length); - char* base64_encoded_string = - (char*)malloc((encoded_length + 1) * sizeof(char)); - ssize_t encoded_size = base64_encode(base64_encoded_string, - encoded_length + 1, - input.c_str(), - unencoded_length); - std::string result(base64_encoded_string); - free(base64_encoded_string); - return result; - } - - /** @brief Destructor */ - ~EventBuffer() - { - evbuffer_free(_buffer); - event_free(_signal_exit); - event_free(_signal_term); - } -}; // class EventBuffer - -/** - * @brief Class that manages a RabbitMQ broker and handles connection, event - * loop and set up various handlers. - */ -template -class RabbitMQDB final : public BaseDB -{ -private: - /** @brief Path of the config file (JSON) */ - std::string _config; - /** @brief Connection to the broker */ - AMQP::TcpConnection* _connection; - /** @brief main channel used to send data to the broker */ - std::shared_ptr _channel_send; - /** @brief main channel used to receive data from the broker */ - std::shared_ptr _channel_receive; - /** @brief Broker address */ - AMQP::Address* _address; - /** @brief name of the queue to send data */ - std::string _queue_sender; - /** @brief name of the queue to receive data */ - std::string _queue_receiver; - /** @brief MPI rank (if MPI is used, otherwise 0) */ - int _rank; - /** @brief The event loop for sender (usually the default one in libevent) */ - struct event_base* _loop_sender; - /** @brief The event loop receiver */ - struct event_base* _loop_receiver; - /** @brief The handler which contains various callbacks for the sender */ - std::shared_ptr _handler_sender; - /** @brief The handler which contains various callbacks for the receiver */ - std::shared_ptr _handler_receiver; - /** @brief evbuffer that is responsible to offload data to RabbitMQ*/ - EventBuffer* _evbuffer; - /** @brief The worker in charge of sending data to the broker (dedicated - * thread) */ - std::shared_ptr _sender; - /** @brief The worker in charge of sending data to the broker (dedicated - * thread) */ - std::shared_ptr _receiver; - /** @brief The number of messages to be sent */ - int _nb_msg_send; - /** @brief Queue that contains all the messages received on receiver queue */ - std::shared_ptr> _messages; - - /** - * @brief Read a JSON and create a hashmap - * @param[in] fn Path of the RabbitMQ JSON config file - * @return a hashmap (std::unordered_map) of the JSON file - */ - json _read_config(std::string fn) - { - std::ifstream config; - json connection_info = { - {"rabbitmq-erlang-cookie", ""}, - {"rabbitmq-name", ""}, - {"rabbitmq-password", ""}, - {"rabbitmq-user", ""}, - {"rabbitmq-vhost", ""}, - {"service-port", ""}, - {"service-host", ""}, - {"rabbitmq-cert", ""}, - {"rabbitmq-inbound-queue", ""}, - {"rabbitmq-outbound-queue", ""}, - }; - - config.open(fn, std::ifstream::in); - - if (config.is_open()) { - std::string line; - while (std::getline(config, line)) { - if (line.find("{") != std::string::npos || - line.find("}") != std::string::npos) { - continue; - } - line.erase(std::remove(line.begin(), line.end(), ' '), line.end()); - line.erase(std::remove(line.begin(), line.end(), ','), line.end()); - line.erase(std::remove(line.begin(), line.end(), '"'), line.end()); - - std::string key = line.substr(0, line.find(':')); - line.erase(0, line.find(":") + 1); - connection_info[key] = line; - } - config.close(); - } else { - std::string err = "Could not open JSON file: " + fn; - throw std::runtime_error(err); - } - return connection_info; - } - - /** @brief linearize all elements of a vector of C-vectors - * in a single C-vector. Data are transposed. - * - * @tparam TypeInValue Type of the source value. - * @param[in] n The number of elements of the vectors. - * @param[in] features A vector containing C-vector of feature values. - * @param[in] add_dims A bool. if true we add the dimensions of the - * flatten feature as first element and second element of the result array. - * @return A pointer to a C-vector containing the linearized values. The - * C-vector is_same resident in the same device as the input feature pointers. - */ - template - PERFFASPECT() - static inline TypeValue* flatten_features( - const size_t n, - const std::vector& features, - bool add_dims = false) - { - const size_t nfeatures = features.size(); - const size_t nvalues = n * nfeatures; - - size_t offset = 0; - // Offset to have space to write off the dimensions at the begiginning of the array - if (add_dims) - offset = 2; - - TypeValue* data = (TypeValue*) malloc((nvalues + offset) * sizeof(TypeValue)); - - if (add_dims) { - data[0] = (TypeValue) n; - data[1] = (TypeValue) features.size(); - } - - for (size_t d = 0; d < nfeatures; d++) { - for (size_t i = 0; i < n; i++) { - data[offset + (i * nfeatures) + d] = static_cast(features[d][i]); - } - } - return data; - } - - /** - * @brief Initialize the connection with the broker, open a channel and set up a - * queue. Then it also sets up a worker thread and start its even loop. Now - * the broker is ready for push operation. - * @param[in] queue The name of the queue to declare - */ - void start_sender(const std::string& queue) - { - _channel_send = - std::make_shared(_connection); - _channel_send->onError([&_rank = _rank](const char* message) { - CFATAL(RabbitMQDB, - false, - "[rank=%d] Error while creating broker channel: %s", - _rank, - message) - // TODO: throw dedicated excpetion and try to recover - // from it (re-trying to open a queue, testing if the RM server is alive - // etc) - throw std::runtime_error(message); - }); - - _channel_send->declareQueue(queue) - .onSuccess([queue, &_rank = _rank](const std::string& name, - uint32_t messagecount, - uint32_t consumercount) { - if (messagecount > 0 || consumercount > 1) { - WARNING(RabbitMQDB, - "[rank=%d] declared queue: %s (messagecount=%d, " - "consumercount=%d)", - _rank, - queue.c_str(), - messagecount, - consumercount) - } - }) - .onError([queue, &_rank = _rank](const char* message) { - CFATAL(RabbitMQDB, - false, - "[ERROR][rank=%d] Error while creating broker queue (%s): %s", - _rank, - queue.c_str(), - message) - // TODO: throw dedicated excpetion and try to recover - // from it (re-trying to open a queue, testing if the RM server is - // alive etc) - throw std::runtime_error(message); - }); - - _sender = std::make_shared(); - _sender->loop = _loop_sender; - _evbuffer = new EventBuffer(_rank, - _loop_sender, - _channel_send, - _queue_sender); - if (pthread_create(&_sender->id, NULL, start_worker_sender, _sender.get())) { - FATAL(RabbitMQDB, "error pthread_create for sender worker"); - } - } - - /** - * @brief Initialize the connection with the broker, open a channel and set up a - * queue. Then it also sets up a worker thread and start its even loop. Now - * the broker is ready for push operation. - * @param[in] queue The name of the queue to declare - */ - void start_receiver(const std::string& queue) - { - _channel_receive = - std::make_shared(_connection); - _channel_receive->onError([&_rank = _rank](const char* message) { - CFATAL(RabbitMQDB, - false, - "[rank=%d] Error while creating broker channel: %s", - _rank, - message) - // TODO: throw dedicated excpetion and try to recover - // from it (re-trying to open a queue, testing if the RM server is alive - // etc) - throw std::runtime_error(message); - }); - - _channel_receive->declareQueue(queue) - .onSuccess([queue, &_rank = _rank](const std::string& name, - uint32_t messagecount, - uint32_t consumercount) { - if (messagecount > 0 || consumercount > 1) { - WARNING(RabbitMQDB, - "[rank=%d] declared queue: %s (messagecount=%d, " - "consumercount=%d)", - _rank, - queue.c_str(), - messagecount, - consumercount) - } - }) - .onError([queue, &_rank = _rank](const char* message) { - CFATAL(RabbitMQDB, - false, - "[ERROR][rank=%d] Error while creating broker queue (%s): %s", - _rank, - queue.c_str(), - message) - // TODO: throw dedicated excpetion and try to recover - // from it (re-trying to open a queue, testing if the RM server is - // alive etc) - throw std::runtime_error(message); - }); - - _receiver = std::make_shared(); - _receiver->loop = _loop_receiver; - _receiver->channel = _channel_receive; - // Structure that will contain all messages received - _receiver->messages = std::make_shared>(); - if (pthread_create( - &_receiver->id, NULL, start_worker_consumer, _receiver.get())) { - FATAL(RabbitMQDB, "error pthread_create for receiver worker"); - } - } - -public: - RabbitMQDB(const RabbitMQDB&) = delete; - RabbitMQDB& operator=(const RabbitMQDB&) = delete; - - RabbitMQDB(char* config, uint64_t id) - : BaseDB(id), - _rank(0), - _nb_msg_send(0), - _handler_sender(nullptr), - _evbuffer(nullptr), - _address(nullptr), - _sender(nullptr), - _receiver(nullptr) - { - _config = std::string(config); - auto rmq_config = _read_config(this->_config); - _queue_sender = - rmq_config["rabbitmq-outbound-queue"]; // Queue to send data to - _queue_receiver = - rmq_config["rabbitmq-inbound-queue"]; // Queue to receive data from PDS - bool is_secure = true; - - if (rmq_config["service-port"].empty()) { - FATAL(RabbitMQDB, - "service-port is empty, make sure the port number is present in " - "the JSON configuration") - } - if (rmq_config["service-host"].empty()) { - FATAL(RabbitMQDB, - "service-host is empty, make sure the host is present in the JSON " - "configuration") - } - - uint16_t port = std::stoi(rmq_config["service-port"]); - if (_queue_sender.empty() || _queue_receiver.empty()) { - FATAL(RabbitMQDB, - "Queues are empty, please check your credentials file and make " - "sure rabbitmq-inbound-queue and rabbitmq-outbound-queue exist") - } - -#ifdef __ENABLE_MPI__ - MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &_rank)); -#endif -#ifdef EVTHREAD_USE_PTHREADS_IMPLEMENTED - evthread_use_pthreads(); -#endif - _loop_sender = event_base_new(); - _loop_receiver = event_base_new(); - CDEBUG(RabbitMQDB, _rank == 0, "Libevent %s\n", event_get_version()); - CDEBUG(RabbitMQDB, - _rank == 0, - "%s (OPENSSL_VERSION_NUMBER = %#010x)\n", - OPENSSL_VERSION_TEXT, - OPENSSL_VERSION_NUMBER); -#if OPENSSL_VERSION_NUMBER < 0x10100000L - SSL_library_init(); -#else - OPENSSL_init_ssl(0, NULL); -#endif - _handler_sender = - std::make_shared(_rank, _loop_sender, rmq_config["rabbitmq-cert"]); - _handler_receiver = - std::make_shared(_rank, _loop_receiver, rmq_config["rabbitmq-cert"]); - - AMQP::Login login(rmq_config["rabbitmq-user"], - rmq_config["rabbitmq-password"]); - - _address = new AMQP::Address(rmq_config["service-host"], - port, - login, - rmq_config["rabbitmq-vhost"], - is_secure); - - if (_address == nullptr) - throw std::runtime_error("something is wrong, address is NULL"); - - CDEBUG(RabbitMQDB, - _rank == 0, - "RabbitMQ address: %s:%d/%s (sender queue = %s, receiver queue = " - "%s)", - _address->hostname().c_str(), - _address->port(), - _address->vhost().c_str(), - _queue_sender.c_str(), - _queue_receiver.c_str()) - - _connection = new AMQP::TcpConnection(_handler_sender.get(), *_address); - start_sender(_queue_sender); - start_receiver(_queue_receiver); - // mandatory to give some time to OpenSSL and RMQ to set things up, - // TODO: find a way to remove that magic sleep and actually check if OpenSSL - // + RMQ are up and running - std::this_thread::sleep_for(std::chrono::seconds(2)); - } - - /** - * @brief Make sure the evbuffer is being drained. - * This function blocks until the buffer is empty. - * @param[in] sleep_time Number of seconds between two checking (active - * pooling) - */ - void drain(int sleep_time = 1) - { - if (!(_sender && _evbuffer)) { - return; - } - while (true) { - if (_evbuffer->is_drained()) { - break; - } - sleep(sleep_time); - } - } - - /** - * @brief Return the most recent messages and delete it - * @return A structure inbound_msg which is a std::tuple (see typedef) - */ - inbound_msg pop_messages() - { - if (!(_messages->empty())) { - inbound_msg msg = _messages->back(); - _messages->pop_back(); - return msg; - } - return std::make_tuple("", "", "", -1, false); - } - - /** - * @brief Return the message corresponding to the delivery tag. Do not delete the - * message. - * @param[in] delivery_tag Delivery tag that will be returned (if found) - * @return A structure inbound_msg which is a std::tuple (see typedef) - */ - inbound_msg get_messages(uint64_t delivery_tag) - { - if (!(_messages->empty())) { - auto it = std::find_if(_messages->begin(), - _messages->end(), - [&delivery_tag](const inbound_msg& e) { - return std::get<3>(e) == delivery_tag; - }); - if (it != _messages->end()) return *it; - } - return std::make_tuple("", "", "", -1, false); - } - - /** - * @brief Takes an input and an output vector each holding 1-D vectors data, and push - * it onto the libevent buffer. We flatten the inputs/outputs and send one - * message for each. The first two elements of the message are num_elements - * and the feature size. These elements are needed to reconstruct the inputs - * and outputs on the other side (RabbitMQ). - * - * @param[in] num_elements Number of elements of each 1-D vector - * @param[in] inputs Vector of 1-D vectors containing the inputs to be sent - * @param[in] outputs Vector of 1-D vectors, each 1-D vectors contains - * 'num_elements' values to be sent - */ - PERFFASPECT() - void store(size_t num_elements, - std::vector& inputs, - std::vector& outputs) override - { - CINFO(RabbitMQDB, - true, - "RabbitMQDB of type %s stores %ld elements of input/output " - "dimensions (%d, %d)", - type().c_str(), - num_elements, - inputs.size(), - outputs.size()) - - const size_t inputs_size = num_elements * inputs.size(); - auto inputs_data = flatten_features(num_elements, inputs, true); - DBG(RabbitMQDB, - "[store(%d, %d, %d)] input sent %d B", - num_elements, - inputs.size(), - outputs.size(), - inputs_size * sizeof(TypeValue)) - _nb_msg_send++; - _evbuffer->push(static_cast(inputs_data), - inputs_size * sizeof(TypeValue)); - - // TODO: investigate that - // Necessary for some reasons, other the event buffer overheat - // "[err] buffer.c:1066: Assertion chain || datlen==0 failed in - // evbuffer_copyout" potentially segfault, and with CUDA could also lead to - // packet losses - std::this_thread::sleep_for(std::chrono::milliseconds(2000)); - - const size_t outputs_size = num_elements * outputs.size(); - auto outputs_data = flatten_features(num_elements, outputs, true); - DBG(RabbitMQDB, - "[store(%d, %d, %d)] output sent %d B", - num_elements, - inputs.size(), - outputs.size(), - outputs_size * sizeof(TypeValue)) - _nb_msg_send++; - _evbuffer->push(static_cast(outputs_data), - outputs_size * sizeof(TypeValue)); - - free(inputs_data); - free(outputs_data); - } - - /** - * @brief Return the type of this broker - * @return The type of the broker - */ - std::string type() override { return "rabbitmq"; } - - /** - * @brief Return the number of messages that - * has been push to the buffer - * @return The number of messages sent - */ - int nb_msg() const { return _nb_msg_send; } - - ~RabbitMQDB() - { - drain(); - pthread_kill(_sender->id, SIGUSR1); - pthread_kill(_receiver->id, SIGUSR1); - _channel_send->close(); - _channel_receive->close(); - event_base_free(_loop_sender); - event_base_free(_loop_receiver); - delete _evbuffer; - delete _address; - _connection->close(); - free(_connection); - } -}; // class RabbitMQDB - -#endif // __ENABLE_RMQ__ - - -/** - * @brief Create an object of the respective database. - * This should never be used for large scale simulations as txt/csv format will - * be extremely slow. - * @param[in] dbPath path to the directory storing the data - * @param[in] dbType Type of the database to create - * @param[in] rId a unique Id for each process taking part in a distributed - * execution (rank-id) - */ -template -BaseDB* createDB(char* dbPath, AMSDBType dbType, uint64_t rId = 0) -{ - DBG(DB, "Instantiating data base"); -#ifdef __ENABLE_DB__ - if (dbPath == nullptr) { - std::cerr << " [WARNING] Path of DB is NULL, Please provide a valid path " - "to enable db\n"; - std::cerr << " [WARNING] Continueing\n"; - return nullptr; - } - - switch (dbType) { - case AMSDBType::CSV: - return new csvDB(dbPath, rId); -#ifdef __ENABLE_REDIS__ - case AMSDBType::REDIS: - return new RedisDB(dbPath, rId); -#endif -#ifdef __ENABLE_HDF5__ - case AMSDBType::HDF5: - return new hdf5DB(dbPath, rId); -#endif -#ifdef __ENABLE_RMQ__ - case AMSDBType::RMQ: - return new RabbitMQDB(dbPath, rId); -#endif - default: - return nullptr; - } -#else - return nullptr; -#endif -} - -#endif // __AMS_BASE_DB__ diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6d6f086f..1f64d8a1 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -5,7 +5,7 @@ function (ADDTEST binary_name src_file test_name) add_executable(${binary_name} ${src_file}) - target_include_directories(${binary_name} PRIVATE "${PROJECT_SOURCE_DIR}/src" umpire ${caliper_INCLUDE_DIR} ${MPI_INCLUDE_PATH}) + target_include_directories(${binary_name} PRIVATE "${PROJECT_SOURCE_DIR}/src/AMSlib" umpire ${caliper_INCLUDE_DIR} ${MPI_INCLUDE_PATH}) target_link_directories(${binary_name} PRIVATE ${AMS_APP_LIB_DIRS}) target_link_libraries(${binary_name} PRIVATE AMS umpire MPI::MPI_CXX) diff --git a/tests/lb.cpp b/tests/lb.cpp new file mode 100644 index 00000000..0f3dd690 --- /dev/null +++ b/tests/lb.cpp @@ -0,0 +1,102 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "wf/redist_load.hpp" + +#define SIZE (10) + +void init(double *data, int elements, double value) +{ + for (int i = 0; i < elements; i++) { + data[i] = value; + } +} + +void evaluate(double *data, double *src, int elements) +{ + ams::ResourceManager::copy(src, data, elements * sizeof(double)); +} + +int verify(double *data, double *src, int elements, int rId) +{ + return std::memcmp(data, src, elements * sizeof(double)); +} + +int main(int argc, char *argv[]) +{ + using namespace ams; + int device = std::atoi(argv[1]); + MPI_Init(&argc, &argv); + AMSSetupAllocator(AMSResourceType::HOST); + AMSResourceType resource = AMSResourceType::HOST; + AMSSetDefaultAllocator(AMSResourceType::HOST); + int rId, wS; + MPI_Comm_size(MPI_COMM_WORLD, &wS); + MPI_Comm_rank(MPI_COMM_WORLD, &rId); + srand(rId); + std::default_random_engine generator; + std::normal_distribution distribution(0.5, 0.3); + srand(rId); + double threshold; + for ( int i = 0; i <= rId; i++){ + threshold = distribution(generator); + } + + int computeElements = (threshold * SIZE); // / sizeof(double); + + double *srcData, *destData; + double *srcHData = srcData = + ResourceManager::allocate(computeElements, AMSResourceType::HOST); + double *destHData = destData = + ResourceManager::allocate(computeElements, AMSResourceType::HOST); + + init(srcHData, computeElements, static_cast(rId)); + + if (device == 1) { + AMSSetupAllocator(AMSResourceType::DEVICE); + AMSSetDefaultAllocator(AMSResourceType::DEVICE); + resource = AMSResourceType::DEVICE; + srcData = ResourceManager::allocate(computeElements, + AMSResourceType::DEVICE); + destData = ResourceManager::allocate(computeElements, + AMSResourceType::DEVICE); + } + + std::vector inputs({srcData}); + std::vector outputs({destData}); + + { + + std::cerr << "Resource is " << resource << "\n"; + AMSLoadBalancer lBalancer( + rId, wS, computeElements, MPI_COMM_WORLD, 1, 1, resource); + lBalancer.scatterInputs(inputs, resource); + double **lbInputs = lBalancer.inputs(); + double **lbOutputs = lBalancer.outputs(); + evaluate(*lbOutputs, *lbInputs, lBalancer.getBalancedSize()); + lBalancer.gatherOutputs(outputs, resource); + } + + if (device == 1) { + ResourceManager::copy(destData, + destHData, + computeElements * sizeof(double)); + ResourceManager::deallocate(destData, AMSResourceType::DEVICE); + ResourceManager::deallocate(srcData, AMSResourceType::DEVICE); + } + + int ret = verify(destHData, srcHData, computeElements, rId); + + ResourceManager::deallocate(destHData, AMSResourceType::HOST); + ResourceManager::deallocate(srcHData, AMSResourceType::HOST); + + MPI_Finalize(); + return ret; +}