diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java index a7e48fcb9ee..cf7537eadfd 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java @@ -147,6 +147,15 @@ Pointer getUnsafeNativeHandle() { return operation.getUnsafeNativeHandle(index); } + /** + * Returns whether the underlying operation has no valid handle. Makes the opposite check as + * GraphOperation.requireHandle * + */ + public boolean isClosed() { + Pointer handle = operation.getUnsafeNativeHandle(index); + return handle == null || handle.isNull(); + } + private final AbstractOperation operation; private final int index; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index a12e46f82e5..dc7047337e9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -174,7 +174,9 @@ public Op applyGradients(List> gradsAndVars, String List updateOps = new ArrayList<>(); prepOp.ifPresent(updateOps::add); for (GradAndVar pair : gradsAndVars) { - updateOps.add(applyDense(pair)); + if (!pair.gradient.isClosed()) { + updateOps.add(applyDense(pair)); + } } return finish(updateOps, name);