Skip to content

Commit 9b6ce65

Browse files
committed
Bug fixes
Signed-off-by: Ryan Nett <JNett96@gmail.com>
1 parent 354aec8 commit 9b6ce65

6 files changed

Lines changed: 63 additions & 25 deletions

File tree

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
package org.tensorflow.op;
1919

2020
import java.nio.charset.Charset;
21+
import java.util.Arrays;
2122
import java.util.List;
2223
import java.util.Map;
2324
import org.tensorflow.ConcreteFunction;
2425
import org.tensorflow.DeviceSpec;
2526
import org.tensorflow.EagerSession;
2627
import org.tensorflow.ExecutionEnvironment;
2728
import org.tensorflow.Operand;
29+
import org.tensorflow.Operation;
2830
import org.tensorflow.ndarray.BooleanNdArray;
2931
import org.tensorflow.ndarray.ByteNdArray;
3032
import org.tensorflow.ndarray.DoubleNdArray;
@@ -8257,6 +8259,33 @@ public Ops withControlDependencies(Iterable<Op> controls) {
82578259
return new Ops(scope.withControlDependencies(controls));
82588260
}
82598261

8262+
/**
8263+
* Returns an API that adds operations to the graph with the provided control dependencies.
8264+
*
8265+
* @see {@link Scope#withControlDependencies(Iterable<Op<?>>)}
8266+
*/
8267+
public Ops withControlDependencies(Op... controls) {
8268+
return withControlDependencies(Arrays.asList(controls));
8269+
}
8270+
8271+
/**
8272+
* Returns an API that adds operations to the graph with the provided control dependencies.
8273+
*
8274+
* @see {@link Scope#withControlDependencyOps(Iterable<Operation>)}
8275+
*/
8276+
public Ops withControlDependencyOps(Iterable<Operation> controls) {
8277+
return new Ops(scope.withControlDependencyOps(controls));
8278+
}
8279+
8280+
/**
8281+
* Returns an API that adds operations to the graph with the provided control dependencies.
8282+
*
8283+
* @see {@link Scope#withControlDependencyOps(Iterable<Operation>)}
8284+
*/
8285+
public Ops withControlDependencyOps(Operation... controls) {
8286+
return withControlDependencyOps(Arrays.asList(controls));
8287+
}
8288+
82608289
/**
82618290
* Returns the current {@link Scope scope} of this API
82628291
*/

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,6 +1239,8 @@ private static Object[] whileLoop(
12391239
private static SaverDef addVariableSaver(Graph graph) {
12401240
Ops tf = Ops.create(graph).withSubScope("save");
12411241

1242+
// TODO handle resource variables, too
1243+
12421244
List<String> varNames = new ArrayList<>();
12431245
List<Operand<?>> varOutputs = new ArrayList<>();
12441246
List<Class<? extends TType>> varTypes = new ArrayList<>();

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ public void export() throws IOException {
280280
functions.forEach((k, f) -> metaGraphDef.putSignatureDef(k, f.signature().asSignatureDef()));
281281

282282
if (!functions.containsKey(INIT_OP_SIGNATURE_KEY)) {
283+
283284
metaGraphDef.putSignatureDef(
284285
INIT_OP_SIGNATURE_KEY,
285286
SignatureDef.newBuilder()
@@ -387,7 +388,10 @@ public Session session() {
387388

388389
/** Return the signature of all functions available in this saved model. */
389390
public List<Signature> signatures() {
390-
return functions.values().stream().map(f -> f.signature()).collect(Collectors.toList());
391+
return functions.values().stream()
392+
.map(SessionFunction::signature)
393+
.filter(signature -> !signature.key().equals(INIT_OP_SIGNATURE_KEY))
394+
.collect(Collectors.toList());
391395
}
392396

393397
/**
@@ -592,6 +596,9 @@ private static SavedModelBundle load(
592596
throw new TensorFlowException("Cannot parse MetaGraphDef protocol buffer", e);
593597
}
594598
}
599+
bundle.session.initialize();
600+
601+
// bundle.session.restore(exportDir + "/variables/variables");
595602

596603
return bundle;
597604
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,8 @@ public void exportFunctionWithVariables() throws IOException {
173173
SavedModelBundle.load(testFolder.toString(), SavedModelBundle.DEFAULT_TAG)) {
174174
assertNotNull(savedModel.metaGraphDef());
175175
assertNotNull(savedModel.metaGraphDef().getSaverDef());
176-
assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount());
177-
assertEquals(
178-
Signature.DEFAULT_KEY,
179-
savedModel.metaGraphDef().getSignatureDefMap().keySet().iterator().next());
176+
assertEquals(2, savedModel.metaGraphDef().getSignatureDefCount());
177+
assertTrue(savedModel.metaGraphDef().getSignatureDefMap().containsKey(Signature.DEFAULT_KEY));
180178

181179
TensorFunction function = savedModel.function(Signature.DEFAULT_KEY);
182180
assertNotNull(function);

tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/Names.java

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,25 @@
11
/*
2-
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
33
4-
Licensed under the Apache License, Version 2.0 (the "License");
5-
you may not use this file except in compliance with the License.
6-
You may obtain a copy of the License at
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
77
8-
http://www.apache.org/licenses/LICENSE-2.0
8+
http://www.apache.org/licenses/LICENSE-2.0
99
10-
Unless required by applicable law or agreed to in writing, software
11-
distributed under the License is distributed on an "AS IS" BASIS,
12-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
See the License for the specific language governing permissions and
14-
limitations under the License.
15-
==============================================================================
16-
*/
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
*/
1717
package org.tensorflow;
1818

1919
import com.squareup.javapoet.ArrayTypeName;
2020
import com.squareup.javapoet.ClassName;
2121
import com.squareup.javapoet.ParameterizedTypeName;
2222
import com.squareup.javapoet.TypeName;
23-
import java.util.Arrays;
2423

2524
public class Names {
2625

@@ -52,9 +51,12 @@ public class Names {
5251
public static final ClassName RawOp = ClassName.get(OpPackage, "RawOp");
5352
public static final ClassName Operation = ClassName.get(TensorflowPackage, "Operation");
5453
public static final ClassName Operands = ClassName.get(OpPackage, "Operands");
55-
public static final ClassName OperationBuilder = ClassName.get(TensorflowPackage, "OperationBuilder");
56-
public static final TypeName IterableOp = ParameterizedTypeName.get(ClassName.get(Iterable.class), Op);
57-
public static final TypeName IterableOperation = ParameterizedTypeName.get(ClassName.get(Iterable.class), Operation);
54+
public static final ClassName OperationBuilder =
55+
ClassName.get(TensorflowPackage, "OperationBuilder");
56+
public static final TypeName IterableOp =
57+
ParameterizedTypeName.get(ClassName.get(Iterable.class), Op);
58+
public static final TypeName IterableOperation =
59+
ParameterizedTypeName.get(ClassName.get(Iterable.class), Operation);
5860
public static final TypeName ArrayOp = ArrayTypeName.of(Op);
5961
public static final TypeName ArrayOperation = ArrayTypeName.of(Operation);
6062

@@ -63,7 +65,8 @@ public class Names {
6365

6466
public static final ClassName Shape = ClassName.get(TensorflowPackage + ".ndarray", "Shape");
6567
public static final ClassName Tensor = ClassName.get(TensorflowPackage, "Tensor");
66-
public static final ClassName ConcreteFunction = ClassName.get(TensorflowPackage, "ConcreteFunction");
68+
public static final ClassName ConcreteFunction =
69+
ClassName.get(TensorflowPackage, "ConcreteFunction");
6770

6871
public static final ClassName Scope = ClassName.get(OpPackage, "Scope");
6972
public static final TypeName DeviceSpec = ClassName.get(TensorflowPackage, "DeviceSpec");
@@ -75,5 +78,4 @@ public class Names {
7578

7679
public static final TypeName String = ClassName.get(String.class);
7780
public static final ClassName Arrays = ClassName.get(java.util.Arrays.class);
78-
7981
}

tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/processor/operator/OperatorProcessor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ private static TypeSpec buildTopClass(OpsSpec spec) {
625625
.addParameter(Names.ArrayOp, "controls")
626626
.varargs()
627627
.returns(Names.Ops)
628-
.addStatement("return withControlDependencies(%T.asList(controls))", Names.Arrays)
628+
.addStatement("return withControlDependencies($T.asList(controls))", Names.Arrays)
629629
.addJavadoc(
630630
"Returns an API that adds operations to the graph with the provided control dependencies.\n\n"
631631
+ "@see {@link $T#withControlDependencies(Iterable<Op<?>>)}\n",
@@ -650,7 +650,7 @@ private static TypeSpec buildTopClass(OpsSpec spec) {
650650
.addParameter(Names.ArrayOperation, "controls")
651651
.varargs()
652652
.returns(Names.Ops)
653-
.addStatement("return withControlDependencyOps(%T.asList(controls))", Names.Arrays)
653+
.addStatement("return withControlDependencyOps($T.asList(controls))", Names.Arrays)
654654
.addJavadoc(
655655
"Returns an API that adds operations to the graph with the provided control dependencies.\n\n"
656656
+ "@see {@link $T#withControlDependencyOps(Iterable<Operation>)}\n",

0 commit comments

Comments
 (0)