From 4909fc3c80404797e9f379ea35750d52b67d3207 Mon Sep 17 00:00:00 2001 From: nihuini Date: Wed, 18 Dec 2024 15:56:59 +0800 Subject: [PATCH] get dict list tuple inputs --- tools/pnnx/src/load_torchscript.cpp | 181 +++++++++++++--------------- 1 file changed, 82 insertions(+), 99 deletions(-) diff --git a/tools/pnnx/src/load_torchscript.cpp b/tools/pnnx/src/load_torchscript.cpp index 9a60b346be9..f9d9543c719 100644 --- a/tools/pnnx/src/load_torchscript.cpp +++ b/tools/pnnx/src/load_torchscript.cpp @@ -473,6 +473,35 @@ static void print_shape_list(const std::vector >& shapes, c } } +static void append_input(std::vector >& input_shapes, std::vector& input_types, const torch::jit::IValue& v) +{ + if (v.isTensor()) + { + const auto& tensor = v.toTensor(); + input_shapes.push_back(tensor.sizes().vec()); + input_types.push_back(get_at_tensor_type_str(tensor.scalar_type())); + } + else if (v.isList()) + { + for (const auto& v2 : v.toList()) + append_input(input_shapes, input_types, v2); + } + else if (v.isTuple()) + { + for (const auto& v2 : v.toTuple()->elements()) + append_input(input_shapes, input_types, v2); + } + else if (v.isGenericDict()) + { + for (const auto& kv2 : v.toGenericDict()) + append_input(input_shapes, input_types, kv2.value()); + } + else + { + fprintf(stderr, "unsupported traced input type %s\n", v.tagKind().c_str()); + } +} + static void get_traced_input_shape(const std::string& ptpath, std::vector >& input_shapes, std::vector& input_types) { try @@ -486,25 +515,7 @@ static void get_traced_input_shape(const std::string& ptpath, std::vector shape = tensor.sizes().vec(); - at::ScalarType datatype = tensor.scalar_type(); - - input_shapes[i] = shape; - input_types[i] = get_at_tensor_type_str(datatype); - } - - fprintf(stderr, "use inputshape from traced inputs\n"); - fprintf(stderr, "inputshape = "); - print_shape_list(input_shapes, input_types); - fprintf(stderr, "\n"); + append_input(input_shapes, input_types, entry.value()); break; } } @@ -514,91 +525,63 @@ static void get_traced_input_shape(const std::string& ptpath, std::vector >& input_shapes, const std::vector& input_types) +static bool check_input_shape(const std::vector >& traced_input_shapes, const std::vector& traced_input_types, const std::vector >& input_shapes, const std::vector& input_types) { - try + if (input_shapes.size() != traced_input_shapes.size()) { - // read traced_inputs.pkl - caffe2::serialize::PyTorchStreamReader reader(ptpath); - auto dict = torch::jit::readArchiveAndTensors("traced_inputs", "", "traced_inputs/", std::nullopt, std::nullopt, std::nullopt, reader).toGenericDict(); + fprintf(stderr, "input_shape expect %d tensors but got %d\n", (int)traced_input_shapes.size(), (int)input_shapes.size()); + return false; + } - for (const auto& entry : dict) + for (size_t i = 0; i < traced_input_shapes.size(); i++) + { + bool matched = true; + + if (input_shapes[i].size() != traced_input_shapes[i].size()) { - if (entry.key() != "forward") - continue; + matched = false; + } + else + { + for (size_t j = 0; j < traced_input_shapes[i].size(); j++) + { + if (input_shapes[i][j] != traced_input_shapes[i][j]) + matched = false; + } + } - auto inputs = entry.value().toList().vec(); + if (input_types[i] != traced_input_types[i]) + matched = false; - if (!input_shapes.empty() && input_shapes.size() != inputs.size()) + if (!matched) + { + fprintf(stderr, "input_shapes[%d] expect [", (int)i); + for (size_t j = 0; j < traced_input_shapes[i].size(); j++) { - fprintf(stderr, "input_shape expect %d tensors but got %d\n", (int)inputs.size(), (int)input_shapes.size()); - return false; + fprintf(stderr, "%ld", traced_input_shapes[i][j]); + if (j + 1 != traced_input_shapes[i].size()) + fprintf(stderr, ","); } - - for (size_t i = 0; i < inputs.size(); i++) + fprintf(stderr, "]%s but got ", traced_input_types[i].c_str()); + if (input_shapes.empty()) { - const auto& tensor = inputs[i].toTensor(); - std::vector shape = tensor.sizes().vec(); - at::ScalarType datatype = tensor.scalar_type(); - - std::cerr << "input " << tensor.sizes() << " " << tensor.scalar_type() << std::endl; - - bool matched = true; - - if (input_shapes[i].size() != shape.size()) - { - matched = false; - } - else - { - for (size_t j = 0; j < shape.size(); j++) - { - if (input_shapes[i][j] != shape[j]) - matched = false; - } - } - - if (input_types[i] != get_at_tensor_type_str(datatype)) - matched = false; - - if (!matched) + fprintf(stderr, "nothing\n"); + } + else + { + fprintf(stderr, "["); + for (size_t j = 0; j < input_shapes[i].size(); j++) { - fprintf(stderr, "input_shapes[%d] expect [", (int)i); - for (size_t j = 0; j < shape.size(); j++) - { - fprintf(stderr, "%ld", shape[j]); - if (j + 1 != shape.size()) - fprintf(stderr, ","); - } - fprintf(stderr, "]%s but got ", get_at_tensor_type_str(datatype)); - if (input_shapes.empty()) - { - fprintf(stderr, "nothing\n"); - } - else - { - fprintf(stderr, "["); - for (size_t j = 0; j < input_shapes[i].size(); j++) - { - fprintf(stderr, "%ld", input_shapes[i][j]); - if (j + 1 != input_shapes[i].size()) - fprintf(stderr, ","); - } - fprintf(stderr, "]%s\n", input_types[i].c_str()); - } - - return false; + fprintf(stderr, "%ld", input_shapes[i][j]); + if (j + 1 != input_shapes[i].size()) + fprintf(stderr, ","); } + fprintf(stderr, "]%s\n", input_types[i].c_str()); } - break; + return false; } } - catch (...) - { - // no traced_inputs.pkl pass - return true; - } return true; } @@ -614,23 +597,23 @@ int load_torchscript(const std::string& ptpath, Graph& pnnx_graph, const std::string& foldable_constants_zippath, std::set& foldable_constants) { + // get input shape from traced torchscript std::vector > traced_input_shapes; std::vector traced_input_types; - if (input_shapes.empty()) - { - // get input shape from traced torchscript - get_traced_input_shape(ptpath, traced_input_shapes, traced_input_types); - } - else + get_traced_input_shape(ptpath, traced_input_shapes, traced_input_types); + + fprintf(stderr, "get inputshape from traced inputs\n"); + fprintf(stderr, "inputshape = "); + print_shape_list(traced_input_shapes, traced_input_types); + fprintf(stderr, "\n"); + + if (!input_shapes.empty()) { // input shape sanity check - if (!check_input_shape(ptpath, input_shapes, input_types)) + if (!check_input_shape(traced_input_shapes, traced_input_types, input_shapes, input_types)) { return -1; } - - traced_input_shapes = input_shapes; - traced_input_types = input_types; } // traced torchscript always has static input shapes // if (!input_shapes2.empty() && !check_input_shape(ptpath, input_shapes2, input_types2))