@@ -515,7 +515,7 @@ public void importGraphDef(GraphDef graphDef) throws IllegalArgumentException {
515515 importGraphDef (graphDef , "" );
516516 }
517517
518- private static final String INIT_OP_BASE_NAME = "Init" ;
518+ static final String INIT_OP_NAME = "Init" ;
519519
520520 /**
521521 * Import a representation of a TensorFlow graph.
@@ -538,44 +538,23 @@ public void importGraphDef(GraphDef graphDef, String prefix) throws IllegalArgum
538538 String initPrefix ;
539539 if (!prefix .isEmpty ()) {
540540 if (prefix .endsWith ("/" )) {
541- initPrefix = prefix + INIT_OP_BASE_NAME ;
541+ initPrefix = prefix + INIT_OP_NAME ;
542542 } else {
543- initPrefix = prefix + "/" + INIT_OP_BASE_NAME ;
543+ initPrefix = prefix + "/" + INIT_OP_NAME ;
544544 }
545545 } else {
546- initPrefix = INIT_OP_BASE_NAME ;
546+ initPrefix = INIT_OP_NAME ;
547547 }
548548
549549 operations ()
550550 .forEachRemaining (
551551 op -> {
552552 if (op .name ().startsWith (initPrefix )) {
553- registerInitOp (op );
553+ registerInitializer (op , false );
554554 }
555555 });
556556 }
557557
558- /**
559- * Create and return a NoOp that will run all init ops. If {@code required} is false and there are
560- * no new init ops since the last call, will do nothing and return null.
561- */
562- synchronized GraphOperation addInitOp (boolean required ) {
563- if (!newInitializers && !required ) {
564- return null ;
565- }
566- if (initializers .isEmpty () && !required ) {
567- return null ;
568- }
569-
570- baseScope .refreshNames ();
571- OperationBuilder builder =
572- baseScope ().withInitScope ().opBuilder (NoOp .OP_NAME , INIT_OP_BASE_NAME );
573- initializers .forEach (builder ::addControlInput );
574- GraphOperation initOp = (GraphOperation ) builder .build ();
575- newInitializers = false ;
576- return initOp ;
577- }
578-
579558 /**
580559 * Generate a representation of the Graph.
581560 *
@@ -587,62 +566,18 @@ synchronized GraphOperation addInitOp(boolean required) {
587566 * @see #importGraphDef(GraphDef, String)
588567 */
589568 public GraphDef toGraphDef () {
590- addInitOp ( false ) ;
569+ GraphDef graphDef ;
591570 synchronized (nativeHandleLock ) {
592- return toGraphDef (nativeHandle );
571+ graphDef = toGraphDef (nativeHandle );
593572 }
594- }
595-
596- private boolean registerInitOpHelper (Operation op ) {
597- if (isInitOp (op )) return false ;
598- checkInput (op );
599-
600- if (!(op instanceof GraphOperation )) {
601- throw new IllegalStateException ("Can't use a non-graph op as a graph's init op." );
602- }
603- GraphOperation graphOp = (GraphOperation ) op ;
604-
605- for (GraphOperation controlInput : graphOp .controlInputs ()) {
606- registerInitOpHelper (controlInput );
607- }
608-
609- for (Operand <?> input : graphOp .inputs ()) {
610- registerInitOpHelper (input .op ());
611- }
612- return initializers .add (op );
573+ return addOrUpdateInit (graphDef );
613574 }
614575
615576 @ Override
616- public void registerInitOp (Operation op ) {
617- if (registerInitOpHelper (op )) {
618- newInitializers = true ;
619- }
620- }
621-
622- @ Override
623- public boolean isInitOp (Operation op ) {
577+ public boolean isInitializer (Operation op ) {
624578 return initializers .contains (op );
625579 }
626580
627- /**
628- * Returns a set of ops that will run all initializers added to the graph via {@link
629- * #registerInitOp(Operation)}.
630- *
631- * <p>Note that NoOps aren't included in this list, since any inputs or control dependencies are
632- * guaranteed to also be in this list, and including the no-ops wouldn't change the initialization
633- * result.
634- */
635- public Set <Operation > initializers () {
636- return initializers .stream ()
637- .filter (x -> !x .type ().equals (NoOp .OP_NAME ))
638- .collect (Collectors .toSet ());
639- }
640-
641- /** Get whether the graph has any initializers */
642- public boolean hasInitializers () {
643- return !initializers .isEmpty ();
644- }
645-
646581 /**
647582 * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e.,
648583 * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
@@ -893,6 +828,39 @@ synchronized SaverDef saverDef() {
893828 return saverDef ;
894829 }
895830
831+ /**
832+ * Register an op as an initialization op.
833+ *
834+ * @throws IllegalArgumentException if the op or one of its inputs can't be made an init op.
835+ */
836+ synchronized void registerInitializer (GraphOperation op , boolean isNew ) {
837+ if (isInitializer (op )) {
838+ return ;
839+ }
840+ checkInput (op );
841+ for (GraphOperation controlInput : op .controlInputs ()) {
842+ checkInput (controlInput );
843+ }
844+ for (Operand <?> input : op .inputs ()) {
845+ checkInput (input .op ());
846+ }
847+ if (initializers .add (op ) && isNew && newInitializersMarker < 0 ) {
848+ newInitializersMarker = initializers .size () - 1 ;
849+ }
850+ }
851+
852+ /**
853+ * Returns a set of ops that will run all initializers added to the graph via {@link
854+ * #registerInitOp(Operation)}.
855+ *
856+ * <p>Note that NoOps aren't included in this list, since any inputs or control dependencies are
857+ * guaranteed to also be in this list, and including the no-ops wouldn't change the initialization
858+ * result.
859+ */
860+ Set <Operation > initializers () {
861+ return initializers ;
862+ }
863+
896864 private final Object nativeHandleLock = new Object ();
897865 private TF_Graph nativeHandle ;
898866 private int refcount = 0 ;
@@ -902,7 +870,7 @@ synchronized SaverDef saverDef() {
902870 private boolean dangerousGradientBuilder ;
903871
904872 private final Set <Operation > initializers = Collections .synchronizedSet (new LinkedHashSet <>());
905- private boolean newInitializers = false ;
873+ private int newInitializersMarker = - 1 ;
906874
907875 /**
908876 * Use builders without locking. This should only be used during custom gradient building.
@@ -1091,6 +1059,32 @@ private static GraphDef toGraphDef(TF_Graph handle) {
10911059 }
10921060 }
10931061
1062+ private GraphDef addOrUpdateInit (GraphDef graphDef ) {
1063+ if (newInitializersMarker < 0 ) {
1064+ return graphDef ;
1065+ }
1066+ var graphDefWithInitBuilder = graphDef .toBuilder ();
1067+ var initNode =
1068+ graphDefWithInitBuilder .getNodeBuilderList ().stream ()
1069+ .filter (n -> n .getName ().equals (INIT_OP_NAME ))
1070+ .findFirst ()
1071+ .orElseGet (
1072+ () -> {
1073+ return graphDefWithInitBuilder
1074+ .addNodeBuilder ()
1075+ .setName (INIT_OP_NAME )
1076+ .setOp (NoOp .OP_NAME );
1077+ });
1078+
1079+ // Register each initializer as a control dependency of the Init op by adding their names
1080+ // prefixed with the '^' symbol to the list of inputs
1081+ initializers .stream ()
1082+ .skip (newInitializersMarker )
1083+ .forEach (op -> initNode .addInput ("^" + op .name ()));
1084+
1085+ return graphDefWithInitBuilder .build ();
1086+ }
1087+
10941088 static void resolveOutputs (
10951089 String type , TF_Operation [] srcOps , int [] srcIndices , TF_Output dst , int n ) {
10961090 if (srcOps .length != n ) {
@@ -1316,7 +1310,7 @@ private static SaverDef addVariableSaver(Graph graph) {
13161310 .build ();
13171311 }
13181312
1319- private static Set <Graph > allGraphs =
1313+ private static final Set <Graph > allGraphs =
13201314 Collections .synchronizedSet (Collections .newSetFromMap (new WeakHashMap <>()));
13211315
13221316 /**
0 commit comments