@@ -168,20 +168,52 @@ class DataLoaderSpy : public DataLoader {
168168 public:
169169 // / A record of an operation performed on this DataLoader.
170170 struct Operation {
171- enum { Load, Free } op;
172- size_t offset; // Set for Load; zero for Free.
173- void * data; // Set for Free; nullptr for Load.
174- size_t size; // Set for Load and Free.
171+ enum { Load, Free, DeprecatedLoad } op;
172+ size_t offset; // Set for Load/DeprecatedLoad; zero for Free.
173+ void * data; // Set for Free; nullptr for Load/DeprecatedLoad.
174+ size_t size; // Set for Load/DeprecatedLoad and Free.
175+ std::unique_ptr<const DataLoader::SegmentInfo>
176+ segment_info; // Set for Load; nullptr for Free/DeprecatedLoad.
175177 };
176178
177179 explicit DataLoaderSpy (DataLoader* delegate) : delegate_(delegate) {}
178180
181+ /* *
182+ * Override the deprecated "Load" method. We will be looking to test that
183+ * this function is not called if the new "load" method is called.
184+ */
179185 Result<FreeableBuffer> Load (size_t offset, size_t size) override {
180186 Result<FreeableBuffer> buf = delegate_->Load (offset, size);
181187 if (!buf.ok ()) {
182188 return buf.error ();
183189 }
184- operations_.push_back ({Operation::Load, offset, /* data=*/ nullptr , size});
190+ operations_.push_back (
191+ {Operation::DeprecatedLoad,
192+ offset,
193+ /* data=*/ nullptr ,
194+ size,
195+ /* segment_info=*/ nullptr });
196+ auto * context = new SpyContext (&operations_, std::move (buf.get ()));
197+ // Use context->buffer since buf has been moved.
198+ return FreeableBuffer (
199+ context->buffer .data (), context->buffer .size (), FreeBuffer, context);
200+ }
201+
202+ Result<FreeableBuffer>
203+ load (size_t offset, size_t size, const SegmentInfo& segment_info) override {
204+ Result<FreeableBuffer> buf = delegate_->load (offset, size, segment_info);
205+ if (!buf.ok ()) {
206+ return buf.error ();
207+ }
208+
209+ auto segment_info_cpy =
210+ std::make_unique<const DataLoader::SegmentInfo>(segment_info);
211+ operations_.push_back (
212+ {Operation::Load,
213+ offset,
214+ /* data=*/ nullptr ,
215+ size,
216+ /* segment_info=*/ std::move (segment_info_cpy)});
185217 auto * context = new SpyContext (&operations_, std::move (buf.get ()));
186218 // Use context->buffer since buf has been moved.
187219 return FreeableBuffer (
@@ -200,6 +232,36 @@ class DataLoaderSpy : public DataLoader {
200232 return operations_;
201233 }
202234
235+ /* *
236+ * Returns true if the DataLoader::load() method was called with the correct
237+ * segment info.
238+ */
239+ bool UsedLoad (
240+ DataLoader::SegmentInfo::Type segment_type,
241+ const char * descriptor = nullptr ) const {
242+ for (const auto & op : operations_) {
243+ // We should not be using the deprecated DataLoader::Load() function.
244+ if (op.op == Operation::DeprecatedLoad) {
245+ return false ;
246+ }
247+ if (op.op != Operation::Load) {
248+ continue ;
249+ }
250+ // We have a load op.
251+ if (op.segment_info ->segment_type == segment_type) {
252+ if (segment_type != DataLoader::SegmentInfo::Type::Backend) {
253+ // For non-backend segments, the descriptor is irrelevant / a nullptr.
254+ return true ;
255+ } else {
256+ if (strcmp (op.segment_info ->descriptor , descriptor) == 0 ) {
257+ return true ;
258+ }
259+ }
260+ }
261+ }
262+ return false ;
263+ }
264+
203265 /* *
204266 * Returns true if the operations list shows that the provided data pointer
205267 * was freed.
@@ -223,7 +285,8 @@ class DataLoaderSpy : public DataLoader {
223285
224286 static void FreeBuffer (void * context, void * data, size_t size) {
225287 auto * sc = reinterpret_cast <SpyContext*>(context);
226- sc->operations ->push_back ({Operation::Free, /* offset=*/ 0 , data, size});
288+ sc->operations ->push_back (
289+ {Operation::Free, /* offset=*/ 0 , data, size, /* segment_info=*/ nullptr });
227290 delete sc;
228291 }
229292
@@ -333,7 +396,7 @@ TEST_P(BackendIntegrationTest, FreeingProcessedBufferSucceeds) {
333396 EXPECT_EQ (method_res.error (), Error::Ok);
334397
335398 // Demonstrate that our installed init was called.
336- EXPECT_EQ (init_called, true );
399+ EXPECT_TRUE (init_called);
337400
338401 // See if the processed data was freed.
339402 bool processed_was_freed = spy_loader.WasFreed (processed_data);
@@ -444,6 +507,51 @@ TEST_P(BackendIntegrationTest, EndToEndTestWithProcessedAsHandle) {
444507 EXPECT_EQ (execute_handle, destroy_handle);
445508}
446509
510+ /* *
511+ * Tests that the DataLoader's load is receiving the correct segment info for
512+ * different types of segments.
513+ */
514+ TEST_P (BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) {
515+ const void * processed_data = nullptr ;
516+ StubBackend::singleton ().install_init (
517+ [&](FreeableBuffer* processed,
518+ __ET_UNUSED ArrayRef<CompileSpec> compile_specs,
519+ __ET_UNUSED MemoryAllocator* runtime_allocator)
520+ -> Result<DelegateHandle*> {
521+ processed_data = processed->data ();
522+ processed->Free ();
523+ return nullptr ;
524+ });
525+
526+ // Wrap the real loader in a spy so we can see which operations were
527+ // performed.
528+ Result<FileDataLoader> loader = FileDataLoader::from (program_path ());
529+ ASSERT_EQ (loader.error (), Error::Ok);
530+ DataLoaderSpy spy_loader (&loader.get ());
531+
532+ // Load the program.
533+ Result<Program> program = Program::load (&spy_loader);
534+ ASSERT_EQ (program.error (), Error::Ok);
535+ ManagedMemoryManager mmm (kDefaultNonConstMemBytes , kDefaultRuntimeMemBytes );
536+
537+ // Expect that load was called correctly on program segments.
538+ bool program_load_was_called =
539+ spy_loader.UsedLoad (DataLoader::SegmentInfo::Type::Program, nullptr );
540+
541+ // Load a method.
542+ Result<Method> method_res = program->load_method (" forward" , &mmm.get ());
543+ EXPECT_EQ (method_res.error (), Error::Ok);
544+
545+ // Expect that load was called correctly on a backend segment.
546+ bool backend_load_was_called = spy_loader.UsedLoad (
547+ DataLoader::SegmentInfo::Type::Backend,
548+ " backend_segment" ); // TODO(jackzhxng): replace with actual mock PTE
549+ // file's backend_id in next chained PR.
550+
551+ EXPECT_TRUE (program_load_was_called);
552+ EXPECT_EQ (backend_load_was_called, using_segments ());
553+ }
554+
447555// TODO: Add more tests for the runtime-to-backend interface. E.g.:
448556// - Errors during init() or execute() result in runtime init/execution failures
449557// - Correct values are passed to init()/execute()
0 commit comments