Skip to content

Commit

Permalink
Added algorithms timing messages
Browse files Browse the repository at this point in the history
  • Loading branch information
amirbawab committed Jun 13, 2019
1 parent c7833f1 commit be8102e
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/nn-builder/examples/cpp/mnist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ int main(int argc, char *argv[]) {
}

// Load csv file
uint32_t training_limit = 3000;
uint32_t training_limit = 6000;
uint32_t testing_limit = 1000;
std::vector<std::vector<float>> train_data;
std::vector<std::vector<float>> train_labels;
Expand Down
59 changes: 59 additions & 0 deletions src/nn-builder/js/compiled_model.js
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,65 @@ class CompiledModel {
log_prediction_time: (time) => {
console.log("Prediction time:", time, "ms");
},
// Forward timing
log_forward_Time: () => {
console.log("\n>> Forward algorithm steps time:");
},
log_forward_A_1: (time) => {
console.log("A) Z[l] = W[l] . A[l-1] + b[l]");
console.log(" 1) Z[l] = W[l] . A[l-1]:", time);
},
log_forward_A_2: (time) => {
console.log(" 2) Z[l] = Z[l] + b[l]:", time);
},
log_forward_B: (time) => {
console.log("B) A[l] = g[l](Z[l]):", time);
},
// Backward timing
log_backward_Time: () => {
console.log("\n>> Backward algorithm steps time:");
},
log_backward_A: (time) => {
console.log("A) dA[L] = L(T, A[L]):", time);
},
log_backward_B_1: (time) => {
console.log("B) dZ[l] = dA[l] * g'[l](Z[l])");
console.log(" 1) dZ[l] = g'[l](Z[l]):", time);
},
log_backward_B_2: (time) => {
console.log(" 2) dZ[l] = dA[l] * dZ[l]:", time);
},
log_backward_C_1: (time) => {
console.log("C) dW[l] = (1/m) dZ[l] . A[l-1]^T");
console.log(" 1) dW[l] = dZ[l] . A[l-1]^T:", time);
},
log_backward_C_2: (time) => {
console.log(" 2) dW[l] = (1/m) dW[l]:", time);
},
log_backward_D_1: (time) => {
console.log("D) db[l] = (1/m) dZ[l]");
console.log(" 1) db[l] = SUM(dZ[l], row wise):", time);
},
log_backward_D_2: (time) => {
console.log(" 2) db[l] = (1/m) db[l]:", time);
},
log_backward_E: (time) => {
console.log("E) dA[l-1] = W[l]^T . dZ[l]:", time);
},
log_backward_F_1: (time) => {
console.log("F) W[l] = W[l] - alpha * dW[l]");
console.log(" 1) dW[l] = alpha * dW[l]:", time);
},
log_backward_F_2: (time) => {
console.log(" 2) W[l] = W[l] - dW[l]:", time);
},
log_backward_G_1: (time) => {
console.log("G) b[l] = b[l] - alpha * db[l]");
console.log(" 1) db[l] = alpha * db[l]:", time);
},
log_backward_G_2: (time) => {
console.log(" 2) b[l] = b[l] - db[l]:", time);
},
};

let system_imports = {
Expand Down
4 changes: 2 additions & 2 deletions src/nn-builder/src/arch/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -501,14 +501,14 @@ void Model::CompileTrainingFunction(uint32_t epochs, float learning_rate, const

if(options_.log_forward) {
#define LOG_TIME_MEMBER(name) \
f.Insert(MakeCall(builtins_.system.PrintF64(), { dense_forward_logging_members_.Get##name()}));
f.Insert(MakeCall(builtins_.message.LogForward##name(), { dense_forward_logging_members_.Get##name()}));
DENSE_FORWARD_TIME_MEMBERS(LOG_TIME_MEMBER)
#undef LOG_TIME_MEMBER
}

if(options_.log_backward) {
#define LOG_TIME_MEMBER(name) \
f.Insert(MakeCall(builtins_.system.PrintF64(), { dense_backward_logging_members_.Get##name()}));
f.Insert(MakeCall(builtins_.message.LogBackward##name(), { dense_backward_logging_members_.Get##name()}));
DENSE_BACKWARD_TIME_MEMBERS(LOG_TIME_MEMBER)
#undef LOG_TIME_MEMBER
}
Expand Down
10 changes: 10 additions & 0 deletions src/nn-builder/src/builtins/message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ void Message::InitImports(arch::Model* model, wasmpp::ModuleManager* module_mana
log_testing_time_ = module_manager->MakeFuncImport(module_name, "log_testing_time", {{Type::F64}, {}});
log_testing_error_ = module_manager->MakeFuncImport(module_name, "log_testing_error", {{Type::F32}, {}});
log_prediction_time_ = module_manager->MakeFuncImport(module_name, "log_prediction_time", {{Type::F64}, {}});

#define LOAD_TIME_MESSAGES(name) \
log_forward_##name##_ = module_manager->MakeFuncImport(module_name, "log_forward_" #name, {{Type::F64}, {}});
FORWARD_TIME_MESSAGES(LOAD_TIME_MESSAGES)
#undef LOAD_TIME_MESSAGES

#define LOAD_TIME_MESSAGES(name) \
log_backward_##name##_ = module_manager->MakeFuncImport(module_name, "log_backward_" #name, {{Type::F64}, {}});
BACKWARD_TIME_MESSAGES(LOAD_TIME_MESSAGES)
#undef LOAD_TIME_MESSAGES
}

void Message::InitDefinitions(arch::Model* model, wasmpp::ModuleManager* module_manager) {
Expand Down
46 changes: 44 additions & 2 deletions src/nn-builder/src/builtins/message.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,27 @@
namespace nn {
namespace builtins {

#define FORWARD_TIME_MESSAGES(V) \
V(Time) \
V(A_1) \
V(A_2) \
V(B)

#define BACKWARD_TIME_MESSAGES(V) \
V(Time) \
V(A) \
V(B_1) \
V(B_2) \
V(C_1) \
V(C_2) \
V(D_1) \
V(D_2) \
V(E) \
V(F_1) \
V(F_2) \
V(G_1) \
V(G_2)

class Message : public Builtin {
private:
wabt::Var log_training_accuracy_;
Expand All @@ -15,7 +36,18 @@ class Message : public Builtin {
wabt::Var log_testing_time_;
wabt::Var log_testing_error_;
wabt::Var log_prediction_time_;
public:

#define VAR_NAMES(name) \
wabt::Var log_forward_##name##_;
FORWARD_TIME_MESSAGES(VAR_NAMES)
#undef VAR_NAMES

#define VAR_NAMES(name) \
wabt::Var log_backward_##name##_;
BACKWARD_TIME_MESSAGES(VAR_NAMES)
#undef VAR_NAMES

public:
void InitImports(arch::Model* model, wasmpp::ModuleManager* module_manager, std::string module_name) override;
void InitDefinitions(arch::Model* model, wasmpp::ModuleManager* module_manager) override;

Expand All @@ -26,7 +58,17 @@ class Message : public Builtin {
const wabt::Var& LogTestingError() const { return log_testing_error_; }
const wabt::Var& LogTestingAccuracy() const { return log_testing_accuracy_; }
const wabt::Var& LogPredictionTime() const { return log_prediction_time_; }
};

#define GET_VAR_NAMES(name) \
const wabt::Var& LogForward##name() const { return log_forward_##name##_; }
FORWARD_TIME_MESSAGES(GET_VAR_NAMES)
#undef GET_VAR_NAMES

#define GET_VAR_NAMES(name) \
const wabt::Var& LogBackward##name() const { return log_backward_##name##_; }
BACKWARD_TIME_MESSAGES(GET_VAR_NAMES)
#undef GET_VAR_NAMES
};

} // namespace builtins
} // namespace nn
Expand Down

0 comments on commit be8102e

Please sign in to comment.