2323#include < executorch/extension/data_loader/buffer_data_loader.h>
2424#include < executorch/extension/data_loader/mmap_data_loader.h>
2525#include < executorch/extension/memory_allocator/malloc_memory_allocator.h>
26+ #include < executorch/extension/module/bundled_module.h>
2627#include < executorch/extension/threadpool/threadpool.h>
2728#include < executorch/runtime/backend/interface.h>
2829#include < executorch/runtime/core/data_loader.h>
@@ -96,6 +97,7 @@ using ::executorch::ET_RUNTIME_NAMESPACE::Program;
9697using ::executorch::extension::BufferDataLoader;
9798using ::executorch::extension::MallocMemoryAllocator;
9899using ::executorch::extension::MmapDataLoader;
100+ using ::executorch::extension::ET_BUNDLED_MODULE_NAMESPACE::BundledModule;
99101using ::executorch::runtime::ArrayRef;
100102using ::executorch::runtime::DataLoader;
101103using ::executorch::runtime::Error;
@@ -440,13 +442,54 @@ inline std::unique_ptr<Module> load_module_from_file(
440442 program_verification);
441443}
442444
445+ inline py::list get_outputs_as_py_list (
446+ const std::vector<EValue>& outputs,
447+ bool clone_outputs = true ) {
448+ const auto outputs_size = outputs.size ();
449+ py::list list (outputs_size);
450+ for (size_t i = 0 ; i < outputs_size; ++i) {
451+ auto & v = outputs[i];
452+ if (Tag::None == v.tag ) {
453+ list[i] = py::none ();
454+ } else if (Tag::Int == v.tag ) {
455+ list[i] = py::cast (v.toInt ());
456+ } else if (Tag::Double == v.tag ) {
457+ list[i] = py::cast (v.toDouble ());
458+ } else if (Tag::Bool == v.tag ) {
459+ list[i] = py::cast (v.toBool ());
460+ } else if (Tag::String == v.tag ) {
461+ list[i] = py::cast (std::string (v.toString ().data ()));
462+ } else if (Tag::Tensor == v.tag ) {
463+ #ifdef USE_ATEN_LIB
464+ // Clone so the outputs in python do not share a lifetime with the
465+ // module object
466+ if (clone_outputs) {
467+ list[i] = py::cast (v.toTensor ().clone ());
468+ } else {
469+ list[i] = py::cast (v.toTensor ());
470+ }
471+ #else
472+ if (clone_outputs) {
473+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
474+ } else {
475+ list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
476+ }
477+ #endif
478+ } else {
479+ ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
480+ }
481+ }
482+ return list;
483+ }
484+
443485static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U ;
444486
445- struct PyBundledModule final {
487+ struct PyBundledModule : public BundledModule {
446488 explicit PyBundledModule (
447489 const py::bytes& buffer,
448490 uint32_t bundled_input_pool_size)
449- : bundled_program_ptr_(buffer),
491+ : BundledModule(buffer.cast<std::string_view>().data()),
492+ bundled_program_ptr_(buffer),
450493 program_ptr_(static_cast <const void *>(
451494 bundled_program_flatbuffer::GetBundledProgram (
452495 get_bundled_program_ptr ())
@@ -475,6 +518,33 @@ struct PyBundledModule final {
475518 return program_len_;
476519 }
477520
521+ py::list verify_result_with_bundled_expected_output (
522+ const std::string& method_name,
523+ size_t testset_idx,
524+ double rtol = 1e-5 ,
525+ double atol = 1e-8 ) {
526+ // Execute the method
527+ auto result = BundledModule::execute (method_name, testset_idx);
528+ if (!result.ok ()) {
529+ THROW_IF_ERROR (
530+ result.error (),
531+ " Method execution failed with status 0x%" PRIx32,
532+ static_cast <uint32_t >(result.error ()));
533+ }
534+
535+ // Convert outputs to py::list
536+ const auto & outputs = result.get ();
537+ py::list py_outputs = get_outputs_as_py_list (outputs);
538+
539+ Error status = BundledModule::verify_method_outputs (
540+ method_name, testset_idx, rtol, atol);
541+ THROW_IF_ERROR (
542+ status,
543+ " Result verification failed with status %" PRIu32,
544+ static_cast <uint32_t >(status));
545+ return py_outputs;
546+ }
547+
478548 private:
479549 // Store the bytes object instead of a raw pointer so that this module will
480550 // keep the bytes alive.
@@ -831,43 +901,6 @@ struct PyModule final {
831901 }
832902 }
833903
834- void load_bundled_input (
835- PyBundledModule& m,
836- const std::string method_name,
837- size_t testset_idx) {
838- const void * bundled_program_ptr = m.get_bundled_program_ptr ();
839- Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
840- module_->get_method (method_name), bundled_program_ptr, testset_idx);
841- THROW_IF_ERROR (
842- status,
843- " load_bundled_input failed with status 0x%" PRIx32,
844- static_cast <uint32_t >(status));
845- }
846-
847- py::list verify_result_with_bundled_expected_output (
848- PyBundledModule& m,
849- const std::string method_name,
850- size_t testset_idx,
851- double rtol = 1e-5 ,
852- double atol = 1e-8 ) {
853- const void * bundled_program_ptr = m.get_bundled_program_ptr ();
854- auto & method = module_->get_method (method_name);
855- Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input (
856- method, bundled_program_ptr, testset_idx);
857- THROW_IF_ERROR (
858- status,
859- " load_bundled_input failed with status 0x%" PRIx32,
860- static_cast <uint32_t >(status));
861- py::list outputs = plan_execute (method_name);
862- status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs (
863- method, bundled_program_ptr, testset_idx, rtol, atol);
864- THROW_IF_ERROR (
865- status,
866- " Result verification failed with status %" PRIu32,
867- static_cast <uint32_t >(status));
868- return outputs;
869- }
870-
871904 py::list plan_execute (
872905 const std::string method_name,
873906 bool clone_outputs = true ) {
@@ -890,46 +923,6 @@ struct PyModule final {
890923 return get_outputs_as_py_list (outputs, clone_outputs);
891924 }
892925
893- py::list get_outputs_as_py_list (
894- const std::vector<EValue>& outputs,
895- bool clone_outputs = true ) {
896- const auto outputs_size = outputs.size ();
897- py::list list (outputs_size);
898- for (size_t i = 0 ; i < outputs_size; ++i) {
899- auto & v = outputs[i];
900- if (Tag::None == v.tag ) {
901- list[i] = py::none ();
902- } else if (Tag::Int == v.tag ) {
903- list[i] = py::cast (v.toInt ());
904- } else if (Tag::Double == v.tag ) {
905- list[i] = py::cast (v.toDouble ());
906- } else if (Tag::Bool == v.tag ) {
907- list[i] = py::cast (v.toBool ());
908- } else if (Tag::String == v.tag ) {
909- list[i] = py::cast (std::string (v.toString ().data ()));
910- } else if (Tag::Tensor == v.tag ) {
911- #ifdef USE_ATEN_LIB
912- // Clone so the outputs in python do not share a lifetime with the
913- // module object
914- if (clone_outputs) {
915- list[i] = py::cast (v.toTensor ().clone ());
916- } else {
917- list[i] = py::cast (v.toTensor ());
918- }
919- #else
920- if (clone_outputs) {
921- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()).clone ());
922- } else {
923- list[i] = py::cast (alias_attensor_to_etensor (v.toTensor ()));
924- }
925- #endif
926- } else {
927- ET_ASSERT_UNREACHABLE_MSG (" Invalid model output type" );
928- }
929- }
930- return list;
931- }
932-
933926 std::unique_ptr<PyMethodMeta> method_meta (const std::string method_name) {
934927 auto & method = module_->get_method (method_name);
935928 return std::make_unique<PyMethodMeta>(module_, method.method_meta ());
@@ -1089,16 +1082,6 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
10891082 call_guard);
10901083
10911084 py::class_<PyModule>(m, " ExecuTorchModule" )
1092- .def (" load_bundled_input" , &PyModule::load_bundled_input, call_guard)
1093- .def (
1094- " verify_result_with_bundled_expected_output" ,
1095- &PyModule::verify_result_with_bundled_expected_output,
1096- py::arg (" bundle" ),
1097- py::arg (" method_name" ),
1098- py::arg (" testset_idx" ),
1099- py::arg (" rtol" ) = 1e-5 ,
1100- py::arg (" atol" ) = 1e-8 ,
1101- call_guard)
11021085 .def (
11031086 " plan_execute" ,
11041087 &PyModule::plan_execute,
@@ -1144,7 +1127,16 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
11441127 py::arg (" clone_outputs" ) = true ,
11451128 call_guard);
11461129
1147- py::class_<PyBundledModule>(m, " BundledModule" );
1130+ py::class_<PyBundledModule>(m, " BundledModule" )
1131+ .def (
1132+ " verify_result_with_bundled_expected_output" ,
1133+ &PyBundledModule::verify_result_with_bundled_expected_output,
1134+ py::arg (" method_name" ),
1135+ py::arg (" testset_idx" ),
1136+ py::arg (" rtol" ) = 1e-5 ,
1137+ py::arg (" atol" ) = 1e-8 ,
1138+ call_guard);
1139+
11481140 py::class_<PyTensorInfo>(m, " TensorInfo" )
11491141 .def (" sizes" , &PyTensorInfo::sizes, call_guard)
11501142 .def (" dtype" , &PyTensorInfo::dtype, call_guard)
0 commit comments