diff --git a/python/jittor/contrib.py b/python/jittor/contrib.py index 525f710d..8b8b3776 100644 --- a/python/jittor/contrib.py +++ b/python/jittor/contrib.py @@ -226,13 +226,11 @@ def concat(arr, dim): shape = list(a.shape) shape[dim] = total_dim s = jt.empty(shape, a.dtype) - slices = [slice(None)]*len(a.shape) + slices = [slice(None)]*(dim+1) for a in arr: if a.shape[dim] == 0: continue slices[dim] = slice(cdim, cdim+a.shape[dim]) - # print(slices, type(a)) s = s.setitem(tuple(slices), a) - # s = jt.setitem(s, tuple(slices), a) cdim += a.shape[dim] return s diff --git a/python/jittor/test/test_setitem.py b/python/jittor/test/test_setitem.py index d1c6014b..f8be4b9e 100644 --- a/python/jittor/test/test_setitem.py +++ b/python/jittor/test/test_setitem.py @@ -8,80 +8,222 @@ import unittest import jittor as jt import numpy as np +from jittor.test.test_log import find_log_with_re skip_this_test = False @unittest.skipIf(skip_this_test, "No Torch found") class TestSetitem(unittest.TestCase): - def test_setitem(self): - arr0 = jt.random((4,2,2)) - data0 = jt.ones((2,2)) - arr0[1] = data0 - arr0.sync() - data0.data[0,0] = 0 - assert arr0[1,0,0] == 0 + def test_getitem_grad_opt1(self): + with jt.log_capture_scope(log_vprefix="setitem_gopt=1000") as rep: + with jt.flag_scope(trace_py_var=2): + v = jt.random((1,2,3,4,5,6)) + for i in range(6): + loss = 0. + ss = 1 + if v.shape[i] % 2 == 0: + ss = 2 + res = v.split(split_size=ss, dim=i) + t_ = [] + for j in range(len(res)): + t = np.random.random(res[j].shape).astype("float32") + loss += res[j] * t + t_.append(t) + dv = jt.grad(loss, v) + dv.sync() + data = jt.dump_trace_data() + jt.clear_trace_data() + assert (dv.numpy() == np.concatenate(t_, i)).all() + logs = find_log_with_re(rep, "getitem_grad_opt happens") + assert len(logs) == 4 - arr00 = jt.random((4,2,2)) - data00 = jt.ones((2,2)) - # share memory will fail if d has an edge to other nodes. - tmp = data00 + 1 - arr00[1] = data00 - arr00.sync() - data00.data[0,0] = 0 - assert arr00[1,0,0] == 0 + def test_setitem_grad_opt1(self): + with jt.log_capture_scope(log_vprefix="setitem_gopt=1000") as rep: + v = jt.random((4,10,10)) + a = jt.ones((10,10)) + b = jt.ones((10,10)) + c = jt.ones((10,10)) + d = jt.ones((10,10)) + v[0] = a + v[1] = b + v[2] = c + v[3] = d + t = np.random.random((4,10,10)).astype("float32") + loss = v*t + da, db, dc, dd = jt.grad(loss, [a,b,c,d]) + jt.sync([da, db, dc, dd]) + assert (da.numpy() == t[0]).all() + assert (db.numpy() == t[1]).all() + assert (dc.numpy() == t[2]).all() + assert (dd.numpy() == t[3]).all() + logs = find_log_with_re(rep, "setitem_grad_opt happens") + logs1 = find_log_with_re(rep, "setitem_grad_opt set success") + assert len(logs) == 1 + assert len(logs1) == 3 + + def test_setitem_grad_opt2(self): + with jt.log_capture_scope(log_vprefix="setitem_gopt=1000") as rep: + v = jt.random((4,10,10)) + a = jt.ones((10,10)) + b = jt.ones((10,10)) + c = jt.ones((10,10)) + d = jt.ones((10,10)) + v[0] = a + v[2] = c + v[1] = b + v[3] = d + t = np.random.random((4,10,10)).astype("float32") + loss = v*t + da, db, dc, dd = jt.grad(loss, [a,b,c,d]) + jt.sync([da, db, dc, dd]) + assert (da.numpy() == t[0]).all() + assert (db.numpy() == t[1]).all() + assert (dc.numpy() == t[2]).all() + assert (dd.numpy() == t[3]).all() + logs = find_log_with_re(rep, "setitem_grad_opt happens") + logs1 = find_log_with_re(rep, "setitem_grad_opt set success") + assert len(logs) == 1 + assert len(logs1) == 2 + + def test_setitem_grad_opt3(self): + with jt.log_capture_scope(log_vprefix="setitem_gopt=1000") as rep: + a = jt.ones((10,10,10)) + b = jt.ones((10,10,10)) + c = jt.ones((10,10,10)) + d = jt.ones((10,10,10)) + v = jt.contrib.concat([a,b,c,d], dim=1) + t = np.random.random((10,40,10)).astype("float32") + loss = v*t + da, db, dc, dd = jt.grad(loss, [a,b,c,d]) + jt.sync([da, db, dc, dd]) + assert (da.numpy() == t[:,:10]).all() + assert (db.numpy() == t[:,10:20]).all() + assert (dc.numpy() == t[:,20:30]).all() + assert (dd.numpy() == t[:,30:40]).all() + logs = find_log_with_re(rep, "setitem_grad_opt happens") + logs1 = find_log_with_re(rep, "setitem_grad_opt set success") + assert len(logs) == 1 + assert len(logs1) == 3 - arr1 = jt.random((4,2,2)) - data1 = jt.zeros((2,2)) - arr1[3,:,0:2] = data1 - arr1.sync() - data1.data[0,0] = 1 - assert arr1[3,0,0] == 1 + def test_setitem1(self): + with jt.log_capture_scope(log_vprefix="setitem_gopt=1000") as rep: + v = jt.random((4,2,2)) + d = jt.ones((2,2)) + v[1] = d + del d + v.sync() + logs = find_log_with_re(rep, "setitem_inplace happens") + assert len(logs) == 1 - arr21 = jt.ones((2,2)) - arr22 = jt.ones((2,2)) * 2 - arr2 = jt.contrib.concat([arr21, arr22], dim=0) - arr2.sync() - arr21.data[0,0] = 3 - arr22.data[0,0] = 4 - assert arr2[0,0] == 3 - assert arr2[2,0] == 4 + def test_setitem2(self): + with jt.log_capture_scope(log_vprefix="setitem_gopt=1000") as rep: + v = jt.random((4,2,2)) + d = jt.ones((2,2)) + v[1] = d + v.sync() + logs = find_log_with_re(rep, "setitem_inplace happens") + assert len(logs) == 0 - def test_getitem(self): - # test for different slice type - arr0 = jt.random((4,3)) - arr0_res = arr0[2,:] - arr0_res.data[1] = 1 - assert arr0[2,1] == 1 + def test_setitem3(self): + with jt.log_capture_scope(log_vprefix="setitem_gopt=1000") as rep: + v = jt.random((4,2,2)) + v1 = v[1:3] + d = jt.ones((2,2)) + v1[1] = d + v1.sync() + logs = find_log_with_re(rep, "setitem_inplace happens") + assert len(logs) == 0 + + def test_setitem4(self): + with jt.log_capture_scope(log_vprefix="setitem_gopt=1000") as rep: + v = jt.random((4,2,2)) + v1 = v[1:3,0] + d = jt.ones((2,)) + v1[1] = d + del d + v1.sync() + logs = find_log_with_re(rep, "setitem_inplace happens") + assert len(logs) == 1 + + def test_setitem5(self): + with jt.log_capture_scope(log_vprefix="setitem_gopt=1000") as rep: + v = jt.random((4,2,2)) + v1 = v[1:3,0] + d = jt.ones((2,2)) + d1 = d[0] + v1[1] = d1 + v1.sync() + logs = find_log_with_re(rep, "setitem_inplace happens") + assert len(logs) == 0 + + def test_setitem6(self): + with jt.log_capture_scope(log_vprefix="setitem_gopt=1000") as rep: + v = jt.random((4,2,2)) + v1 = v[1:3,0] + d = jt.ones((2,2,2)) + d1 = d[0,0] + v1[1] = d1 + del d1 + v1.sync() + logs = find_log_with_re(rep, "setitem_inplace happens") + assert len(logs) == 1 - arr1 = jt.array([1,2,3,4]) - arr1_res = arr1[None] - arr1_res.data[0,2] = -1 - assert arr1[2] == -1 + def test_getitem1(self): + with jt.log_capture_scope(log_vprefix="setitem_gopt=1000") as rep: + v = jt.random((4,3)) + v_res = v[2,:] + v_res.data[1] = 1 + logs = find_log_with_re(rep, "getitem_inplace happens") + assert len(logs) == 1 - arr2 = jt.array([1,2,3,4]) - arr2_res = arr2[...] - arr2_res.data[2] = -1 - assert arr2[2] == -1 + def test_getitem2(self): + with jt.log_capture_scope(log_vprefix="setitem_gopt=1000") as rep: + v1 = jt.array([1,2,3,4]) + v1_res = v1[None] + v1_res.data[0,2] = -1 + logs = find_log_with_re(rep, "getitem_inplace happens") + assert len(logs) == 1 - arr3 = jt.array([1,2,3,4]) - arr3_res = arr3[3] - arr3_res.data[0] = -1 - assert arr3[3] == -1 + def test_getitem3(self): + with jt.log_capture_scope(log_vprefix="setitem_gopt=1000") as rep: + arr3 = jt.array([1,2,3,4]) + arr3_res = arr3[3] + arr3_res.data[0] = -1 + logs = find_log_with_re(rep, "getitem_inplace happens") + assert len(logs) == 1 - arr4 = jt.random((4,2,3,3)) - arr4_res = arr4[...,:,:] - arr4_res.data[0,0,1,1] = 1 - assert arr4[0,0,1,1] == 1 + def test_getitem4(self): + with jt.log_capture_scope(log_vprefix="setitem_gopt=1000") as rep: + arr4 = jt.random((4,2,3,3)) + arr4_res = arr4[...,:,:] + arr4_res.data[0,0,1,1] = 1 + logs = find_log_with_re(rep, "getitem_inplace happens") + assert len(logs) == 1 - arr5 = jt.random((4,2,3,3)) - arr5_res = arr5[1:3,:,:,:] - arr5_res.data[1,0,1,1] = 1 - assert arr5[2,0,1,1] == 1 + def test_getitem5(self): + with jt.log_capture_scope(log_vprefix="setitem_gopt=1000") as rep: + arr5 = jt.random((4,2,3,3)) + arr5_res = arr5[1:3,:,:,:] + arr5_res.data[1,0,1,1] = 1 + logs = find_log_with_re(rep, "getitem_inplace happens") + assert len(logs) == 1 - arr6 = jt.random((4,2,3,3)) - arr6_res = arr6[1] - arr6_res.data[0,1,1] = 1 - assert arr6[1,0,1,1] == 1 + def test_getitem6(self): + with jt.log_capture_scope(log_vprefix="setitem_gopt=1000") as rep: + arr6 = jt.random((4,2,3,3)) + arr6_res = arr6[1] + arr6_res.data[0,1,1] = 1 + logs = find_log_with_re(rep, "getitem_inplace happens") + assert len(logs) == 1 + + def test_getitem7(self): + with jt.log_capture_scope(log_vprefix="setitem_gopt=1000") as rep: + arr2 = jt.array([1,2,3,4]) + arr2_res = arr2[...] + arr2_res.data[2] = -1 + logs = find_log_with_re(rep, "getitem_inplace happens") + assert len(logs) == 1 + def test_getitem8(self): # test for different data type (float32/float64/bool/int8/int32) arr_float32 = jt.random((4,2,3)) arr_float32_res = arr_float32[1:3,:,:] diff --git a/src/executor.cc b/src/executor.cc index 978a444e..bb8cd500 100644 --- a/src/executor.cc +++ b/src/executor.cc @@ -128,8 +128,10 @@ void Executor::run_sync(vector vars, bool device_sync) { } } if (!need_opt) break; + toplogical_sort_forward_inplace(bfs_q, [&](Node* n) {}); + SetupFreeBuffer setup_free_buffer; for (Node* n : bfs_q) { - if (n->flags.get(NodeFlags::_has_gopt)) { + if (!n->need_free() && n->flags.get(NodeFlags::_has_gopt)) { n->op()->graph_optimize(); n->flags.set(NodeFlags::_has_gopt, 0); } diff --git a/src/graph.h b/src/graph.h index 406480f2..971c3f16 100644 --- a/src/graph.h +++ b/src/graph.h @@ -111,6 +111,34 @@ void toplogical_sort_forward(vector& nodes, vector& sorted, Func&& ASSERTop(nodes.size(),==,sorted.size()); } +template +void toplogical_sort_forward_inplace(vector& nodes, Func&& func) { + auto t = ++Node::tflag_count; + int sorted_size = 0; + for (auto node : nodes) node->tflag = t; + for (auto node : nodes) { + auto& deps = node->custom_data; + deps = 0; + for (auto i : node->_inputs) + if (i.node->tflag == t) + deps++; + if (deps == 0) + nodes[sorted_size++] = node; + } + int i=0; + while (i < sorted_size) { + Node* node = nodes[i++]; + for (auto o : node->_outputs) + if (o.node->tflag == t) { + o.node->custom_data--; + if (o.node->custom_data == 0) + nodes[sorted_size++] = o.node; + } + func(node); + } + ASSERTop(nodes.size(),==,sorted_size); +} + template void toplogical_sort_backward(vector& nodes, vector& sorted, Func&& func) { diff --git a/src/ops/getitem_op.cc b/src/ops/getitem_op.cc index d316283d..cc283b62 100644 --- a/src/ops/getitem_op.cc +++ b/src/ops/getitem_op.cc @@ -447,6 +447,9 @@ void GetitemOp::jit_run() { auto in = inputs().front(); auto out = outputs().front(); if (out->num == 0) return; + if (out->allocation == in->allocation && + out->allocator == in->allocator) + return; @for(i, 0, ODIM, index_t oshape@i = o_shape[@i];) @if(ODIM>0, diff --git a/src/opt/gopt/setitem_gopt.cc b/src/opt/gopt/setitem_gopt.cc index 0d0bd5b1..02ac3702 100644 --- a/src/opt/gopt/setitem_gopt.cc +++ b/src/opt/gopt/setitem_gopt.cc @@ -8,6 +8,7 @@ #include "var.h" #include "ops/setitem_op.h" #include "ops/getitem_op.h" +#include "ops/op_register.h" namespace jittor { @@ -16,11 +17,15 @@ inline static bool fast_strcmp(const char* a, const char* b) { return !*b; } +static auto make_empty = get_op_info("empty") + .get_constructor(); + static void setitem_inplace(SetitemOp* op) { - // LOGir << "in setitem_inplace"; - auto input = op->inputs().front(); - if (!(input->outputs().size() == 1 && - input->forward_liveness<=1 && + LOGvvvv << "setitem_inplace"; + if (!op->flags.get(NodeFlags::_has_gopt)) + return; + auto input = op->inputs().front(); + if (!(input->backward_liveness<=1 && (op->op == ns_void || op->op == ns_add || op->op == ns_subtract))) { return; } @@ -28,18 +33,16 @@ static void setitem_inplace(SetitemOp* op) { if (input_op) { // make sure input op will not use input auto input_name = input_op->name(); - if (!(input_op->type() == OpType::broadcast || - input_op->inputs().size() == 0 || - fast_strcmp(input_name, "setitem") || - fast_strcmp(input_name, "getitem"))) - // TODO: inplace getitem maybe risky, getitem maybe inplace too + // if it is not setitem and been inplaced + if (!fast_strcmp(input_name, "setitem") && + (!input->mem_ptr && input->allocator)) return; } + auto output = op->outputs().front(); output->share_with(input); - // LOGir << "pass setitem optim one"; - + // data shares memory with input auto data = op->input(1); input_op = input->input(); @@ -51,7 +54,12 @@ static void setitem_inplace(SetitemOp* op) { } VarSlices vs = op->vs; - if (!(data->is_finished() == 0 && (data->outputs().size() == 1 || (!input_op || input_op->inputs().size() == 0)))) + /* data can share memory with input, which must suit: + - data must be not finished + - data has no other output + */ + if (data->is_finished() || data->backward_liveness>1) + // || input_op || input_op->inputs().size() == 0))) return; auto in_shape = input->shape; @@ -74,21 +82,76 @@ static void setitem_inplace(SetitemOp* op) { data->input()->add_inputs(vector{input}); data->share_with(input, size); - // LOGir << "pass setitem optim two"; + op->flags.set(NodeFlags::_has_gopt, 0); + LOGvvvv << "setitem_inplace happens"; } -struct BBox { - int n = 0; - int* minmax = nullptr; - +static void getitem_grad_opt(SetitemOp* op) { + LOGvvvv << "getitem_grad_opt"; + if (!op->flags.get(NodeFlags::_has_gopt)) + return; + bool last = true; + SetitemOp* last_set_op = nullptr; + Var* last_dv = nullptr; + while (1) { + // op is a setitem op, out is dv + auto cur_dv = op->outputs().front(); + setitem_inplace(op); - void load_var_slice(const VarSlice& vs) { + // out_op is a binary.add op + auto dv_out_op = cur_dv->outputs().front(); + + if (dv_out_op == nullptr) return; + if (dv_out_op && !fast_strcmp(dv_out_op->name(), "binary")) return; + + Var* pre_dv = nullptr; + for (auto* tmp : dv_out_op->inputs()) { + if (tmp != cur_dv) { pre_dv = tmp; break; } + } + + auto pre_dv_in_op = pre_dv->input(); + + if (last) { + last_dv = dv_out_op->outputs().front(); + last_set_op = op; + last = false; + } + + if (fast_strcmp(pre_dv_in_op->name(), "binary")) { + for (auto* tmp : pre_dv_in_op->inputs()) { + if (fast_strcmp(tmp->input()->name(), "setitem")) { + pre_dv = tmp; + break; + } + } + op->set_inputs(list{pre_dv, op->inputs().back()}); + op = (SetitemOp *)(pre_dv->input()); + op->flags.set(NodeFlags::_has_gopt, 0); + } + else if (fast_strcmp(pre_dv_in_op->name(), "setitem")) { + op->set_inputs(list{pre_dv, op->inputs().back()}); + auto ori_v = pre_dv->input()->inputs().front(); + auto tmp_v = make_empty(ori_v->shape, ori_v->dtype()); + ori_v->set_inputs({{tmp_v->input()}}); + tmp_v->set_inputs({}); + op->flags.set(NodeFlags::_has_gopt, 0); + pre_dv->input()->flags.set(NodeFlags::_has_gopt, 0); + break; + } + } -}; + last_dv->set_inputs({{last_set_op}}); + if (last_set_op->outputs().size() == 2) { + last_set_op->outputs().front()->set_inputs({}); + ASSERT(last_set_op->outputs().size() == 1) << last_set_op->outputs(); + } + LOGvvvv << "getitem_grad_opt happens"; +} static void setitem_grad_opt(GetitemOp* op) { + LOGvvvv << "setitem_grad_opt"; if (!op->flags.get(NodeFlags::_has_gopt)) return; auto get_in = op->inputs().front(); @@ -126,21 +189,84 @@ static void setitem_grad_opt(GetitemOp* op) { last_set = next; chain.push_back(next); } - // LOGir << "find setitem chain" << chain.size() << chain; + if (chain.size() == 0) return; + SetitemOp* cur_op = chain[0]; + VarSlices vs = chain[0]->vs; + + // only suppot :*n, int or slice now + int idx_min = -1; + int idx_max = -1; + int idx = -1; + for (int i = 0; i < vs.n; ++i) { + VarSlice s = vs.slices[i]; + if (s.is_int()) { + idx_min = s.i; + idx_max = s.i; + if (idx != -1) return; + idx = i; + } + else if (s.is_slice()) { + idx_min = s.slice.start; + idx_max = s.slice.stop; + if (idx_min == 0 && idx_max == -1) continue; + if (idx != -1) return; + idx = i; + } + } + for (auto* sop : chain) { - // LOGig << sop << sop->vs; auto out_var = sop->outputs().front(); + auto in_var = cur_op->inputs().front(); for (auto* out : out_var->outputs()) { if (fast_strcmp(out->name(), "getitem")) { - out->flags.set(NodeFlags::_has_gopt, 0); + GetitemOp* cur_get_op = (GetitemOp*)out; + VarSlices vs = cur_get_op->vs; + + int cur_idx_min = -1, cur_idx_max = -1; + for (int i = 0; i < vs.n; ++i) { + VarSlice s = vs.slices[i]; + if (s.is_int()) { + cur_idx_min = s.i; + cur_idx_max = s.i; + if (i != idx) return; + } + else if (s.is_slice()) { + cur_idx_min = s.slice.start; + cur_idx_max = s.slice.stop-1; + if (cur_idx_min == 0 && cur_idx_max == -2) continue; + if (i != idx) return; + } + } + + int flag = 0; + // 括号数组,如果当前的区间与之前记录的区间没有overlap就直接share memory + if (cur_idx_max < idx_min) { + idx_min = cur_idx_min; + flag = 1; + } + else if (cur_idx_min > idx_max) { + idx_max = cur_idx_max; + flag = 1; + } + else + cur_op = sop; + + if (flag == 1) { + LOGvvvv << "setitem_grad_opt set success"; + cur_get_op->set_inputs({in_var}); + } } + out->flags.set(NodeFlags::_has_gopt, 0); } } - + LOGvvvv << "setitem_grad_opt happens"; } static void getitem_inplace(GetitemOp* op) { - // LOGir << "in getitem_inplace"; + LOGvvvv << "getitem_inplace"; + + if (!op->flags.get(NodeFlags::_has_gopt)) + return; auto in = op->inputs().front(); auto ou = op->outputs().front(); @@ -169,20 +295,21 @@ static void getitem_inplace(GetitemOp* op) { else if (s.is_slice()) size = s.slice.start * in->size / in_shape[0]; ou->share_with(in, size); - // LOGir << "pass getitem_inplace"; + op->flags.set(NodeFlags::_has_gopt, 0); + LOGvvvv << "getitem_inplace happens"; } void SetitemOp::graph_optimize() { - // LOGir << "hello graph_optimize"; + // LOGvvvv << "hello graph_optimize"; + // (void)getitem_grad_opt; + getitem_grad_opt(this); setitem_inplace(this); } void GetitemOp::graph_optimize() { // This optimize is still WIP - // LOGir << "hello getitem graph_optimize"; - // setitem_grad_opt(this); - (void)setitem_grad_opt; - // (void)getitem_inplace; + // LOGvvvv << "hello getitem graph_optimize"; + setitem_grad_opt(this); getitem_inplace(this); }