Skip to content

Failure when a tensorflow variable has no gradients #482

@ramon-garcia

Description

@ramon-garcia

Tensorflow java fails when a variable is hidden from gradients using tf.stopGradient.

Variables without gradients are used to have persistent state between runs. In my case, I want to have some customizable values in the graph. With a constant, these values cannot be changed. For instance:

For instance, I want to normalize input values so that their square deviation is (more or less) 1. So one introduces a reasonable value of square deviation via a Tensorflow variable defined with

var signalVar = tf.withName("signalVar").variable(tf.constant(1.f));

it is initialized as 1, but can be modified with variable assignment (omitted here). Then use it to divide the input signal. But we don't want to train it.

var signalNormalized = tf.math.div(signal, tf.stopGradient(signalNormalized));

Then, when calling optimizer.minimize an IllegalStateException "close() has been called on the Graph this Operation was a part of".

var optimizer = new Adam(graph, learningRate, betaOne, betaTwo, epsilon);
optimizer.minimize(loss, "train"); // IllegalStateException throw here.

The reason is that when the optimizer code scans for variables, it checks the gradients and their data types, but a variable protected with stopGradient() has no gradient TF_Operation initialized.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions