diff --git a/tools/pnnx/src/pass_level2/Tensor_new_empty.cpp b/tools/pnnx/src/pass_level2/Tensor_new_empty.cpp index 6f0f0cb68d2..215b17e2c13 100644 --- a/tools/pnnx/src/pass_level2/Tensor_new_empty.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_new_empty.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/Tensor_new_ones.cpp b/tools/pnnx/src/pass_level2/Tensor_new_ones.cpp index 03df3028b8b..3fe4c3390be 100644 --- a/tools/pnnx/src/pass_level2/Tensor_new_ones.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_new_ones.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/Tensor_new_zeros.cpp b/tools/pnnx/src/pass_level2/Tensor_new_zeros.cpp index c3aa069dc63..93963f2a266 100644 --- a/tools/pnnx/src/pass_level2/Tensor_new_zeros.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_new_zeros.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/Tensor_to.cpp b/tools/pnnx/src/pass_level2/Tensor_to.cpp index 8ab1f124960..6d7cd9e7dc6 100644 --- a/tools/pnnx/src/pass_level2/Tensor_to.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_to.cpp @@ -40,18 +40,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } op->params["copy"] = captured_params.at("copy"); diff --git a/tools/pnnx/src/pass_level2/torch_empty.cpp b/tools/pnnx/src/pass_level2/torch_empty.cpp index 92244e2e456..3c6a074cbd0 100644 --- a/tools/pnnx/src/pass_level2/torch_empty.cpp +++ b/tools/pnnx/src/pass_level2/torch_empty.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_empty_like.cpp b/tools/pnnx/src/pass_level2/torch_empty_like.cpp index 13c145c969e..baa2f74c0cf 100644 --- a/tools/pnnx/src/pass_level2/torch_empty_like.cpp +++ b/tools/pnnx/src/pass_level2/torch_empty_like.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_full.cpp b/tools/pnnx/src/pass_level2/torch_full.cpp index 293fad2e9b6..718a0796a53 100644 --- a/tools/pnnx/src/pass_level2/torch_full.cpp +++ b/tools/pnnx/src/pass_level2/torch_full.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_full_like.cpp b/tools/pnnx/src/pass_level2/torch_full_like.cpp index 67f2a6f58b6..4d58df9c7c7 100644 --- a/tools/pnnx/src/pass_level2/torch_full_like.cpp +++ b/tools/pnnx/src/pass_level2/torch_full_like.cpp @@ -42,18 +42,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_ones.cpp b/tools/pnnx/src/pass_level2/torch_ones.cpp index d055b346631..888397a97c5 100644 --- a/tools/pnnx/src/pass_level2/torch_ones.cpp +++ b/tools/pnnx/src/pass_level2/torch_ones.cpp @@ -40,18 +40,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_ones_like.cpp b/tools/pnnx/src/pass_level2/torch_ones_like.cpp index 312ea8ed95a..8837b0fdd5f 100644 --- a/tools/pnnx/src/pass_level2/torch_ones_like.cpp +++ b/tools/pnnx/src/pass_level2/torch_ones_like.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_randn.cpp b/tools/pnnx/src/pass_level2/torch_randn.cpp index 5cbfc33fea9..345c4e495d5 100644 --- a/tools/pnnx/src/pass_level2/torch_randn.cpp +++ b/tools/pnnx/src/pass_level2/torch_randn.cpp @@ -40,18 +40,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_randn_like.cpp b/tools/pnnx/src/pass_level2/torch_randn_like.cpp index da74dec04a2..68c1dc9dcb6 100644 --- a/tools/pnnx/src/pass_level2/torch_randn_like.cpp +++ b/tools/pnnx/src/pass_level2/torch_randn_like.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_zeros.cpp b/tools/pnnx/src/pass_level2/torch_zeros.cpp index 8b53d1652b0..90213fdde5b 100644 --- a/tools/pnnx/src/pass_level2/torch_zeros.cpp +++ b/tools/pnnx/src/pass_level2/torch_zeros.cpp @@ -40,18 +40,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } }; diff --git a/tools/pnnx/src/pass_level2/torch_zeros_like.cpp b/tools/pnnx/src/pass_level2/torch_zeros_like.cpp index 85a0bd22490..5babbbb55a7 100644 --- a/tools/pnnx/src/pass_level2/torch_zeros_like.cpp +++ b/tools/pnnx/src/pass_level2/torch_zeros_like.cpp @@ -41,18 +41,25 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; - if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; - if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; - if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; - if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; - if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; - if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; - if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; - if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; - if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; - if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; - if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + if (captured_params.at("dtype").type == 0) + { + op->params["dtype"] = Parameter(); + } + else // if (captured_params.at("dtype").type == 2) + { + if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; + if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; + if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; + if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; + if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; + if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; + if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; + if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; + if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; + if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; + if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; + if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; + } } };