diff --git a/src_cpp/opennnBridge/get_set_weights.h b/src_cpp/opennnBridge/get_set_weights.h index 15a616d92..6a81a5aab 100644 --- a/src_cpp/opennnBridge/get_set_weights.h +++ b/src_cpp/opennnBridge/get_set_weights.h @@ -7,6 +7,8 @@ #include "bridgeController.h" #include "nerlWorkerOpenNN.h" +#define GET_WEIGHTS_ATOM_STR "get_weights" + using namespace opennn; using namespace nerlnet; @@ -18,7 +20,6 @@ class GetWeightsParams ErlNifTid tid; }; - inline void* get_weights(void* arg) { std::shared_ptr* pGetWeigthsParamsPtr= static_cast*>(arg); @@ -31,7 +32,8 @@ inline void* get_weights(void* arg) fTensor1D parameters; fTensor1DPtr parameters_ptr; nifpp::str_atom nerltensor_type = "float"; - std::tuple message_tuple; // returned nerltensor of parameters + nifpp::str_atom get_weights_atom = GET_WEIGHTS_ATOM_STR; + std::tuple message_tuple; // returned nerltensor of parameters //get neural network parameters which are weights and biases valuse as a 1D vector BridgeController &bridge_controller = BridgeController::GetInstance(); @@ -45,7 +47,7 @@ inline void* get_weights(void* arg) nifpp::make_tensor_1d(env, nerltensor_parameters_bin, parameters_ptr); //binary tensor // create returned tuple - message_tuple = { nerltensor_parameters_bin , nifpp::make(env, nerltensor_type) }; + message_tuple = { nifpp::make(env, get_weights_atom), nerltensor_parameters_bin , nifpp::make(env, nerltensor_type) }; nifpp::TERM message = nifpp::make(env, message_tuple); if(enif_send(NULL,&(getWeigthsParamsPtr->pid), env, message)) //TODO check this value in Erlang! diff --git a/src_erl/NerlnetApp/src/Bridge/Common/workerDefinitions.hrl b/src_erl/NerlnetApp/src/Bridge/Common/workerDefinitions.hrl index 966d54b32..9c914b745 100644 --- a/src_erl/NerlnetApp/src/Bridge/Common/workerDefinitions.hrl +++ b/src_erl/NerlnetApp/src/Bridge/Common/workerDefinitions.hrl @@ -1,4 +1,4 @@ --define(ETS_KEYVAL_KEY_IDX, 2). +-define(ETS_KEYVAL_KEY_IDX, 1). -define(ETS_KEYVAL_VAL_IDX, 2). -define(TENSOR_DATA_IDX, 1). diff --git a/src_erl/NerlnetApp/src/Bridge/onnWorkers/nerlNIF.erl b/src_erl/NerlnetApp/src/Bridge/onnWorkers/nerlNIF.erl index 7d87d30fb..376733ed8 100644 --- a/src_erl/NerlnetApp/src/Bridge/onnWorkers/nerlNIF.erl +++ b/src_erl/NerlnetApp/src/Bridge/onnWorkers/nerlNIF.erl @@ -2,6 +2,9 @@ -include_lib("kernel/include/logger.hrl"). -include("../nerlTensor.hrl"). +-define(ETS_KEYVAL_KEY_IDX, 1). +-define(ETS_KEYVAL_VAL_IDX, 2). + -export([init/0,nif_preload/0,get_active_models_ids_list/0, train_nif/3,update_nerlworker_train_params_nif/6,call_to_train/5,predict_nif/3,call_to_predict/5,get_weights_nif/1,printTensor/2]). -export([call_to_get_weights/1,call_to_set_weights/2]). -export([decode_nif/2, nerltensor_binary_decode/2]). @@ -78,22 +81,35 @@ call_to_predict(ModelID, {BatchTensor, Type}, WorkerPid, BatchID , SourceName)-> end. % This function calls to get_weights_nif() and waits for the result using receive block +% It uses NIF and receive block in a new process to avoid from blocking the main process of the FSM % Returns {NerlTensorWeights , BinaryType} call_to_get_weights(ModelID)-> try - ?LOG_INFO("Calling get weights in model ~p~n",{ModelID}), - _RetVal = get_weights_nif(ModelID), - recv_call_loop() + WeightsEts = ets:new(weights_ets, [set,public]), + ets:insert(WeightsEts, {weights_status, waiting}), + spawn_link(fun() -> get_weights_nif(ModelID), recv_call_loop(WeightsEts) end), + get_weights_sync(WeightsEts), % sync on ETS update + ets:lookup_element(WeightsEts, weights, ?ETS_KEYVAL_VAL_IDX) % return weights NerlTensor catch Err:E -> ?LOG_ERROR("Couldnt get weights from worker~n~p~n",{Err,E}), [] end. -%% sometimes the receive loop gets OTP calls that its not supposed to in high freq. wait for nerktensor of weights -recv_call_loop() -> +get_weights_sync(WeightsEts) -> + WeightsEtsStats = ets:lookup_element(WeightsEts, weights_status, ?ETS_KEYVAL_VAL_IDX), + case WeightsEtsStats of + updated -> finished; + waiting -> get_weights_sync(WeightsEts) + end. + +%% This function runs in a spwaned thread to avoid blocking the main process +%% Moreover the main process belongs to the FSM and we don't want to catch messages of the FSM ({'$gen_cast',_Any} {'$gen_call', _Any} etc...) +recv_call_loop(WeightsEts) -> receive - {'$gen_cast', _Any} -> ?LOG_WARNING("Missed batch in call of get_weigths"), - recv_call_loop(); - NerlTensorWeights -> NerlTensorWeights + {get_weights, NerlTensorWeights, NerlTensorType} -> + ets:insert(WeightsEts, {weights, {NerlTensorWeights, NerlTensorType}}), + ets:update_element(WeightsEts, weights_status, {?ETS_KEYVAL_VAL_IDX, updated}); % save weights to temporary ets - TODO try to optimize + _Else -> ?LOG_ERROR("Received wrong message in get_weights_nif~n"), + recv_call_loop(WeightsEts) end. call_to_set_weights(ModelID,{WeightsNerlTensor, Type})->