@@ -62,7 +62,7 @@ class ProgramTest : public ::testing::Test {
6262
6363 add_loader_ = std::make_unique<FileDataLoader>(std::move (loader.get ()));
6464
65- // Load the serialized ModuleAdd data.
65+ // Load the serialized ModuleMultiEntry data.
6666 path = std::getenv (" ET_MODULE_MULTI_ENTRY_PATH" );
6767 Result<FileDataLoader> multi_loader = FileDataLoader::from (path);
6868 ASSERT_EQ (multi_loader.error (), Error::Ok);
@@ -98,6 +98,16 @@ class ProgramTestFriend final {
9898 return program->LoadSegment (segment_info);
9999 }
100100
101+ __ET_NODISCARD static Error load_mutable_subsegment_into (
102+ const Program* program,
103+ size_t mutable_data_segments_index,
104+ size_t offset_index,
105+ size_t size,
106+ void * buffer) {
107+ return program->load_mutable_subsegment_into (
108+ mutable_data_segments_index, offset_index, size, buffer);
109+ }
110+
101111 const static executorch_flatbuffer::Program* GetInternalProgram (
102112 const Program* program) {
103113 return program->internal_program_ ;
@@ -444,3 +454,89 @@ TEST_F(ProgramTest, LoadConstantSegmentWithNoConstantSegment) {
444454 // The constant buffer should exist.
445455 EXPECT_GE (flatbuffer_program->constant_buffer ()->size (), 1 );
446456}
457+
458+ TEST_F (ProgramTest, LoadFromMutableSegment) {
459+ // Load the serialized ModuleSimpleTrain data.
460+ auto path = std::getenv (" ET_MODULE_SIMPLE_TRAIN_PATH" );
461+ Result<FileDataLoader> training_loader = FileDataLoader::from (path);
462+ ASSERT_EQ (training_loader.error (), Error::Ok);
463+
464+ // This file should always be compatible.
465+ Result<FreeableBuffer> training_header = training_loader->load (
466+ /* offset=*/ 0 ,
467+ Program::kMinHeadBytes ,
468+ DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::Program));
469+ ASSERT_EQ (training_header.error (), Error::Ok);
470+ EXPECT_EQ (
471+ Program::check_header (training_header->data (), training_header->size ()),
472+ Program::HeaderStatus::CompatibleVersion);
473+
474+ Result<Program> program = Program::load (&training_loader.get ());
475+ ASSERT_EQ (program.error (), Error::Ok);
476+
477+ // dummy buffers to load into
478+ uint8_t buffer[1 ] = {0 };
479+ uint8_t buffer2[1 ] = {0 };
480+
481+ // Load some mutable segment data
482+ Error err = ProgramTestFriend::load_mutable_subsegment_into (
483+ &program.get (), 0 , 1 , 1 , buffer);
484+ EXPECT_EQ (err, Error::Ok);
485+
486+ // Check that the data loaded correctly, and then mutate it
487+ EXPECT_EQ (buffer[0 ], 232 ); // 232 comes from inspecting the file itself. The
488+ // file is seeded so this value should be stable.
489+ buffer[0 ] = 0 ;
490+
491+ // Load the same mutable segment data from file into a different buffer.
492+ err = ProgramTestFriend::load_mutable_subsegment_into (
493+ &program.get (),
494+ 0 , // mutable_data_segments_index
495+ 1 , // offset_index
496+ 1 , // size
497+ buffer2);
498+ EXPECT_EQ (err, Error::Ok);
499+
500+ // Check that new data loaded from the file does not reflect the change to
501+ // buffer.
502+ EXPECT_EQ (buffer2[0 ], 232 );
503+
504+ const executorch_flatbuffer::Program* flatbuffer_program =
505+ ProgramTestFriend::GetInternalProgram (&program.get ());
506+
507+ // Expect 1 segment. 1 mutable segment and no constant segment.
508+ EXPECT_EQ (flatbuffer_program->segments ()->size (), 1 );
509+
510+ // Expect a mutable data segment.
511+ EXPECT_EQ (flatbuffer_program->mutable_data_segments ()->size (), 1 );
512+
513+ // Expect the 0 index to be reserved and the offsets for weight and bias of
514+ // linear to be indices 1 and 2.
515+ EXPECT_EQ (
516+ flatbuffer_program->mutable_data_segments ()->Get (0 )->offsets ()->size (),
517+ 3 );
518+ EXPECT_EQ (
519+ flatbuffer_program->mutable_data_segments ()->Get (0 )->offsets ()->Get (0 ),
520+ 0 );
521+ EXPECT_EQ (
522+ flatbuffer_program->mutable_data_segments ()->Get (0 )->offsets ()->Get (1 ),
523+ 0 );
524+ EXPECT_EQ (
525+ flatbuffer_program->mutable_data_segments ()->Get (0 )->offsets ()->Get (2 ),
526+ 36 );
527+
528+ // Loading beyond file should fail
529+ err = ProgramTestFriend::load_mutable_subsegment_into (
530+ &program.get (), 0 , 1 , 500 , buffer);
531+ EXPECT_NE (err, Error::Ok);
532+
533+ // Loading beyond offsets should fail
534+ err = ProgramTestFriend::load_mutable_subsegment_into (
535+ &program.get (), 0 , 500 , 1 , buffer);
536+ EXPECT_NE (err, Error::Ok);
537+
538+ // Loading beyond segments should fail
539+ err = ProgramTestFriend::load_mutable_subsegment_into (
540+ &program.get (), 500 , 1 , 1 , buffer);
541+ EXPECT_NE (err, Error::Ok);
542+ }
0 commit comments