@@ -560,7 +560,7 @@ def _plan(
560560 new_infra_proto = new_infra .to_proto ()
561561 infra_diff = diff_infra_protos (current_infra_proto , new_infra_proto )
562562
563- return ( registry_diff , infra_diff , new_infra )
563+ return registry_diff , infra_diff , new_infra
564564
565565 @log_exceptions_and_usage
566566 def _apply_diffs (
@@ -648,16 +648,23 @@ def apply(
648648 ]
649649 odfvs_to_update = [ob for ob in objects if isinstance (ob , OnDemandFeatureView )]
650650 services_to_update = [ob for ob in objects if isinstance (ob , FeatureService )]
651- data_sources_to_update = [ob for ob in objects if isinstance (ob , DataSource )]
652-
653- if len (entities_to_update ) + len (views_to_update ) + len (
654- request_views_to_update
655- ) + len (odfvs_to_update ) + len (services_to_update ) + len (
656- data_sources_to_update
657- ) != len (
658- objects
659- ):
660- raise ValueError ("Unknown object type provided as part of apply() call" )
651+ data_sources_set_to_update = {
652+ ob for ob in objects if isinstance (ob , DataSource )
653+ }
654+
655+ for fv in views_to_update :
656+ data_sources_set_to_update .add (fv .batch_source )
657+ if fv .stream_source :
658+ data_sources_set_to_update .add (fv .stream_source )
659+
660+ for rfv in request_views_to_update :
661+ data_sources_set_to_update .add (rfv .request_data_source )
662+
663+ for odfv in odfvs_to_update :
664+ for v in odfv .input_request_data_sources .values ():
665+ data_sources_set_to_update .add (v )
666+
667+ data_sources_to_update = list (data_sources_set_to_update )
661668
662669 # Validate all feature views and make inferences.
663670 self ._validate_all_feature_views (
0 commit comments