@@ -38,26 +38,41 @@ class MethodTest : public ::testing::Test {
3838 const char * path = std::getenv (" ET_MODULE_ADD_PATH" );
3939 Result<FileDataLoader> loader = FileDataLoader::From (path);
4040 ASSERT_EQ (loader.error (), Error::Ok);
41- loader_ = std::make_unique<FileDataLoader>(std::move (loader.get ()));
41+ add_loader_ = std::make_unique<FileDataLoader>(std::move (loader.get ()));
4242
4343 // Use it to load the program.
4444 Result<Program> program = Program::Load (
45- loader_ .get (), Program::Verification::InternalConsistency);
45+ add_loader_ .get (), Program::Verification::InternalConsistency);
4646 ASSERT_EQ (program.error (), Error::Ok);
47- program_ = std::make_unique<Program>(std::move (program.get ()));
47+ add_program_ = std::make_unique<Program>(std::move (program.get ()));
48+
49+ // Create a loader for the serialized ModuleIndex program.
50+ const char * index_path = std::getenv (" ET_MODULE_INDEX_PATH" );
51+ Result<FileDataLoader> index_loader = FileDataLoader::From (index_path);
52+ ASSERT_EQ (index_loader.error (), Error::Ok);
53+ index_loader_ =
54+ std::make_unique<FileDataLoader>(std::move (index_loader.get ()));
55+
56+ // Use it to load the program.
57+ Result<Program> index_program = Program::Load (
58+ index_loader_.get (), Program::Verification::InternalConsistency);
59+ ASSERT_EQ (index_program.error (), Error::Ok);
60+ index_program_ = std::make_unique<Program>(std::move (index_program.get ()));
4861 }
4962
5063 private:
5164 // Must outlive program_, but tests shouldn't need to touch it.
52- std::unique_ptr<FileDataLoader> loader_;
65+ std::unique_ptr<FileDataLoader> add_loader_;
66+ std::unique_ptr<FileDataLoader> index_loader_;
5367
5468 protected:
55- std::unique_ptr<Program> program_;
69+ std::unique_ptr<Program> add_program_;
70+ std::unique_ptr<Program> index_program_;
5671};
5772
5873TEST_F (MethodTest, MoveTest) {
5974 ManagedMemoryManager mmm (kDefaultNonConstMemBytes , kDefaultRuntimeMemBytes );
60- Result<Method> method = program_ ->load_method (" forward" , &mmm.get ());
75+ Result<Method> method = add_program_ ->load_method (" forward" , &mmm.get ());
6176 ASSERT_EQ (method.error (), Error::Ok);
6277
6378 // Can execute the method.
@@ -79,3 +94,29 @@ TEST_F(MethodTest, MoveTest) {
7994
8095 torch::executor::util::FreeInputs (inputs);
8196}
97+
98+ // TODO(T161163608): Test is disabled due to a resize bug in tensor_index_out of
99+ // the portable op lib
100+
101+ // TEST_F(MethodTest, OptionalTensorListDeserialization) {
102+ // ManagedMemoryManager mmm(kDefaultNonConstMemBytes,
103+ // kDefaultRuntimeMemBytes); Result<Method> method =
104+ // index_program_->load_method("forward", &mmm.get());
105+ // ASSERT_EQ(method.error(), Error::Ok);
106+
107+ // // Can execute the method.
108+ // exec_aten::ArrayRef<void*> inputs =
109+ // torch::executor::util::PrepareInputTensors(*method);
110+ // Error err = method->execute();
111+ // ASSERT_EQ(err, Error::Ok);
112+
113+ // EXPECT_EQ(method->inputs_size(), 1);
114+
115+ // auto outputs = method->get_output(0);
116+ // EXPECT_EQ(outputs.toTensor().dim(), 3);
117+ // EXPECT_EQ(outputs.toTensor().size(0), 5);
118+ // EXPECT_EQ(outputs.toTensor().size(1), 2);
119+ // EXPECT_EQ(outputs.toTensor().size(2), 10);
120+
121+ // torch::executor::util::FreeInputs(inputs);
122+ // }
0 commit comments