Skip to content

Commit

Permalink
get dict list tuple inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Dec 18, 2024
1 parent 98c78c4 commit 4909fc3
Showing 1 changed file with 82 additions and 99 deletions.
181 changes: 82 additions & 99 deletions tools/pnnx/src/load_torchscript.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,35 @@ static void print_shape_list(const std::vector<std::vector<int64_t> >& shapes, c
}
}

static void append_input(std::vector<std::vector<int64_t> >& input_shapes, std::vector<std::string>& input_types, const torch::jit::IValue& v)
{
if (v.isTensor())
{
const auto& tensor = v.toTensor();
input_shapes.push_back(tensor.sizes().vec());
input_types.push_back(get_at_tensor_type_str(tensor.scalar_type()));
}
else if (v.isList())
{
for (const auto& v2 : v.toList())
append_input(input_shapes, input_types, v2);
}
else if (v.isTuple())
{
for (const auto& v2 : v.toTuple()->elements())
append_input(input_shapes, input_types, v2);
}
else if (v.isGenericDict())
{
for (const auto& kv2 : v.toGenericDict())
append_input(input_shapes, input_types, kv2.value());
}
else
{
fprintf(stderr, "unsupported traced input type %s\n", v.tagKind().c_str());
}
}

static void get_traced_input_shape(const std::string& ptpath, std::vector<std::vector<int64_t> >& input_shapes, std::vector<std::string>& input_types)
{
try
Expand All @@ -486,25 +515,7 @@ static void get_traced_input_shape(const std::string& ptpath, std::vector<std::v
if (entry.key() != "forward")
continue;

auto inputs = entry.value().toList().vec();

input_shapes.resize(inputs.size());
input_types.resize(inputs.size());

for (size_t i = 0; i < inputs.size(); i++)
{
const auto& tensor = inputs[i].toTensor();
std::vector<long> shape = tensor.sizes().vec();
at::ScalarType datatype = tensor.scalar_type();

input_shapes[i] = shape;
input_types[i] = get_at_tensor_type_str(datatype);
}

fprintf(stderr, "use inputshape from traced inputs\n");
fprintf(stderr, "inputshape = ");
print_shape_list(input_shapes, input_types);
fprintf(stderr, "\n");
append_input(input_shapes, input_types, entry.value());
break;
}
}
Expand All @@ -514,91 +525,63 @@ static void get_traced_input_shape(const std::string& ptpath, std::vector<std::v
}
}

static bool check_input_shape(const std::string& ptpath, const std::vector<std::vector<int64_t> >& input_shapes, const std::vector<std::string>& input_types)
static bool check_input_shape(const std::vector<std::vector<int64_t> >& traced_input_shapes, const std::vector<std::string>& traced_input_types, const std::vector<std::vector<int64_t> >& input_shapes, const std::vector<std::string>& input_types)
{
try
if (input_shapes.size() != traced_input_shapes.size())
{
// read traced_inputs.pkl
caffe2::serialize::PyTorchStreamReader reader(ptpath);
auto dict = torch::jit::readArchiveAndTensors("traced_inputs", "", "traced_inputs/", std::nullopt, std::nullopt, std::nullopt, reader).toGenericDict();
fprintf(stderr, "input_shape expect %d tensors but got %d\n", (int)traced_input_shapes.size(), (int)input_shapes.size());
return false;
}

for (const auto& entry : dict)
for (size_t i = 0; i < traced_input_shapes.size(); i++)
{
bool matched = true;

if (input_shapes[i].size() != traced_input_shapes[i].size())
{
if (entry.key() != "forward")
continue;
matched = false;
}
else
{
for (size_t j = 0; j < traced_input_shapes[i].size(); j++)
{
if (input_shapes[i][j] != traced_input_shapes[i][j])
matched = false;
}
}

auto inputs = entry.value().toList().vec();
if (input_types[i] != traced_input_types[i])
matched = false;

if (!input_shapes.empty() && input_shapes.size() != inputs.size())
if (!matched)
{
fprintf(stderr, "input_shapes[%d] expect [", (int)i);
for (size_t j = 0; j < traced_input_shapes[i].size(); j++)
{
fprintf(stderr, "input_shape expect %d tensors but got %d\n", (int)inputs.size(), (int)input_shapes.size());
return false;
fprintf(stderr, "%ld", traced_input_shapes[i][j]);
if (j + 1 != traced_input_shapes[i].size())
fprintf(stderr, ",");
}

for (size_t i = 0; i < inputs.size(); i++)
fprintf(stderr, "]%s but got ", traced_input_types[i].c_str());
if (input_shapes.empty())
{
const auto& tensor = inputs[i].toTensor();
std::vector<long> shape = tensor.sizes().vec();
at::ScalarType datatype = tensor.scalar_type();

std::cerr << "input " << tensor.sizes() << " " << tensor.scalar_type() << std::endl;

bool matched = true;

if (input_shapes[i].size() != shape.size())
{
matched = false;
}
else
{
for (size_t j = 0; j < shape.size(); j++)
{
if (input_shapes[i][j] != shape[j])
matched = false;
}
}

if (input_types[i] != get_at_tensor_type_str(datatype))
matched = false;

if (!matched)
fprintf(stderr, "nothing\n");
}
else
{
fprintf(stderr, "[");
for (size_t j = 0; j < input_shapes[i].size(); j++)
{
fprintf(stderr, "input_shapes[%d] expect [", (int)i);
for (size_t j = 0; j < shape.size(); j++)
{
fprintf(stderr, "%ld", shape[j]);
if (j + 1 != shape.size())
fprintf(stderr, ",");
}
fprintf(stderr, "]%s but got ", get_at_tensor_type_str(datatype));
if (input_shapes.empty())
{
fprintf(stderr, "nothing\n");
}
else
{
fprintf(stderr, "[");
for (size_t j = 0; j < input_shapes[i].size(); j++)
{
fprintf(stderr, "%ld", input_shapes[i][j]);
if (j + 1 != input_shapes[i].size())
fprintf(stderr, ",");
}
fprintf(stderr, "]%s\n", input_types[i].c_str());
}

return false;
fprintf(stderr, "%ld", input_shapes[i][j]);
if (j + 1 != input_shapes[i].size())
fprintf(stderr, ",");
}
fprintf(stderr, "]%s\n", input_types[i].c_str());
}

break;
return false;
}
}
catch (...)
{
// no traced_inputs.pkl pass
return true;
}

return true;
}
Expand All @@ -614,23 +597,23 @@ int load_torchscript(const std::string& ptpath, Graph& pnnx_graph,
const std::string& foldable_constants_zippath,
std::set<std::string>& foldable_constants)
{
// get input shape from traced torchscript
std::vector<std::vector<int64_t> > traced_input_shapes;
std::vector<std::string> traced_input_types;
if (input_shapes.empty())
{
// get input shape from traced torchscript
get_traced_input_shape(ptpath, traced_input_shapes, traced_input_types);
}
else
get_traced_input_shape(ptpath, traced_input_shapes, traced_input_types);

fprintf(stderr, "get inputshape from traced inputs\n");
fprintf(stderr, "inputshape = ");
print_shape_list(traced_input_shapes, traced_input_types);
fprintf(stderr, "\n");

if (!input_shapes.empty())
{
// input shape sanity check
if (!check_input_shape(ptpath, input_shapes, input_types))
if (!check_input_shape(traced_input_shapes, traced_input_types, input_shapes, input_types))
{
return -1;
}

traced_input_shapes = input_shapes;
traced_input_types = input_types;
}
// traced torchscript always has static input shapes
// if (!input_shapes2.empty() && !check_input_shape(ptpath, input_shapes2, input_types2))
Expand Down

0 comments on commit 4909fc3

Please sign in to comment.