Skip to content

Commit

Permalink
pnnx fix some undefined dtype (#5382)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Mar 20, 2024
1 parent a55fe1c commit 02ba676
Show file tree
Hide file tree
Showing 14 changed files with 266 additions and 168 deletions.
31 changes: 19 additions & 12 deletions tools/pnnx/src/pass_level2/Tensor_new_empty.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};

Expand Down
31 changes: 19 additions & 12 deletions tools/pnnx/src/pass_level2/Tensor_new_ones.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};

Expand Down
31 changes: 19 additions & 12 deletions tools/pnnx/src/pass_level2/Tensor_new_zeros.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};

Expand Down
31 changes: 19 additions & 12 deletions tools/pnnx/src/pass_level2/Tensor_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}

op->params["copy"] = captured_params.at("copy");

Expand Down
31 changes: 19 additions & 12 deletions tools/pnnx/src/pass_level2/torch_empty.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};

Expand Down
31 changes: 19 additions & 12 deletions tools/pnnx/src/pass_level2/torch_empty_like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};

Expand Down
31 changes: 19 additions & 12 deletions tools/pnnx/src/pass_level2/torch_full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};

Expand Down
31 changes: 19 additions & 12 deletions tools/pnnx/src/pass_level2/torch_full_like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};

Expand Down
Loading

0 comments on commit 02ba676

Please sign in to comment.