Fix fail when a variable receives zero gradient #482#483
Conversation
| return operation.getUnsafeNativeHandle(index); | ||
| } | ||
|
|
||
| public boolean isClosed() { |
There was a problem hiding this comment.
Does this behave correctly on an EagerSession? Also it needs some javadoc and I think it'll fail the spotless check.
There was a problem hiding this comment.
Fixed the formatting problem.
I tried the following small example with EagerSession, and it appears to reply correctly false:
import org.tensorflow.EagerSession;
import org.tensorflow.op.Ops;
import org.tensorflow.framework.initializers.Zeros;
import org.tensorflow.types.TFloat32;
public class App {
public static void main(String[] args) {
try (EagerSession s = EagerSession.create()) {
Ops tf = Ops.create(s);
Zeros<TFloat32> zeroInit = new Zeros<>();
// y = a*x + b
var x = tf.constant(new float[] {1.0f, 3.0f, 4.0f});
var y = tf.constant(new float[] {2.5f, 6.5f, 8.5f});
var b = tf.constant(1.0f);
var a = tf.constant(1.8f);
var ypred = tf.math.add(tf.math.mul(a, x), tf.stopGradient(b));
var loss_gen = new org.tensorflow.framework.losses.MeanSquaredError();
var los = loss_gen.call(tf, ypred, y);
System.out.println("IsClosed " + los.asOutput().isClosed());
}
}
}
Running it, the output is:
2023-01-06 16:32:31.701434: I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
IsClosed false
So it appears to work correctly.
There was a problem hiding this comment.
Ok. Can you add the javadoc to the method?
There was a problem hiding this comment.
Yes, I have just added the javadoc.
|
Thanks @ramon-garcia for your contribution! |
|
Thank you for accepting this contribution. It will make my work easier, now that I can stop gradient propagation. |
Fixes bug #482 . A check is added in the optimizer code that the gradient of a variable is valid, and otherwise it is left out.