diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp index 18617e3b77b..89fecbdf827 100644 --- a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp +++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp @@ -107,6 +107,8 @@ void overwrite_quantparam(const luci::CircleNode *source, luci::CircleNode *targ target_qparam->scale = source_qparam->scale; target_qparam->zerop = source_qparam->zerop; target_qparam->quantized_dimension = source_qparam->quantized_dimension; + + target->dtype(source->dtype()); } /** @@ -341,7 +343,6 @@ void propagate_concat_quantparam(luci::CircleConcatenation *concat) continue; // Non-const input must have been quantized - assert(node->quantparam() != nullptr); overwrite_quantparam(concat, node); } } diff --git a/compiler/luci/pass/src/QuantizeOnnxFakeQuantModelPass.cpp b/compiler/luci/pass/src/QuantizeOnnxFakeQuantModelPass.cpp index face706b2cc..b95c7ea150c 100644 --- a/compiler/luci/pass/src/QuantizeOnnxFakeQuantModelPass.cpp +++ b/compiler/luci/pass/src/QuantizeOnnxFakeQuantModelPass.cpp @@ -14,6 +14,7 @@ */ #include "luci/Pass/QuantizeOnnxFakeQuantModelPass.h" +#include "luci/Pass/PropagateQParamBackwardPass.h" #include "QuantizeOnnxQDQPass.h" #include "QuantizeOnnxDequantizeLinearPass.h" #include "QuantizeWithPredecessorPass.h" @@ -92,6 +93,12 @@ bool QuantizeOnnxFakeQuantModelPass::run(loco::Graph *g) pass.run(g); } + // Backward propagation of activation qparam + { + PropagateQParamBackwardPass pqbp(_ctx->default_activation_dtype); + pqbp.run(g); + } + // Update qparam of output of special Ops for (auto node : loco::active_nodes(loco::output_nodes(g))) {