From eacad899a7fab2d04840120704d83cc8c927a78e Mon Sep 17 00:00:00 2001 From: takuseno Date: Wed, 30 Jan 2019 19:31:22 +0900 Subject: [PATCH 1/6] Add static backward_all function to CgVariable --- include/nbla/computation_graph/variable.hpp | 11 ++++- src/nbla/computation_graph/variable.cpp | 47 ++++++++++++++++++--- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/include/nbla/computation_graph/variable.hpp b/include/nbla/computation_graph/variable.hpp index 60c5752fe..757d1c008 100644 --- a/include/nbla/computation_graph/variable.hpp +++ b/include/nbla/computation_graph/variable.hpp @@ -75,8 +75,9 @@ class CgVariable { unordered_set &fclosed, std::function forward_callback); - void visit_function_backward( - CgFunctionPtr func, std::function backward_callback, + static void visit_function_backward( + vector roots, + std::function backward_callback, vector communicator_callbacks); public: @@ -236,6 +237,12 @@ class CgVariable { backward(NdArrayPtr grad = nullptr, bool clear_buffer = false, vector communicator_callbacks = {}); + /** + */ + static void + backward_all(vector variables, bool clear_buffer = false, + vector communicator_callbacks = {}); + /** */ NBLA_API vector function_references(); diff --git a/src/nbla/computation_graph/variable.cpp b/src/nbla/computation_graph/variable.cpp index ad1a86084..309ad59ca 100644 --- a/src/nbla/computation_graph/variable.cpp +++ b/src/nbla/computation_graph/variable.cpp @@ -292,11 +292,13 @@ class BackwardCallback { } public: - BackwardCallback(CgFunctionPtr f, bool clear_buffer) + BackwardCallback(vector roots, bool clear_buffer) : clear_buffer_(clear_buffer) { // Note prohibiting clearing variable buffers where terminal. - for (auto o : f->outputs()) { - vseen_.insert({o, true}); + for (auto v : roots) { + for (auto o : v->outputs()) { + vseen_.insert({o, true}); + } } } @@ -403,7 +405,8 @@ void CgVariable::visit_function_recursive( } void CgVariable::visit_function_backward( - CgFunctionPtr p, std::function backward_callback, + vector roots, + std::function backward_callback, vector communicator_callbacks) { // Open list of next search candidate. unordered_map ids; @@ -420,7 +423,9 @@ void CgVariable::visit_function_backward( return it->second; }; set> open; - open.insert(make_tuple(-p->rank(), get_id(p), p)); + for (auto p : roots) { + open.insert(make_tuple(-p->rank(), get_id(p), p)); + } while (!open.empty()) { auto rank_func = open.begin(); auto f = get<2>(*rank_func); @@ -488,11 +493,39 @@ void CgVariable::backward( } // Create callback - BackwardCallback backward_callback(parent_, clear_buffer); + vector roots{this->parent()}; + BackwardCallback backward_callback(roots, clear_buffer); + + // Visit backward + visit_function_backward( + roots, [&backward_callback](CgFunctionPtr f) { backward_callback(f); }, + communicator_callbacks); +} + +void CgVariable::backward_all( + vector variables, bool clear_buffer, + vector communicator_callbacks) { + // setup backward at each variable + vector bak_grads; + DestructorCallback at_scope_exit([&]() { + for (int i = 0; i < variables.size(); ++i) { + variables[i]->variable()->set_grad(bak_grads[i]); + } + }); + vector roots; + for (auto v : variables) { + // backup gradients + bak_grads.push_back(v->variable()->grad()); + // set function to avoid clearing + roots.push_back(v->parent()); + } + + // Create callback + BackwardCallback backward_callback(roots, clear_buffer); // Visit backward visit_function_backward( - parent_, [&backward_callback](CgFunctionPtr f) { backward_callback(f); }, + roots, [&backward_callback](CgFunctionPtr f) { backward_callback(f); }, communicator_callbacks); } From b4effeeded1b8235b0c11ff46638f64ca23875d8 Mon Sep 17 00:00:00 2001 From: takuseno Date: Wed, 30 Jan 2019 19:32:04 +0900 Subject: [PATCH 2/6] Implement backward_all in c++ --- include/nbla/computation_graph/computation_graph.hpp | 7 +++++++ src/nbla/computation_graph/computation_graph.cpp | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/include/nbla/computation_graph/computation_graph.hpp b/include/nbla/computation_graph/computation_graph.hpp index cbb5f7a97..231098834 100644 --- a/include/nbla/computation_graph/computation_graph.hpp +++ b/include/nbla/computation_graph/computation_graph.hpp @@ -55,5 +55,12 @@ NBLA_API void steal_variable_from_to(CgVariablePtr from, CgVariablePtr to); */ NBLA_API void forward_all(const vector variables, bool clear_no_need_grad = false); + +/** Backward given variables in a single call. + * Backward all given variables in a single call. + */ +NBLA_API void backward_all( + const vector variables, bool clear_buffer = false, + const vector communicator_callbacks = {}); } #endif diff --git a/src/nbla/computation_graph/computation_graph.cpp b/src/nbla/computation_graph/computation_graph.cpp index 651ef60cc..e27e869f1 100644 --- a/src/nbla/computation_graph/computation_graph.cpp +++ b/src/nbla/computation_graph/computation_graph.cpp @@ -142,4 +142,9 @@ void forward_all(const vector variables, variables[i]->forward(false, clear_no_need_grad, &fclosed); } } + +void backward_all(const vector variables, bool clear_buffer, + const vector communicator_callbacks) { + CgVariable::backward_all(variables, clear_buffer, communicator_callbacks); +} } From 1222f3d1aae869ef2e916744f7a3b869c28d0b9d Mon Sep 17 00:00:00 2001 From: takuseno Date: Wed, 30 Jan 2019 19:32:21 +0900 Subject: [PATCH 3/6] Implement nn.backward_all in python --- python/src/nnabla/__init__.py | 2 +- python/src/nnabla/_computation_graph.pxd | 4 ++++ python/src/nnabla/_computation_graph.pyx | 24 ++++++++++++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/python/src/nnabla/__init__.py b/python/src/nnabla/__init__.py index 9d29706eb..db7a39015 100644 --- a/python/src/nnabla/__init__.py +++ b/python/src/nnabla/__init__.py @@ -37,7 +37,7 @@ from .context import ( context_scope, set_default_context, get_current_context) from .auto_forward import auto_forward, set_auto_forward, get_auto_forward -from._computation_graph import forward_all +from._computation_graph import forward_all, backward_all # Prefer cached array by default for performance. prefer_cached_array(True) diff --git a/python/src/nnabla/_computation_graph.pxd b/python/src/nnabla/_computation_graph.pxd index 7d8de71cd..bb01700b9 100644 --- a/python/src/nnabla/_computation_graph.pxd +++ b/python/src/nnabla/_computation_graph.pxd @@ -28,3 +28,7 @@ cdef extern from "nbla/computation_graph/computation_graph.hpp" namespace "nbla" cpp_bool) except+ void steal_variable_from_to(CgVariablePtr f, CgVariablePtr t) except+ void forward_all(const vector[CgVariablePtr] &, cpp_bool) nogil except+ + void backward_all( + const vector[CgVariablePtr] &, + cpp_bool, + vector[CommunicatorBackwardCallbackPtr]) nogil except+ diff --git a/python/src/nnabla/_computation_graph.pyx b/python/src/nnabla/_computation_graph.pyx index 8dc175e66..983c926b2 100644 --- a/python/src/nnabla/_computation_graph.pyx +++ b/python/src/nnabla/_computation_graph.pyx @@ -15,7 +15,11 @@ from libcpp cimport bool as cpp_bool from libcpp.vector cimport vector from _variable cimport Variable as _Variable +from _variable cimport CommunicatorBackwardCallback +from _variable cimport CommunicatorBackwardCallbackPtr +from _variable cimport CgVariable as _CgVariable from _computation_graph cimport forward_all as cforward_all +from _computation_graph cimport backward_all as cbackward_all def forward_all(variables, cpp_bool clear_no_need_grad=False): @@ -28,3 +32,23 @@ def forward_all(variables, cpp_bool clear_no_need_grad=False): cg_variables[i] = (<_Variable?> variables[i]).var with nogil: cforward_all(cg_variables, clear_no_need_grad) + + +def backward_all(variables, cpp_bool clear_buffer=False, communicator_callbacks=None): + cdef vector[CommunicatorBackwardCallbackPtr] callback_list + if type(communicator_callbacks) == list: + for x in communicator_callbacks: + callback_list.push_back((< CommunicatorBackwardCallback?> x).var) + elif type(communicator_callbacks) != type(None): + callback_list.push_back((< CommunicatorBackwardCallback?> communicator_callbacks).var) + + cdef vector[CgVariablePtr] cg_variables + cdef int i + cdef int size + size = len(variables) + cg_variables.resize(size) + for i in range(size): + cg_variables[i] = (<_Variable?> variables[i]).var + + with nogil: + cbackward_all(cg_variables, clear_buffer, callback_list) From d5552d635d42ed79228e31b5c3321be79901afee Mon Sep 17 00:00:00 2001 From: takuseno Date: Wed, 30 Jan 2019 19:32:41 +0900 Subject: [PATCH 4/6] Modify test_foward_all --- python/test/test_forward_all.py | 60 +++++++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 11 deletions(-) diff --git a/python/test/test_forward_all.py b/python/test/test_forward_all.py index ce75e5924..8f08c1a39 100644 --- a/python/test/test_forward_all.py +++ b/python/test/test_forward_all.py @@ -52,11 +52,20 @@ def test_graph_logreg(seed): L2 = F.mean(l2) nn.forward_all([L1, L2]) + def zero_grad(): + x.g = 0 + w1.g = 0 + w2.g = 0 + b1.g = 0 + b2.g = 0 + + def backup_grads(): + grads = [x.g, w1.g, w2.g, b1.g, b2.g] + return map(lambda v: v.copy(), grads) + # Backprop for z1 # Diff should be initialized since they are always accumulated - x.g = 0 - w1.g = 0 - b1.g = 0 + zero_grad() L1.backward(clear_buffer=True) inputs = [x, w1, b1] @@ -68,9 +77,7 @@ def test_graph_logreg(seed): # Backprop for z2 # Diff should be initialized since they are always accumulated - x.g = 0 - w2.g = 0 - b2.g = 0 + zero_grad() L2.backward(clear_buffer=True) inputs = [x, w2, b2] @@ -80,6 +87,16 @@ def test_graph_logreg(seed): agrad, ngrad = grads(L2, inputs, 1e-3, False) assert np.allclose(ngrad, agrad, atol=1e-2) + zero_grad() + L1.backward(clear_buffer=True) + L2.backward(clear_buffer=True) + grad1 = backup_grads() + zero_grad() + nn.backward_all([L1, L2], clear_buffer=True) + grad2 = backup_grads() + for g1, g2 in zip(grad1, grad2): + np.allclose(g1, g2) + @pytest.mark.parametrize("seed", [311]) @pytest.mark.parametrize("model", ["mlp", "recurrent", "convolution"]) @@ -157,6 +174,15 @@ def test_graph_model(model, seed): agrad, ngrad = grads(L2, inputs, 1e-3, False) assert np.allclose(ngrad, agrad, atol=1.05e-2) + # test backward_all + initialize_grad(parameters) + L1.backward(clear_buffer=False) + L2.backward(clear_buffer=True) + backup_grads = {k: v.g.copy() for k, v in parameters.items()} + nn.backward_all([L1, L2], clear_buffer=True) + for k, g in backup_grads.items(): + np.allclose(parameters[k].g, g) + @pytest.mark.parametrize("seed", [311]) def test_graph_unlink_backward(seed): @@ -186,6 +212,11 @@ def test_graph_unlink_backward(seed): assert np.all(x0.g == 0) assert not np.all(x1.g == 0) + # test backward_all + nn.backward_all([y1, y2], clear_buffer=True) + assert np.all(x0.g == 0) + assert not np.all(x1.g == 0) + @pytest.mark.parametrize("seed", [311]) def test_graph_clear_buffer(seed): @@ -224,11 +255,7 @@ def test_graph_clear_buffer(seed): for v in nn.get_parameters().values(): v.grad.zero() nn.forward_all([L1, L2], clear_no_need_grad=cnng) - - # for now, the first backward cannot be - # called with clear_buffer=True - L1.backward(clear_buffer=False) - L2.backward(clear_buffer=cb) + nn.backward_all([L1, L2], clear_buffer=cb) if not first: first = True g = list(nn.get_parameters().values())[0].g.copy() @@ -308,3 +335,14 @@ def backup_params(): assert np.allclose(xa.d, xc.d) for b, c in zip(gb, gc): assert np.allclose(b, c) + + # test backward_all + zero_grad() + nn.backward_all([yb1, yb2], clear_buffer=True) + gb = backup_params() + zero_grad() + nn.backward_all([yc1, yc2], clear_buffer=True) + gc = backup_params() + assert np.allclose(xa.d, xc.d) + for b, c in zip(gb, gc): + assert np.allclose(b, c) From 4bbd4d93b791619e6e0db45020704c3a94055155 Mon Sep 17 00:00:00 2001 From: takuseno Date: Wed, 30 Jan 2019 19:33:33 +0900 Subject: [PATCH 5/6] Rename test_forward_all.py to test_forward_backward_all.py --- python/test/{test_forward_all.py => test_forward_backward_all.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename python/test/{test_forward_all.py => test_forward_backward_all.py} (100%) diff --git a/python/test/test_forward_all.py b/python/test/test_forward_backward_all.py similarity index 100% rename from python/test/test_forward_all.py rename to python/test/test_forward_backward_all.py From 54f4008977ab13c78ad8545d0d131ac68a9c5550 Mon Sep 17 00:00:00 2001 From: takuseno Date: Thu, 7 Feb 2019 17:17:34 +0900 Subject: [PATCH 6/6] Move visit_function_backward and BackwardCallback outside of CgVariable --- include/nbla/computation_graph/variable.hpp | 38 +- .../computation_graph/computation_graph.cpp | 23 +- src/nbla/computation_graph/variable.cpp | 416 ++++++++---------- 3 files changed, 241 insertions(+), 236 deletions(-) diff --git a/include/nbla/computation_graph/variable.hpp b/include/nbla/computation_graph/variable.hpp index 757d1c008..038d3963e 100644 --- a/include/nbla/computation_graph/variable.hpp +++ b/include/nbla/computation_graph/variable.hpp @@ -42,6 +42,13 @@ struct CommunicatorBackwardCallback { typedef shared_ptr CommunicatorBackwardCallbackPtr; +/** visit functions backward to calculate gradients + */ +void visit_function_backward( + vector roots, + std::function backward_callback, + vector communicator_callbacks); + /** Computation graph variable. A Variable object is held in this object as a data container. In addition, @@ -75,11 +82,6 @@ class CgVariable { unordered_set &fclosed, std::function forward_callback); - static void visit_function_backward( - vector roots, - std::function backward_callback, - vector communicator_callbacks); - public: typedef shared_ptr Ptr; @@ -237,12 +239,6 @@ class CgVariable { backward(NdArrayPtr grad = nullptr, bool clear_buffer = false, vector communicator_callbacks = {}); - /** - */ - static void - backward_all(vector variables, bool clear_buffer = false, - vector communicator_callbacks = {}); - /** */ NBLA_API vector function_references(); @@ -305,5 +301,25 @@ class CgVariable { /** shared_ptr typedef of CGVariable */ typedef CgVariable::Ptr CgVariablePtr; + +/** Callback invoked at backward + */ +class BackwardCallback { + bool clear_buffer_; + unordered_map vseen_; + vector get_accum(const vector &inputs, + const vector &first_visit_flags); + void force_zero_grad_if_unseen(vector outputs, + const vector &first_visit); + void clear_output_buffers(CgFunctionPtr func, + const vector &prohibit_clear); + pair, vector> query_outputs_flags( + const vector &outputs); + vector query_input_flags(const vector &inputs, + CgFunctionPtr func); +public: + BackwardCallback(vector roots, bool clear_buffer); + void operator()(CgFunctionPtr f); +}; } #endif diff --git a/src/nbla/computation_graph/computation_graph.cpp b/src/nbla/computation_graph/computation_graph.cpp index e27e869f1..239d5f9bd 100644 --- a/src/nbla/computation_graph/computation_graph.cpp +++ b/src/nbla/computation_graph/computation_graph.cpp @@ -145,6 +145,27 @@ void forward_all(const vector variables, void backward_all(const vector variables, bool clear_buffer, const vector communicator_callbacks) { - CgVariable::backward_all(variables, clear_buffer, communicator_callbacks); + // setup backward at each variable + vector bak_grads; + DestructorCallback at_scope_exit([&]() { + for (int i = 0; i < variables.size(); ++i) { + variables[i]->variable()->set_grad(bak_grads[i]); + } + }); + vector roots; + for (auto v : variables) { + // backup gradients + bak_grads.push_back(v->variable()->grad()); + // set function to avoid clearing + roots.push_back(v->parent()); + } + + // Create callback + BackwardCallback backward_callback(roots, clear_buffer); + + // Visit backward + visit_function_backward( + roots, [&backward_callback](CgFunctionPtr f) { backward_callback(f); }, + communicator_callbacks); } } diff --git a/src/nbla/computation_graph/variable.cpp b/src/nbla/computation_graph/variable.cpp index 309ad59ca..48beff6b5 100644 --- a/src/nbla/computation_graph/variable.cpp +++ b/src/nbla/computation_graph/variable.cpp @@ -55,6 +55,67 @@ CgVariable::CgVariable(VariablePtr var, bool need_grad) : CgVariable(var) { set_need_grad(need_grad); } +void visit_function_backward( + vector roots, + std::function backward_callback, + vector communicator_callbacks) { + // Open list of next search candidate. + unordered_map ids; + /* Returns the ID for each function (layer) */ + auto get_id = [&ids](const CgFunctionPtr &ptr) -> uint64_t { + auto it = ids.find(ptr); + if (it == ids.end()) { + /* Assign an ID to the function */ + auto id = ids.size(); + ids.insert({ptr, id}); + return id; + } + /* Return the previous ID if the ID is already assigned */ + return it->second; + }; + set> open; + for (auto p : roots) { + open.insert(make_tuple(-p->rank(), get_id(p), p)); + } + while (!open.empty()) { + auto rank_func = open.begin(); + auto f = get<2>(*rank_func); + DestructorCallback at_scope_exit([&]() { open.erase(rank_func); }); + // std::cout << "size: " << open.size(); + // std::cout << " --> " << open.size() << std::endl; + if (!f->need_grad()) + continue; + + // Callback + backward_callback(f); + // std::cout << (int)(get<1>(*rank_func)) << ": " << f->rank() << " " + // << f->function()->name() << " " << f.get() << " " << + // open.size() + // << std::endl; + + // + for (auto &com_callback : communicator_callbacks) { + com_callback->on_finish_function_backward(f); + } + + // Propagate down. + auto inputs = f->inputs(); + for (int i = 0; i < f->num_inputs(); i++) { + auto inp = inputs[i]; + if (!inp->need_grad_state()) + continue; + auto p_i = inp->parent(); + if (!p_i) + continue; + open.insert(make_tuple(-p_i->rank(), get_id(p_i), p_i)); + } + } + + for (auto &com_callback : communicator_callbacks) { + com_callback->on_finish_backward(); + } +} + class ForwardCallback { bool clear_buffer_{false}; bool clear_no_need_grad_{false}; @@ -176,165 +237,160 @@ class ForwardCallback { } }; -class BackwardCallback { - bool clear_buffer_; - // Visit CgVaiable list. The value is whether this is cleared during backward. - unordered_map vseen_; - - vector get_accum(const vector &inputs, - const vector &first_visit_flags) { - vector accum(inputs.size(), false); - for (int i = 0; i < inputs.size(); i++) { - // No need grad. - if (!inputs[i]->need_grad_state()) - continue; +vector BackwardCallback::get_accum( + const vector &inputs, + const vector &first_visit_flags) { + vector accum(inputs.size(), false); + for (int i = 0; i < inputs.size(); i++) { + // No need grad. + if (!inputs[i]->need_grad_state()) + continue; - // Root variable is always accumulated. - if (!inputs[i]->parent()) { - accum[i] = true; - continue; - } - // First visit gradients are copied. - if (first_visit_flags[i]) { - continue; - } + // Root variable is always accumulated. + if (!inputs[i]->parent()) { accum[i] = true; + continue; } - return accum; + // First visit gradients are copied. + if (first_visit_flags[i]) { + continue; + } + accum[i] = true; } + return accum; +} - void force_zero_grad_if_unseen(vector outputs, - const vector &first_visit) { - for (int i = 0; i < outputs.size(); i++) { - auto o = outputs[i]; - if (first_visit[i]) { - // The output variable has not been seen during this backprop, which - // means no one sets the gradient previously. To prevent to propagate - // uninitialized gradient, the output gradients are filled as 0. - // std::cout << "Zero-ing output grad of " - // << o->parent()->function()->name() << std::endl; - o->variable()->grad()->zero(); - } +void BackwardCallback::force_zero_grad_if_unseen(vector outputs, + const vector &first_visit) { + for (int i = 0; i < outputs.size(); i++) { + auto o = outputs[i]; + if (first_visit[i]) { + // The output variable has not been seen during this backprop, which + // means no one sets the gradient previously. To prevent to propagate + // uninitialized gradient, the output gradients are filled as 0. + // std::cout << "Zero-ing output grad of " + // << o->parent()->function()->name() << std::endl; + o->variable()->grad()->zero(); } } +} - void clear_output_buffers(CgFunctionPtr func, - const vector &prohibit_clear) { - if (clear_buffer_) { - auto f = func->function(); - auto inputs = func->inputs(); - auto outputs = func->outputs(); - vector> clear(outputs.size(), {true, true}); - for (int i = 0; i < inputs.size(); i++) { - if (f->inplace_data(i)) { - clear[f->inplace_data_with(i)].first = false; - } - if (f->inplace_grad(i)) { - clear[f->inplace_grad_with(i)].second = false; - } +void BackwardCallback::clear_output_buffers(CgFunctionPtr func, + const vector &prohibit_clear) { + if (clear_buffer_) { + auto f = func->function(); + auto inputs = func->inputs(); + auto outputs = func->outputs(); + vector> clear(outputs.size(), {true, true}); + for (int i = 0; i < inputs.size(); i++) { + if (f->inplace_data(i)) { + clear[f->inplace_data_with(i)].first = false; } - for (int o = 0; o < outputs.size(); ++o) { - if (prohibit_clear[o] || outputs[o]->persistent()) { - continue; - } - if (clear[o].first) { - outputs[o]->variable()->data()->array()->clear(); - } - if (clear[o].second) { - outputs[o]->variable()->grad()->array()->clear(); - } + if (f->inplace_grad(i)) { + clear[f->inplace_grad_with(i)].second = false; } } - } - - // Get first visit flags and prohibit clear flags; - // The prohibit clear flags are set by query_input_flags function with inputs - // of a previously called function. - pair, vector> - query_outputs_flags(const vector &outputs) { - vector first_visit(outputs.size()); - vector prohibit_clear(outputs.size()); - for (int i = 0; i < outputs.size(); i++) { - auto v = outputs[i]; - auto it = vseen_.find(v); - bool first = it == vseen_.end(); - if (first) { // first visit - // Terminal variable always doesn't allow to clear buffers. - prohibit_clear[i] = true; - } else { - // Propagate prohibit_clear_inputs_buffers flag from the previous seen - // inputs. - prohibit_clear[i] = it->second; + for (int o = 0; o < outputs.size(); ++o) { + if (prohibit_clear[o] || outputs[o]->persistent()) { + continue; + } + if (clear[o].first) { + outputs[o]->variable()->data()->array()->clear(); + } + if (clear[o].second) { + outputs[o]->variable()->grad()->array()->clear(); } - first_visit[i] = first; } - return {first_visit, prohibit_clear}; } +} - vector query_input_flags(const vector &inputs, - CgFunctionPtr func) { - vector ret(inputs.size()); - bool prohibit_clear = func->function()->prohibit_clear_input_buffers(); - for (int i = 0; i < ret.size(); i++) { - auto v = inputs[i]; - auto it = vseen_.find(v); - bool first_visit = it == vseen_.end(); - ret[i] = first_visit; - if (first_visit) { - vseen_.insert({v, prohibit_clear}); - continue; - } - // Prohibits clearing if any of previous function prohibits clearing +// Get first visit flags and prohibit clear flags; +// The prohibit clear flags are set by query_input_flags function with inputs +// of a previously called function. +pair, vector> +BackwardCallback::query_outputs_flags(const vector &outputs) { + vector first_visit(outputs.size()); + vector prohibit_clear(outputs.size()); + for (int i = 0; i < outputs.size(); i++) { + auto v = outputs[i]; + auto it = vseen_.find(v); + bool first = it == vseen_.end(); + if (first) { // first visit + // Terminal variable always doesn't allow to clear buffers. + prohibit_clear[i] = true; + } else { + // Propagate prohibit_clear_inputs_buffers flag from the previous seen // inputs. - it->second |= prohibit_clear; + prohibit_clear[i] = it->second; } - return ret; + first_visit[i] = first; } + return {first_visit, prohibit_clear}; +} -public: - BackwardCallback(vector roots, bool clear_buffer) - : clear_buffer_(clear_buffer) { - // Note prohibiting clearing variable buffers where terminal. - for (auto v : roots) { - for (auto o : v->outputs()) { - vseen_.insert({o, true}); - } +vector BackwardCallback::query_input_flags( + const vector &inputs, CgFunctionPtr func) { + vector ret(inputs.size()); + bool prohibit_clear = func->function()->prohibit_clear_input_buffers(); + for (int i = 0; i < ret.size(); i++) { + auto v = inputs[i]; + auto it = vseen_.find(v); + bool first_visit = it == vseen_.end(); + ret[i] = first_visit; + if (first_visit) { + vseen_.insert({v, prohibit_clear}); + continue; } + // Prohibits clearing if any of previous function prohibits clearing + // inputs. + it->second |= prohibit_clear; } + return ret; +} - void operator()(CgFunctionPtr f) { - // Check accumulation. - const auto inputs = f->inputs(); - auto first_visit_flags = query_input_flags(inputs, f); - auto accum = get_accum(inputs, first_visit_flags); - - // Get output variables - vector outputs; - vector voutputs; - std::tie(outputs, voutputs) = f->function_outputs(); - - // Query output flags according to previous trace history. - vector output_first_visit_flags; - vector output_prohibit_clear; - std::tie(output_first_visit_flags, output_prohibit_clear) = - query_outputs_flags(outputs); - - // Check if any of outputs is unseen. - force_zero_grad_if_unseen(outputs, output_first_visit_flags); - - // Call backward function - vector prop_down(accum.size()); - std::transform(inputs.begin(), inputs.end(), prop_down.begin(), - [](CgVariablePtr v) { return v->need_grad_state(); }); - // std::cout << f->function()->name() << std::endl; - // std::cout << " " << string_join(prop_down, ",") << std::endl; - // std::cout << " " << string_join(accum, ",") << std::endl; - f->function()->backward(f->function_inputs(), voutputs, prop_down, accum); - - // Clear outputs buffer - clear_output_buffers(f, output_prohibit_clear); +BackwardCallback::BackwardCallback(vector roots, + bool clear_buffer) + : clear_buffer_(clear_buffer) { + // Note prohibiting clearing variable buffers where terminal. + for (auto v : roots) { + for (auto o : v->outputs()) { + vseen_.insert({o, true}); + } } -}; +} + +void BackwardCallback::operator()(CgFunctionPtr f) { + // Check accumulation. + const auto inputs = f->inputs(); + auto first_visit_flags = query_input_flags(inputs, f); + auto accum = get_accum(inputs, first_visit_flags); + + // Get output variables + vector outputs; + vector voutputs; + std::tie(outputs, voutputs) = f->function_outputs(); + + // Query output flags according to previous trace history. + vector output_first_visit_flags; + vector output_prohibit_clear; + std::tie(output_first_visit_flags, output_prohibit_clear) = + query_outputs_flags(outputs); + + // Check if any of outputs is unseen. + force_zero_grad_if_unseen(outputs, output_first_visit_flags); + + // Call backward function + vector prop_down(accum.size()); + std::transform(inputs.begin(), inputs.end(), prop_down.begin(), + [](CgVariablePtr v) { return v->need_grad_state(); }); + // std::cout << f->function()->name() << std::endl; + // std::cout << " " << string_join(prop_down, ",") << std::endl; + // std::cout << " " << string_join(accum, ",") << std::endl; + f->function()->backward(f->function_inputs(), voutputs, prop_down, accum); + + // Clear outputs buffer + clear_output_buffers(f, output_prohibit_clear); +} void CgVariable::visit_function_recursive( CgFunctionPtr func, unordered_set &fclosed, @@ -404,67 +460,6 @@ void CgVariable::visit_function_recursive( // func.get() << std::endl; } -void CgVariable::visit_function_backward( - vector roots, - std::function backward_callback, - vector communicator_callbacks) { - // Open list of next search candidate. - unordered_map ids; - /* Returns the ID for each function (layer) */ - auto get_id = [&ids](const CgFunctionPtr &ptr) -> uint64_t { - auto it = ids.find(ptr); - if (it == ids.end()) { - /* Assign an ID to the function */ - auto id = ids.size(); - ids.insert({ptr, id}); - return id; - } - /* Return the previous ID if the ID is already assigned */ - return it->second; - }; - set> open; - for (auto p : roots) { - open.insert(make_tuple(-p->rank(), get_id(p), p)); - } - while (!open.empty()) { - auto rank_func = open.begin(); - auto f = get<2>(*rank_func); - DestructorCallback at_scope_exit([&]() { open.erase(rank_func); }); - // std::cout << "size: " << open.size(); - // std::cout << " --> " << open.size() << std::endl; - if (!f->need_grad()) - continue; - - // Callback - backward_callback(f); - // std::cout << (int)(get<1>(*rank_func)) << ": " << f->rank() << " " - // << f->function()->name() << " " << f.get() << " " << - // open.size() - // << std::endl; - - // - for (auto &com_callback : communicator_callbacks) { - com_callback->on_finish_function_backward(f); - } - - // Propagate down. - auto inputs = f->inputs(); - for (int i = 0; i < f->num_inputs(); i++) { - auto inp = inputs[i]; - if (!inp->need_grad_state()) - continue; - auto p_i = inp->parent(); - if (!p_i) - continue; - open.insert(make_tuple(-p_i->rank(), get_id(p_i), p_i)); - } - } - - for (auto &com_callback : communicator_callbacks) { - com_callback->on_finish_backward(); - } -} - void CgVariable::forward(bool clear_buffer, bool clear_no_need_grad, unordered_set *fclosed) { if (fclosed == nullptr) { @@ -502,33 +497,6 @@ void CgVariable::backward( communicator_callbacks); } -void CgVariable::backward_all( - vector variables, bool clear_buffer, - vector communicator_callbacks) { - // setup backward at each variable - vector bak_grads; - DestructorCallback at_scope_exit([&]() { - for (int i = 0; i < variables.size(); ++i) { - variables[i]->variable()->set_grad(bak_grads[i]); - } - }); - vector roots; - for (auto v : variables) { - // backup gradients - bak_grads.push_back(v->variable()->grad()); - // set function to avoid clearing - roots.push_back(v->parent()); - } - - // Create callback - BackwardCallback backward_callback(roots, clear_buffer); - - // Visit backward - visit_function_backward( - roots, [&backward_callback](CgFunctionPtr f) { backward_callback(f); }, - communicator_callbacks); -} - vector CgVariable::function_references() { vector ret(this->function_reference_count(), nullptr); int i = 0;