From f3a97132a3c701d76c98e9f6c457af54ebb3a3a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ram=C3=B3n=20Garc=C3=ADa=20Fern=C3=A1ndez?= Date: Thu, 29 Dec 2022 15:40:06 +0100 Subject: [PATCH 1/4] Fix fail when a variable receives zero gradient --- .../src/main/java/org/tensorflow/Output.java | 5 +++++ .../java/org/tensorflow/framework/optimizers/Optimizer.java | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java index a7e48fcb9ee..3ebfc96fd37 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java @@ -147,6 +147,11 @@ Pointer getUnsafeNativeHandle() { return operation.getUnsafeNativeHandle(index); } + public boolean isClosed() { + Pointer handle = operation.getUnsafeNativeHandle(index); + return handle== null || handle.isNull() ; + } + private final AbstractOperation operation; private final int index; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index a12e46f82e5..dc7047337e9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -174,7 +174,9 @@ public Op applyGradients(List> gradsAndVars, String List updateOps = new ArrayList<>(); prepOp.ifPresent(updateOps::add); for (GradAndVar pair : gradsAndVars) { - updateOps.add(applyDense(pair)); + if (!pair.gradient.isClosed()) { + updateOps.add(applyDense(pair)); + } } return finish(updateOps, name); From be762db8bf347f104df69f63190a1b1a2a51c049 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ram=C3=B3n=20Garc=C3=ADa=20Fern=C3=A1ndez?= Date: Fri, 6 Jan 2023 04:08:41 +0100 Subject: [PATCH 2/4] Fix whitespace error. --- .../src/main/java/org/tensorflow/Output.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java index 3ebfc96fd37..1d72bb4bea5 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java @@ -149,7 +149,7 @@ Pointer getUnsafeNativeHandle() { public boolean isClosed() { Pointer handle = operation.getUnsafeNativeHandle(index); - return handle== null || handle.isNull() ; + return handle == null || handle.isNull(); } private final AbstractOperation operation; From 4be5886e1091870f59fb163b371d81dd6cb944fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ram=C3=B3n=20Garc=C3=ADa=20Fern=C3=A1ndez?= Date: Fri, 6 Jan 2023 19:10:05 +0100 Subject: [PATCH 3/4] Added javadoc to function isClosed --- .../tensorflow-core-api/src/main/java/org/tensorflow/Output.java | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java index 1d72bb4bea5..c66e007c83b 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java @@ -147,6 +147,7 @@ Pointer getUnsafeNativeHandle() { return operation.getUnsafeNativeHandle(index); } + /** Returns whether the underlying operation has no valid handle. Makes the opposite check as GraphOperation.requireHandle **/ public boolean isClosed() { Pointer handle = operation.getUnsafeNativeHandle(index); return handle == null || handle.isNull(); From b07c9ab0e73d4f6cf419776ec7422eeec2513906 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ram=C3=B3n=20Garc=C3=ADa=20Fern=C3=A1ndez?= Date: Sat, 21 Jan 2023 22:25:53 +0100 Subject: [PATCH 4/4] Fix formatting errors from mvn spotless:check --- .../src/main/java/org/tensorflow/Output.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java index c66e007c83b..cf7537eadfd 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java @@ -147,7 +147,10 @@ Pointer getUnsafeNativeHandle() { return operation.getUnsafeNativeHandle(index); } - /** Returns whether the underlying operation has no valid handle. Makes the opposite check as GraphOperation.requireHandle **/ + /** + * Returns whether the underlying operation has no valid handle. Makes the opposite check as + * GraphOperation.requireHandle * + */ public boolean isClosed() { Pointer handle = operation.getUnsafeNativeHandle(index); return handle == null || handle.isNull();