Skip to content

Commit 851fcf6

Browse files
authored
Simplify graph initializers (tensorflow#466)
* Simplify graph initializers * Check for exact init op name when adding it in graph def * Apply spotless * Explaining the usage of the ^ symbol
1 parent 9cfea86 commit 851fcf6

File tree

18 files changed

+177
-283
lines changed

18 files changed

+177
-283
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8131,36 +8131,18 @@ public Ops withSubScope(String childScopeName) {
81318131
}
81328132

81338133
/**
8134-
* Returns an API that builds init operations. {@link #liftToInitScope(Operand)} will be called for all created operations.
8134+
* Returns an API that builds init operations.
81358135
* <p>
81368136
* Init operations will be initialized at session creation, will have their inputs (and control inputs) made init ops as well,
81378137
* and are ignored when used as control dependencies.
81388138
* Additionally, this scope ignores any control dependencies.
81398139
* <p>
81408140
* If an input can not be made an init op (i.e. a Placeholder), will throw an {@link IllegalStateException} on op creation.
8141-
* @see #liftToInitScope(Operand)
81428141
*/
81438142
public Ops withInitScope() {
81448143
return new Ops(scope.withInitScope());
81458144
}
81468145

8147-
/**
8148-
* Make {@code op} an init operation, doing the same for all of it's inputs (and control inputs).
8149-
* <p>
8150-
* Init operations will be initialized at session creation, will have their inputs (and control inputs) made init ops as well,
8151-
* and are ignored when used as control dependencies.
8152-
* Additionally, this scope ignores any control dependencies.
8153-
* <p>
8154-
* If an input can not be made an init op (i.e. a Placeholder), will throw an {@link IllegalStateException} on op creation.
8155-
* @see ExecutionEnvironment#registerInitOp(Operation)
8156-
*
8157-
* @throws IllegalStateException if the op or one of its inputs can't be made an init op.
8158-
*/
8159-
public <T extends Operand> T liftToInitScope(T op) {
8160-
scope.env().registerInitOp(op.op());
8161-
return op;
8162-
}
8163-
81648146
/**
81658147
* Returns an API that uses the provided name for an op.
81668148
*

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,7 @@ final class EagerOperationBuilder implements OperationBuilder {
7575
public EagerOperation build() {
7676
scope.apply(this);
7777
TFE_TensorHandle[] tensorHandles = execute(opHandle, session);
78-
EagerOperation op = new EagerOperation(session, opHandle, tensorHandles, type, name);
79-
scope.onOpCreated(op);
80-
return op;
78+
return new EagerOperation(session, opHandle, tensorHandles, type, name);
8179
}
8280

8381
@Override

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -335,14 +335,10 @@ public Scope baseScope() {
335335

336336
/** Noop, initialization is meaningless for eager sessions */
337337
@Override
338-
public boolean isInitOp(Operation op) {
338+
public boolean isInitializer(Operation op) {
339339
return false;
340340
}
341341

342-
/** Noop, initialization is meaningless for eager sessions */
343-
@Override
344-
public void registerInitOp(Operation op) {}
345-
346342
TFE_Context nativeHandle() {
347343
checkSession();
348344
return nativeHandle;

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ExecutionEnvironment.java

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -101,29 +101,10 @@ default boolean isGraph() {
101101
*/
102102
Scope baseScope();
103103

104-
/**
105-
* Get the execution environment to use for initialization. In most cases is {@code this}.
106-
*
107-
* <p><b>Should generally only be used internally.</b>
108-
*/
109-
default ExecutionEnvironment initEnv() {
110-
return this;
111-
}
112-
113-
/**
114-
* Register an op and all of its inputs (and control inputs) as an initialization op.
115-
*
116-
* <p><b>Should generally only be used internally, prefer {@link
117-
* org.tensorflow.op.Ops#withInitScope()}.</b>
118-
*
119-
* @throws IllegalStateException if the op or one of its inputs can't be made an init op.
120-
*/
121-
void registerInitOp(Operation op);
122-
123104
/**
124105
* Get whether an op is an initialization op.
125106
*
126107
* <p><b>Should generally only be used internally.</b>
127108
*/
128-
boolean isInitOp(Operation op);
109+
boolean isInitializer(Operation op);
129110
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 70 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ public void importGraphDef(GraphDef graphDef) throws IllegalArgumentException {
515515
importGraphDef(graphDef, "");
516516
}
517517

518-
private static final String INIT_OP_BASE_NAME = "Init";
518+
static final String INIT_OP_NAME = "Init";
519519

520520
/**
521521
* Import a representation of a TensorFlow graph.
@@ -538,44 +538,23 @@ public void importGraphDef(GraphDef graphDef, String prefix) throws IllegalArgum
538538
String initPrefix;
539539
if (!prefix.isEmpty()) {
540540
if (prefix.endsWith("/")) {
541-
initPrefix = prefix + INIT_OP_BASE_NAME;
541+
initPrefix = prefix + INIT_OP_NAME;
542542
} else {
543-
initPrefix = prefix + "/" + INIT_OP_BASE_NAME;
543+
initPrefix = prefix + "/" + INIT_OP_NAME;
544544
}
545545
} else {
546-
initPrefix = INIT_OP_BASE_NAME;
546+
initPrefix = INIT_OP_NAME;
547547
}
548548

549549
operations()
550550
.forEachRemaining(
551551
op -> {
552552
if (op.name().startsWith(initPrefix)) {
553-
registerInitOp(op);
553+
registerInitializer(op, false);
554554
}
555555
});
556556
}
557557

558-
/**
559-
* Create and return a NoOp that will run all init ops. If {@code required} is false and there are
560-
* no new init ops since the last call, will do nothing and return null.
561-
*/
562-
synchronized GraphOperation addInitOp(boolean required) {
563-
if (!newInitializers && !required) {
564-
return null;
565-
}
566-
if (initializers.isEmpty() && !required) {
567-
return null;
568-
}
569-
570-
baseScope.refreshNames();
571-
OperationBuilder builder =
572-
baseScope().withInitScope().opBuilder(NoOp.OP_NAME, INIT_OP_BASE_NAME);
573-
initializers.forEach(builder::addControlInput);
574-
GraphOperation initOp = (GraphOperation) builder.build();
575-
newInitializers = false;
576-
return initOp;
577-
}
578-
579558
/**
580559
* Generate a representation of the Graph.
581560
*
@@ -587,62 +566,18 @@ synchronized GraphOperation addInitOp(boolean required) {
587566
* @see #importGraphDef(GraphDef, String)
588567
*/
589568
public GraphDef toGraphDef() {
590-
addInitOp(false);
569+
GraphDef graphDef;
591570
synchronized (nativeHandleLock) {
592-
return toGraphDef(nativeHandle);
571+
graphDef = toGraphDef(nativeHandle);
593572
}
594-
}
595-
596-
private boolean registerInitOpHelper(Operation op) {
597-
if (isInitOp(op)) return false;
598-
checkInput(op);
599-
600-
if (!(op instanceof GraphOperation)) {
601-
throw new IllegalStateException("Can't use a non-graph op as a graph's init op.");
602-
}
603-
GraphOperation graphOp = (GraphOperation) op;
604-
605-
for (GraphOperation controlInput : graphOp.controlInputs()) {
606-
registerInitOpHelper(controlInput);
607-
}
608-
609-
for (Operand<?> input : graphOp.inputs()) {
610-
registerInitOpHelper(input.op());
611-
}
612-
return initializers.add(op);
573+
return addOrUpdateInit(graphDef);
613574
}
614575

615576
@Override
616-
public void registerInitOp(Operation op) {
617-
if (registerInitOpHelper(op)) {
618-
newInitializers = true;
619-
}
620-
}
621-
622-
@Override
623-
public boolean isInitOp(Operation op) {
577+
public boolean isInitializer(Operation op) {
624578
return initializers.contains(op);
625579
}
626580

627-
/**
628-
* Returns a set of ops that will run all initializers added to the graph via {@link
629-
* #registerInitOp(Operation)}.
630-
*
631-
* <p>Note that NoOps aren't included in this list, since any inputs or control dependencies are
632-
* guaranteed to also be in this list, and including the no-ops wouldn't change the initialization
633-
* result.
634-
*/
635-
public Set<Operation> initializers() {
636-
return initializers.stream()
637-
.filter(x -> !x.type().equals(NoOp.OP_NAME))
638-
.collect(Collectors.toSet());
639-
}
640-
641-
/** Get whether the graph has any initializers */
642-
public boolean hasInitializers() {
643-
return !initializers.isEmpty();
644-
}
645-
646581
/**
647582
* Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e.,
648583
* {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
@@ -893,6 +828,39 @@ synchronized SaverDef saverDef() {
893828
return saverDef;
894829
}
895830

831+
/**
832+
* Register an op as an initialization op.
833+
*
834+
* @throws IllegalArgumentException if the op or one of its inputs can't be made an init op.
835+
*/
836+
synchronized void registerInitializer(GraphOperation op, boolean isNew) {
837+
if (isInitializer(op)) {
838+
return;
839+
}
840+
checkInput(op);
841+
for (GraphOperation controlInput : op.controlInputs()) {
842+
checkInput(controlInput);
843+
}
844+
for (Operand<?> input : op.inputs()) {
845+
checkInput(input.op());
846+
}
847+
if (initializers.add(op) && isNew && newInitializersMarker < 0) {
848+
newInitializersMarker = initializers.size() - 1;
849+
}
850+
}
851+
852+
/**
853+
* Returns a set of ops that will run all initializers added to the graph via {@link
854+
* #registerInitOp(Operation)}.
855+
*
856+
* <p>Note that NoOps aren't included in this list, since any inputs or control dependencies are
857+
* guaranteed to also be in this list, and including the no-ops wouldn't change the initialization
858+
* result.
859+
*/
860+
Set<Operation> initializers() {
861+
return initializers;
862+
}
863+
896864
private final Object nativeHandleLock = new Object();
897865
private TF_Graph nativeHandle;
898866
private int refcount = 0;
@@ -902,7 +870,7 @@ synchronized SaverDef saverDef() {
902870
private boolean dangerousGradientBuilder;
903871

904872
private final Set<Operation> initializers = Collections.synchronizedSet(new LinkedHashSet<>());
905-
private boolean newInitializers = false;
873+
private int newInitializersMarker = -1;
906874

907875
/**
908876
* Use builders without locking. This should only be used during custom gradient building.
@@ -1091,6 +1059,32 @@ private static GraphDef toGraphDef(TF_Graph handle) {
10911059
}
10921060
}
10931061

1062+
private GraphDef addOrUpdateInit(GraphDef graphDef) {
1063+
if (newInitializersMarker < 0) {
1064+
return graphDef;
1065+
}
1066+
var graphDefWithInitBuilder = graphDef.toBuilder();
1067+
var initNode =
1068+
graphDefWithInitBuilder.getNodeBuilderList().stream()
1069+
.filter(n -> n.getName().equals(INIT_OP_NAME))
1070+
.findFirst()
1071+
.orElseGet(
1072+
() -> {
1073+
return graphDefWithInitBuilder
1074+
.addNodeBuilder()
1075+
.setName(INIT_OP_NAME)
1076+
.setOp(NoOp.OP_NAME);
1077+
});
1078+
1079+
// Register each initializer as a control dependency of the Init op by adding their names
1080+
// prefixed with the '^' symbol to the list of inputs
1081+
initializers.stream()
1082+
.skip(newInitializersMarker)
1083+
.forEach(op -> initNode.addInput("^" + op.name()));
1084+
1085+
return graphDefWithInitBuilder.build();
1086+
}
1087+
10941088
static void resolveOutputs(
10951089
String type, TF_Operation[] srcOps, int[] srcIndices, TF_Output dst, int n) {
10961090
if (srcOps.length != n) {
@@ -1316,7 +1310,7 @@ private static SaverDef addVariableSaver(Graph graph) {
13161310
.build();
13171311
}
13181312

1319-
private static Set<Graph> allGraphs =
1313+
private static final Set<Graph> allGraphs =
13201314
Collections.synchronizedSet(Collections.newSetFromMap(new WeakHashMap<>()));
13211315

13221316
/**

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ public GraphOperation build() {
101101
}
102102
GraphOperation op = new GraphOperation(graph, built);
103103
unsafeNativeHandle = null;
104-
scope.onOpCreated(op);
104+
if (scope.isInit()) {
105+
graph.registerInitializer(op, true);
106+
}
105107
return op;
106108
}
107109
}

0 commit comments

Comments
 (0)