Skip to content

Commit

Permalink
[luci] Support quantized inputs in ReplaceNonConstFCWithBatchMatMulPa…
Browse files Browse the repository at this point in the history
…ss (#14487)

This supports quantized inputs in ReplaceNonConstFCWithBatchMatMulPass.

ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong <[email protected]>
  • Loading branch information
jinevening authored Dec 24, 2024
1 parent 5c65065 commit 5754f4a
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleQuantParam.h>
#include <luci/Profile/CircleNodeOrigin.h>
#include <luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h>

Expand Down Expand Up @@ -95,6 +96,8 @@ luci::CircleReshape *create_reshape(luci::CircleFullyConnected *node)

reshape->shape(shape_const);

luci::copy_quantparam(node, reshape);

return reshape;
}

Expand Down Expand Up @@ -165,9 +168,6 @@ bool replace_fc_with_matmul(luci::CircleFullyConnected *fc)
x = loco::must_cast<luci::CircleNode *>(fc->input());
}

if (x->dtype() != loco::DataType::FLOAT32 || y->dtype() != loco::DataType::FLOAT32)
return false;

auto bc = dynamic_cast<luci::CircleConst *>(fc->bias());
// NOTE bias can be empty as CircleOutputExclude type
// NOTE we can only handle bias as FLOAT32 type as of now
Expand All @@ -185,6 +185,8 @@ bool replace_fc_with_matmul(luci::CircleFullyConnected *fc)
matmul->name(name);
matmul->dtype(fc->dtype());

luci::copy_quantparam(fc, matmul);

luci::add_origin(matmul, luci::get_origin(fc));

auto reshape = create_reshape(fc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,56 @@ struct FCGraphlet
luci::CircleInput *_y = nullptr;
};

struct S16FCGraphlet
{
public:
S16FCGraphlet() = default;
virtual ~S16FCGraphlet() = default;

void init(loco::Graph *g, const ShapeU32 r_shape)
{
_tr_x = g->nodes()->create<luci::CircleTranspose>();
_tr_x->a(_x);
std::vector<int32_t> tr_x_val = {1, 0};
_tr_x->perm(luci::create_const_node(g, loco::DataType::S32, {2}, tr_x_val));
_tr_x->dtype(loco::DataType::S16);

_tr_y = g->nodes()->create<luci::CircleTranspose>();
_tr_y->a(_y);
std::vector<int32_t> tr_y_val = {1, 0};
_tr_y->perm(luci::create_const_node(g, loco::DataType::S32, {2}, tr_y_val));
_tr_y->dtype(loco::DataType::S16);

_fc = g->nodes()->create<luci::CircleFullyConnected>();
_fc->input(_tr_x);
_fc->weights(_tr_y);
_fc->fusedActivationFunction(luci::FusedActFunc::NONE);
_fc->dtype(loco::DataType::S16);
_fc->shape(r_shape);

auto no_bias = g->nodes()->create<luci::CircleOutputExclude>();
_fc->bias(no_bias);
_fc->name("fc");

auto qparam = std::make_unique<luci::CircleQuantParam>();
{
qparam->scale = {1.0};
qparam->zerop = {0};
}
_fc->quantparam(std::move(qparam));
}

public:
luci::CircleFullyConnected *fc() { return _fc; }

protected:
luci::CircleFullyConnected *_fc = nullptr;
luci::CircleTranspose *_tr_x = nullptr;
luci::CircleTranspose *_tr_y = nullptr;
luci::CircleInput *_x = nullptr;
luci::CircleInput *_y = nullptr;
};

struct FCGraph : public TestIsGraphlet<2>, public TestOGraphlet, public FCGraphlet
{
FCGraph() = default;
Expand All @@ -99,13 +149,33 @@ struct FCGraph : public TestIsGraphlet<2>, public TestOGraphlet, public FCGraphl
}
};

struct S16FCGraph : public TestIsGraphlet<2>, public TestOGraphlet, public S16FCGraphlet
{
void init(const ShapeU32 x_shape, const ShapeU32 y_shape, const ShapeU32 r_shape)
{
TestIsGraphlet<2>::init(g(), {x_shape, y_shape});
TestOGraphlet::init(g(), r_shape);
_x = input(0);
_y = input(1);
S16FCGraphlet::init(g(), r_shape);
output()->from(_fc);
}
};

class ReplaceNonConstFCWithBatchMatMulPassTest : public ::testing::Test
{
public:
FCGraph g;
luci::ReplaceNonConstFCWithBatchMatMulPass pass;
};

class ReplaceNonConstS16FCWithBatchMatMulPassTest : public ::testing::Test
{
public:
S16FCGraph g;
luci::ReplaceNonConstFCWithBatchMatMulPass pass;
};

} // namespace

TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, simple_test)
Expand All @@ -130,6 +200,22 @@ TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, nonzero_bias_test)
EXPECT_NE(nullptr, mm);
}

TEST_F(ReplaceNonConstS16FCWithBatchMatMulPassTest, s16_test)
{
g.init({2, 3}, {2, 3}, {2, 2});

auto ret = pass.run(g.g());
EXPECT_EQ(true, ret);

auto res = dynamic_cast<luci::CircleReshape *>(g.output()->from());
EXPECT_NE(nullptr, res);

auto qparam = res->quantparam();
EXPECT_NE(nullptr, qparam);
EXPECT_EQ(1.0, qparam->scale[0]);
EXPECT_EQ(0, qparam->zerop[0]);
}

TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, wrong_op_NEG)
{
loco::Graph g;
Expand Down

0 comments on commit 5754f4a

Please sign in to comment.