forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimethod.cpp
64 lines (51 loc) · 2.31 KB
/
imethod.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
#include <gtest/gtest.h>
#include <torch/csrc/deploy/deploy.h>
#include <torch/script.h>
#include <torch/torch.h>
using namespace ::testing;
using namespace caffe2;
const char* simple = "torch/csrc/deploy/example/generated/simple";
const char* simpleJit = "torch/csrc/deploy/example/generated/simple_jit";
// TODO(jwtan): Try unifying cmake and buck for getting the path.
const char* path(const char* envname, const char* path) {
const char* env = getenv(envname);
return env ? env : path;
}
// Run `python torch/csrc/deploy/example/generate_examples.py` before running
// the following tests.
// TODO(jwtan): Figure out a way to automate the above step for development. (CI
// has it already.)
TEST(IMethodTest, CallMethod) {
auto scriptModel = torch::jit::load(path("SIMPLE_JIT", simpleJit));
auto scriptMethod = scriptModel.get_method("forward");
torch::deploy::InterpreterManager manager(3);
torch::deploy::Package package = manager.loadPackage(path("SIMPLE", simple));
auto pyModel = package.loadPickle("model", "model.pkl");
torch::deploy::PythonMethodWrapper pyMethod(pyModel, "forward");
EXPECT_EQ(scriptMethod.name(), "forward");
EXPECT_EQ(pyMethod.name(), "forward");
auto input = torch::ones({10, 20});
auto outputPy = pyMethod({input});
auto outputScript = scriptMethod({input});
EXPECT_TRUE(outputPy.isTensor());
EXPECT_TRUE(outputScript.isTensor());
auto outputPyTensor = outputPy.toTensor();
auto outputScriptTensor = outputScript.toTensor();
EXPECT_TRUE(outputPyTensor.equal(outputScriptTensor));
EXPECT_EQ(outputPyTensor.numel(), 200);
}
TEST(IMethodTest, GetArgumentNames) {
auto scriptModel = torch::jit::load(path("SIMPLE_JIT", simpleJit));
auto scriptMethod = scriptModel.get_method("forward");
auto& scriptNames = scriptMethod.getArgumentNames();
EXPECT_EQ(scriptNames.size(), 1);
EXPECT_STREQ(scriptNames[0].c_str(), "input");
torch::deploy::InterpreterManager manager(3);
torch::deploy::Package package = manager.loadPackage(path("SIMPLE", simple));
auto pyModel = package.loadPickle("model", "model.pkl");
torch::deploy::PythonMethodWrapper pyMethod(pyModel, "forward");
auto& pyNames = pyMethod.getArgumentNames();
EXPECT_EQ(pyNames.size(), 1);
EXPECT_STREQ(pyNames[0].c_str(), "input");
}