Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pnnx fix some undefined dtype #5382

Merged
merged 1 commit into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions tools/pnnx/src/pass_level2/Tensor_new_empty.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};

Expand Down
31 changes: 19 additions & 12 deletions tools/pnnx/src/pass_level2/Tensor_new_ones.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};

Expand Down
31 changes: 19 additions & 12 deletions tools/pnnx/src/pass_level2/Tensor_new_zeros.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};

Expand Down
31 changes: 19 additions & 12 deletions tools/pnnx/src/pass_level2/Tensor_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}

op->params["copy"] = captured_params.at("copy");

Expand Down
31 changes: 19 additions & 12 deletions tools/pnnx/src/pass_level2/torch_empty.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};

Expand Down
31 changes: 19 additions & 12 deletions tools/pnnx/src/pass_level2/torch_empty_like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};

Expand Down
31 changes: 19 additions & 12 deletions tools/pnnx/src/pass_level2/torch_full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};

Expand Down
31 changes: 19 additions & 12 deletions tools/pnnx/src/pass_level2/torch_full_like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,25 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
if (captured_params.at("dtype").type == 0)
{
op->params["dtype"] = Parameter();
}
else // if (captured_params.at("dtype").type == 2)
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";
}
}
};

Expand Down
Loading
Loading