Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c57a2e7
Merge pull request #3 from tensorflow/master
JimClarke5 Oct 8, 2020
09fc07e
Merge pull request #4 from tensorflow/master
JimClarke5 Oct 27, 2020
a99dcb4
Merge pull request #5 from tensorflow/master
JimClarke5 Nov 17, 2020
ba294ea
Merge pull request #6 from tensorflow/master
JimClarke5 Nov 19, 2020
04f419a
Merge pull request #7 from tensorflow/master
JimClarke5 Dec 30, 2020
02e7ebf
Merge pull request #8 from tensorflow/master
JimClarke5 Jan 29, 2021
e0c9ed8
Merge pull request #9 from tensorflow/master
JimClarke5 Feb 1, 2021
5b0374b
Merge pull request #10 from tensorflow/master
JimClarke5 Feb 11, 2021
e038bbd
Merge pull request #11 from tensorflow/master
JimClarke5 Feb 23, 2021
28a34dd
Clean up generics, remove generics from class and fix call method to …
JimClarke5 Mar 3, 2021
309b834
resynch with master, for some reason when I build on mac, the order f…
JimClarke5 Mar 3, 2021
def3051
Merge pull request #13 from tensorflow/master
JimClarke5 Mar 3, 2021
3a9ae37
Merge branch 'master' of https://github.com/JimClarke5/java into Gene…
JimClarke5 Mar 3, 2021
c5d37bf
Add GeLU activation present in TF 2.4
JimClarke5 Mar 4, 2021
11f8ac9
Fix @param<T> and reformat
JimClarke5 Mar 4, 2021
40a95af
Fix JavaDoc to add @param <T>
JimClarke5 Mar 6, 2021
d0e8de9
Refactor to add generic to base class and change signature of call me…
JimClarke5 Mar 6, 2021
478b78a
Add check for scalar.
JimClarke5 Mar 6, 2021
f53fa08
Change to accept TString value.
JimClarke5 Mar 7, 2021
79594da
Fix GeLU equations with separate Operands
JimClarke5 Mar 9, 2021
112c740
Fix Constant to handle TString properly
JimClarke5 Mar 9, 2021
61e6206
Added Stddev check for not less than 0.
JimClarke5 Mar 9, 2021
3b4b607
Fix fix fill to cast the 1 to the approriate type before the fill
JimClarke5 Mar 9, 2021
98df654
Code reformat
JimClarke5 Mar 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor to add generic to base class and change signature of call me…
…thod to <U extends T>.

Changed all subclasses to match these signatures.
  • Loading branch information
JimClarke5 committed Mar 6, 2021
commit d0e8de92efcd37d103492dd6a8f5bb8afb3918c0
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import org.tensorflow.types.family.TNumber;

/** Abstract base class for Activations */
public abstract class Activation {
public abstract class Activation<T extends TNumber> {

/** The TensorFlow Ops */
protected Ops tf;
Expand Down Expand Up @@ -55,8 +55,8 @@ protected void setTF(Ops tf) {
* Gets the calculation operation for the activation.
*
* @param input the input tensor
* @param <T> the data type of the input and result
* @param <U> the data type of the input and result
* @return The operand for the activation
*/
public abstract <T extends TNumber> Operand<T> call(Operand<T> input);
public abstract <U extends T> Operand<U> call(Operand<U> input);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.tensorflow.op.Ops;
import org.tensorflow.types.TBool;
import org.tensorflow.types.family.TFloating;
import org.tensorflow.types.family.TNumber;

/**
* Exponential linear unit.
Expand Down Expand Up @@ -48,8 +47,7 @@
* @see <a href="https://arxiv.org/abs/1511.07289">Clevert et al, 2016, Fast and Accurate Deep
* Network Learning by Exponential Linear Units (ELUs)</a>
*/
// TFloating
public class ELU extends Activation {
public class ELU extends Activation<TFloating> {

private static final double ALPHA_DEFAULT = 1.0;

Expand Down Expand Up @@ -79,19 +77,14 @@ public ELU(Ops tf, double alpha) {

/** {@inheritDoc} */
@Override
public <T extends TNumber> Operand<T> call(Operand<T> input) {
public <U extends TFloating> Operand<U> call(Operand<U> input) {

if (!TFloating.class.isAssignableFrom(input.type()) ) {
throw new IllegalArgumentException(
"Tensor type must be numeric or boolean: " + input.type().getSimpleName());
}

Operand<T> result = tf.nn.elu(input);
Operand<U> result = tf.nn.elu(input);
if (alpha == 1.0) {
return result;
} else {
Class<T> inputType = input.type();
Operand<T> y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), inputType));
Class<U> inputType = input.type();
Operand<U> y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), inputType));
Operand<TBool> cond = tf.math.greater(result, tf.dtypes.cast(tf.constant(0), inputType));
return tf.select(cond, result, y);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TFloating;
import org.tensorflow.types.family.TNumber;

/**
* Exponential activation function.
Expand All @@ -32,8 +31,7 @@
* // result is [0.04978707f, 0.36787945f, 1.f, 2.7182817f, 20.085537f]
* </pre>
*/
// TFloating
public class Exponential extends Activation {
public class Exponential extends Activation<TFloating> {

/**
* Creates an Exponential activation.
Expand All @@ -48,15 +46,12 @@ public Exponential(Ops tf) {
* Calculates the Exponential activation.
*
* @param input the input tensor
* @param <T> the data type of the input and result
* @param <U> the data type of the input and result
* @return an Operand for the exponential activation: <code>exp(x)</code>.
*/
@Override
public <T extends TNumber> Operand<T> call(Operand<T> input) {
if (!TFloating.class.isAssignableFrom(input.type()) ) {
throw new IllegalArgumentException(
"Tensor type must be numeric or boolean: " + input.type().getSimpleName());
}
public <U extends TFloating> Operand<U> call(Operand<U> input) {


return tf.math.exp(input);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,8 @@
package org.tensorflow.framework.activations;

import org.tensorflow.Operand;
import org.tensorflow.Session;
import org.tensorflow.op.Ops;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.family.TFloating;
import org.tensorflow.types.family.TNumber;

import java.util.Arrays;

import static org.tensorflow.framework.utils.CastHelper.cast;

Expand All @@ -39,16 +34,15 @@
* <p>or, if <code>approximate</code> is <code>false</code>.
*
* <pre>
* x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2))),
* x * P(X &lt;= x) = 0.5 * x * (1 + erf(x / sqrt(2))),
* </pre>
*
* where <code>P(X) ~ N(0, 1)</code>.
*
* @see <a href="https://arxiv.org/abs/1606.08415">Hendrycks, Dan and Gimpel, Kevin, 2016-2020,
* Gaussian Error Linear Units (GELUs)</a>
*/
// TFloating
public class GeLU extends Activation {
public class GeLU extends Activation<TFloating> {

private final boolean approximate;

Expand All @@ -74,21 +68,18 @@ public GeLU(Ops tf, boolean approximate) {

/** {@inheritDoc} */
@Override
public <T extends TNumber> Operand<T> call(Operand<T> input) {
if (!TFloating.class.isAssignableFrom(input.type())) {
throw new IllegalArgumentException(
"Tensor type must be numeric or boolean: " + input.type().getSimpleName());
}
public <U extends TFloating> Operand<U> call(Operand<U> input) {

if (approximate) {
/*
coeff = math_ops.cast(0.044715, features.dtype)
return 0.5 * features * (
1.0 + math_ops.tanh(0.7978845608028654 *
(features + coeff * math_ops.pow(features, 3))))
*/
Operand<T> coeff = cast(tf, tf.constant(0.044715), input.type());
Operand<T> point5 = cast(tf, tf.constant(0.5), input.type());
Operand<T> one = cast(tf, tf.constant(1.0), input.type());
Operand<U> coeff = cast(tf, tf.constant(0.044715), input.type());
Operand<U> point5 = cast(tf, tf.constant(0.5), input.type());
Operand<U> one = cast(tf, tf.constant(1.0), input.type());

return tf.math.mul(
point5,
Expand Down Expand Up @@ -128,5 +119,4 @@ public <T extends TNumber> Operand<T> call(Operand<T> input) {
input, cast(tf, tf.constant(1.4142135623730951), input.type()))))));
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
* // result is [0.f , 0.3f, 0.5f, 0.7f, 1.f]
* </pre>
*/
public class HardSigmoid extends Activation {
public class HardSigmoid extends Activation<TNumber> {

/**
* Creates Hard sigmoid activation.
Expand All @@ -54,12 +54,12 @@ public HardSigmoid(Ops tf) {

/** {@inheritDoc} */
@Override
public <T extends TNumber> Operand<T> call(Operand<T> input) {
Class<T> inputType = input.type();
Operand<T> point2 = tf.dtypes.cast(tf.constant(0.2), inputType);
Operand<T> point5 = tf.dtypes.cast(tf.constant(0.5), inputType);
public <U extends TNumber> Operand<U> call(Operand<U> input) {
Class<U> inputType = input.type();
Operand<U> point2 = tf.dtypes.cast(tf.constant(0.2), inputType);
Operand<U> point5 = tf.dtypes.cast(tf.constant(0.5), inputType);

Operand<T> x = tf.math.add(tf.math.mul(input, point2), point5);
Operand<U> x = tf.math.add(tf.math.mul(input, point2), point5);
return tf.clipByValue(
x, tf.dtypes.cast(tf.constant(0), inputType), tf.dtypes.cast(tf.constant(1), inputType));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
* // result is [-3.0f,-1.0f, 0.0f,1.0f,3.0f]
* </pre>
*/
public class Linear extends Activation {
public class Linear extends Activation<TNumber> {

/**
* Creates a linear activation.
Expand All @@ -46,7 +46,7 @@ public Linear(Ops tf) {

/** {@inheritDoc} */
@Override
public <T extends TNumber> Operand<T> call(Operand<T> input) {
public <U extends TNumber> Operand<U> call(Operand<U> input) {
return input;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@
* // result is [-0.f, -0.f, 0.f, 0.f, 10.f]
* </pre>
*/
// TFloating
public class ReLU extends Activation {
public class ReLU extends Activation<TNumber> {

public static final float ALPHA_DEFAULT = 0.0f;
public static final float MAX_VALUE_DEFAULT = Float.NaN;
Expand Down Expand Up @@ -95,11 +94,11 @@ public ReLU(Ops tf, float alpha, float maxValue, float threshold) {

/** {@inheritDoc} */
@Override
public <T extends TNumber> Operand<T> call(Operand<T> input) {
Class<T> inputType = input.type();
public <U extends TNumber> Operand<U> call(Operand<U> input) {
Class<U> inputType = input.type();

boolean clipMax = !Float.isNaN(maxValue);
Operand<T> negativePart = null;
Operand<U> negativePart = null;
if (alpha != 0) {
if (Float.isNaN(maxValue) && threshold == 0) {
return tf.nn.leakyRelu(input, LeakyRelu.alpha(alpha));
Expand All @@ -113,7 +112,7 @@ public <T extends TNumber> Operand<T> call(Operand<T> input) {
}
}

Operand<T> lInput;
Operand<U> lInput;
if (threshold != 0) {
// computes input for input > threshold else 0
Greater greater = tf.math.greater(input, tf.dtypes.cast(tf.constant(threshold), inputType));
Expand All @@ -126,8 +125,8 @@ public <T extends TNumber> Operand<T> call(Operand<T> input) {
lInput = tf.nn.relu(input);
}
if (clipMax) {
Operand<T> lmaxValue = tf.dtypes.cast(tf.constant(maxValue), inputType);
Operand<T> zero = tf.dtypes.cast(tf.constant(0), inputType);
Operand<U> lmaxValue = tf.dtypes.cast(tf.constant(maxValue), inputType);
Operand<U> zero = tf.dtypes.cast(tf.constant(0), inputType);
lInput = tf.clipByValue(lInput, zero, lmaxValue);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TFloating;

/**
* Scaled Exponential Linear Unit (SELU).
Expand Down Expand Up @@ -44,8 +44,7 @@
*
* @see <a href="https://arxiv.org/abs/1706.02515">Klambauer et al., 2017</a>
*/
// TFloating
public class SELU extends Activation {
public class SELU extends Activation<TFloating> {

/**
* Creates a Scaled Exponential Linear Unit (SELU) activation.
Expand All @@ -58,7 +57,7 @@ public SELU(Ops tf) {

/** {@inheritDoc} */
@Override
public <T extends TNumber> Operand<T> call(Operand<T> input) {
public <U extends TFloating> Operand<U> call(Operand<U> input) {
return tf.nn.selu(input);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TFloating;

/**
* Sigmoid activation. <code>sigmoid(x) = 1 / (1 + exp(-x))</code>.
Expand All @@ -39,8 +39,7 @@
* // 5.0000000e-01f,7.3105860e-01f, 1.f]
* </pre>
*/
// TFloating
public class Sigmoid extends Activation {
public class Sigmoid extends Activation<TFloating> {

/**
* Creates a Sigmoid activation.
Expand All @@ -53,7 +52,7 @@ public Sigmoid(Ops tf) {

/** {@inheritDoc} */
@Override
public <T extends TNumber> Operand<T> call(Operand<T> input) {
public <U extends TFloating> Operand<U> call(Operand<U> input) {
return tf.math.sigmoid(input);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.ReduceMax;
import org.tensorflow.op.core.ReduceSum;
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TFloating;

/**
* Softmax converts a real vector to a vector of categorical probabilities.
Expand All @@ -36,8 +36,7 @@
*
* <p>The input values in are the log-odds of the resulting probability.
*/
// TFloating
public class Softmax extends Activation {
public class Softmax extends Activation<TFloating> {

private static final int AXIS_DEFAULT = -1;

Expand Down Expand Up @@ -66,16 +65,16 @@ public Softmax(Ops tf, int axis) {

/** {@inheritDoc} */
@Override
public <T extends TNumber> Operand<T> call(Operand<T> input) {
public <U extends TFloating> Operand<U> call(Operand<U> input) {
Shape shape = input.shape();
int numDimensions = shape.numDimensions();
if (numDimensions == 2) {
return tf.nn.softmax(input);
} else {
Operand<T> e =
Operand<U> e =
tf.math.exp(
tf.math.sub(input, tf.reduceMax(input, tf.constant(axis), ReduceMax.keepDims(true))));
Operand<T> s = tf.reduceSum(e, tf.constant(axis), ReduceSum.keepDims(true));
Operand<U> s = tf.reduceSum(e, tf.constant(axis), ReduceSum.keepDims(true));
return tf.math.div(e, s);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TFloating;

/**
* Softplus activation function, <code>softplus(x) = log(exp(x) + 1)</code>.
Expand All @@ -32,8 +32,7 @@
* // 1.3132616e+00f, 2.0000000e+01f]
* </pre>
*/
// TFloating
public class Softplus extends Activation {
public class Softplus extends Activation<TFloating> {

/**
* Creates a Softplus activation function.
Expand All @@ -46,7 +45,7 @@ public Softplus(Ops tf) {

/** {@inheritDoc} */
@Override
public <T extends TNumber> Operand<T> call(Operand<T> input) {
public <U extends TFloating> Operand<U> call(Operand<U> input) {
return tf.math.softplus(input);
}
}
Loading