Skip to content

Commit 30a4b36

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
unblock list optional tensor
Summary: We dont have an unboxed value to point to in the values table. Hack it to just use nullptr to indicate. I dont love this because its weird to couple the type and the boxed list like this but its probably ok for now Reviewed By: guangy10 Differential Revision: D48339367 fbshipit-source-id: d9346b072530c3cbda8240087f01becdd4757207
1 parent 54412be commit 30a4b36

8 files changed

Lines changed: 105 additions & 7 deletions

File tree

runtime/core/evalue.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/evalue.h>
10+
11+
namespace torch {
12+
namespace executor {
13+
template <>
14+
exec_aten::ArrayRef<exec_aten::optional<exec_aten::Tensor>>
15+
BoxedEvalueList<exec_aten::optional<exec_aten::Tensor>>::get() const {
16+
for (typename exec_aten::ArrayRef<
17+
exec_aten::optional<exec_aten::Tensor>>::size_type i = 0;
18+
i < wrapped_vals_.size();
19+
i++) {
20+
if (wrapped_vals_[i] == nullptr) {
21+
unwrapped_vals_[i] = exec_aten::nullopt;
22+
} else {
23+
unwrapped_vals_[i] =
24+
wrapped_vals_[i]->to<exec_aten::optional<exec_aten::Tensor>>();
25+
}
26+
}
27+
return exec_aten::ArrayRef<exec_aten::optional<exec_aten::Tensor>>{
28+
unwrapped_vals_, wrapped_vals_.size()};
29+
}
30+
} // namespace executor
31+
} // namespace torch

runtime/core/evalue.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ class BoxedEvalueList {
7373
mutable T* unwrapped_vals_;
7474
};
7575

76+
template <>
77+
exec_aten::ArrayRef<exec_aten::optional<exec_aten::Tensor>>
78+
BoxedEvalueList<exec_aten::optional<exec_aten::Tensor>>::get() const;
79+
7680
// Aggregate typing system similar to IValue only slimmed down with less
7781
// functionality, no dependencies on atomic, and fewer supported types to better
7882
// suit embedded systems (ie no intrusive ptr)
@@ -498,6 +502,7 @@ exec_aten::ArrayRef<T> BoxedEvalueList<T>::get() const {
498502
for (typename exec_aten::ArrayRef<T>::size_type i = 0;
499503
i < wrapped_vals_.size();
500504
i++) {
505+
ET_CHECK(wrapped_vals_[i] != nullptr);
501506
unwrapped_vals_[i] = wrapped_vals_[i]->template to<T>();
502507
}
503508
return exec_aten::ArrayRef<T>{unwrapped_vals_, wrapped_vals_.size()};

runtime/core/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def define_common_targets():
6060
exported_headers = [
6161
"evalue.h",
6262
],
63+
srcs = ["evalue.cpp"],
6364
visibility = [
6465
"//executorch/...",
6566
"@EXECUTORCH_CLIENTS",

runtime/executor/tensor_parser.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,15 @@ parseListOptionalType(
5858
if (index == -1) {
5959
new (&optional_tensor_list[output_idx])
6060
exec_aten::optional<T>(exec_aten::nullopt);
61+
// no value to point to. BoxedEvalueList for optional tensor will convert
62+
// this to nullopt.
63+
// TODO(T161156879): do something less hacky here.
64+
evalp_list[output_idx] = nullptr;
6165
} else {
6266
new (&optional_tensor_list[output_idx])
6367
exec_aten::optional<T>(values_[index].toOptional<T>());
68+
evalp_list[output_idx] = &values_[static_cast<size_t>(index)];
6469
}
65-
evalp_list[output_idx] = &values_[static_cast<size_t>(index)];
6670
output_idx++;
6771
}
6872
return BoxedEvalueList<exec_aten::optional<T>>(

runtime/executor/test/method_test.cpp

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,26 +38,41 @@ class MethodTest : public ::testing::Test {
3838
const char* path = std::getenv("ET_MODULE_ADD_PATH");
3939
Result<FileDataLoader> loader = FileDataLoader::From(path);
4040
ASSERT_EQ(loader.error(), Error::Ok);
41-
loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));
41+
add_loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));
4242

4343
// Use it to load the program.
4444
Result<Program> program = Program::Load(
45-
loader_.get(), Program::Verification::InternalConsistency);
45+
add_loader_.get(), Program::Verification::InternalConsistency);
4646
ASSERT_EQ(program.error(), Error::Ok);
47-
program_ = std::make_unique<Program>(std::move(program.get()));
47+
add_program_ = std::make_unique<Program>(std::move(program.get()));
48+
49+
// Create a loader for the serialized ModuleIndex program.
50+
const char* index_path = std::getenv("ET_MODULE_INDEX_PATH");
51+
Result<FileDataLoader> index_loader = FileDataLoader::From(index_path);
52+
ASSERT_EQ(index_loader.error(), Error::Ok);
53+
index_loader_ =
54+
std::make_unique<FileDataLoader>(std::move(index_loader.get()));
55+
56+
// Use it to load the program.
57+
Result<Program> index_program = Program::Load(
58+
index_loader_.get(), Program::Verification::InternalConsistency);
59+
ASSERT_EQ(index_program.error(), Error::Ok);
60+
index_program_ = std::make_unique<Program>(std::move(index_program.get()));
4861
}
4962

5063
private:
5164
// Must outlive program_, but tests shouldn't need to touch it.
52-
std::unique_ptr<FileDataLoader> loader_;
65+
std::unique_ptr<FileDataLoader> add_loader_;
66+
std::unique_ptr<FileDataLoader> index_loader_;
5367

5468
protected:
55-
std::unique_ptr<Program> program_;
69+
std::unique_ptr<Program> add_program_;
70+
std::unique_ptr<Program> index_program_;
5671
};
5772

5873
TEST_F(MethodTest, MoveTest) {
5974
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
60-
Result<Method> method = program_->load_method("forward", &mmm.get());
75+
Result<Method> method = add_program_->load_method("forward", &mmm.get());
6176
ASSERT_EQ(method.error(), Error::Ok);
6277

6378
// Can execute the method.
@@ -79,3 +94,29 @@ TEST_F(MethodTest, MoveTest) {
7994

8095
torch::executor::util::FreeInputs(inputs);
8196
}
97+
98+
// TODO(T161163608): Test is disabled due to a resize bug in tensor_index_out of
99+
// the portable op lib
100+
101+
// TEST_F(MethodTest, OptionalTensorListDeserialization) {
102+
// ManagedMemoryManager mmm(kDefaultNonConstMemBytes,
103+
// kDefaultRuntimeMemBytes); Result<Method> method =
104+
// index_program_->load_method("forward", &mmm.get());
105+
// ASSERT_EQ(method.error(), Error::Ok);
106+
107+
// // Can execute the method.
108+
// exec_aten::ArrayRef<void*> inputs =
109+
// torch::executor::util::PrepareInputTensors(*method);
110+
// Error err = method->execute();
111+
// ASSERT_EQ(err, Error::Ok);
112+
113+
// EXPECT_EQ(method->inputs_size(), 1);
114+
115+
// auto outputs = method->get_output(0);
116+
// EXPECT_EQ(outputs.toTensor().dim(), 3);
117+
// EXPECT_EQ(outputs.toTensor().size(0), 5);
118+
// EXPECT_EQ(outputs.toTensor().size(1), 2);
119+
// EXPECT_EQ(outputs.toTensor().size(2), 10);
120+
121+
// torch::executor::util::FreeInputs(inputs);
122+
// }

runtime/executor/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def define_common_targets(is_fbcode = False):
7979
# an fbcode target path because the authoring/export tools
8080
# intentionally don't work in xplat (since they're host-only tools).
8181
"ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])",
82+
"ET_MODULE_INDEX_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleIndex.pte])",
8283
"ET_MODULE_MULTI_ENTRY_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleMultipleEntry.pte])",
8384
}
8485

test/models/export_program.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,20 @@ def get_export_kwargs() -> Dict[str, Any]:
4646
}
4747

4848

49+
class ModuleIndex(nn.Module):
50+
def __init__(self):
51+
super(ModuleIndex, self).__init__()
52+
53+
def forward(self, x):
54+
# Weird index that happens to generate a None in torch.index.Tensor_out
55+
# which is desirable for deserialization testing. A modified form of
56+
# an example index from https://pytorch.org/cppdocs/notes/tensor_indexing.html.
57+
return x[1::2, torch.tensor([1, 2])]
58+
59+
def get_random_inputs(self):
60+
return (torch.randn(10, 10, 10),)
61+
62+
4963
class ModuleNoOp(nn.Module):
5064
def __init__(self):
5165
super(ModuleNoOp, self).__init__()

test/models/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def define_common_targets():
6363
"ModuleBasic",
6464
"ModuleLinear",
6565
"ModuleMultipleEntry",
66+
"ModuleIndex",
6667
]
6768

6869
# Generates Executorch .pte program files for various modules at build time.

0 commit comments

Comments
 (0)