Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c57a2e7
Merge pull request #3 from tensorflow/master
JimClarke5 Oct 8, 2020
09fc07e
Merge pull request #4 from tensorflow/master
JimClarke5 Oct 27, 2020
a99dcb4
Merge pull request #5 from tensorflow/master
JimClarke5 Nov 17, 2020
ba294ea
Merge pull request #6 from tensorflow/master
JimClarke5 Nov 19, 2020
04f419a
Merge pull request #7 from tensorflow/master
JimClarke5 Dec 30, 2020
02e7ebf
Merge pull request #8 from tensorflow/master
JimClarke5 Jan 29, 2021
e0c9ed8
Merge pull request #9 from tensorflow/master
JimClarke5 Feb 1, 2021
5b0374b
Merge pull request #10 from tensorflow/master
JimClarke5 Feb 11, 2021
e038bbd
Merge pull request #11 from tensorflow/master
JimClarke5 Feb 23, 2021
28a34dd
Clean up generics, remove generics from class and fix call method to …
JimClarke5 Mar 3, 2021
309b834
resynch with master, for some reason when I build on mac, the order f…
JimClarke5 Mar 3, 2021
def3051
Merge pull request #13 from tensorflow/master
JimClarke5 Mar 3, 2021
3a9ae37
Merge branch 'master' of https://github.com/JimClarke5/java into Gene…
JimClarke5 Mar 3, 2021
c5d37bf
Add GeLU activation present in TF 2.4
JimClarke5 Mar 4, 2021
11f8ac9
Fix @param<T> and reformat
JimClarke5 Mar 4, 2021
40a95af
Fix JavaDoc to add @param <T>
JimClarke5 Mar 6, 2021
d0e8de9
Refactor to add generic to base class and change signature of call me…
JimClarke5 Mar 6, 2021
478b78a
Add check for scalar.
JimClarke5 Mar 6, 2021
f53fa08
Change to accept TString value.
JimClarke5 Mar 7, 2021
79594da
Fix GeLU equations with separate Operands
JimClarke5 Mar 9, 2021
112c740
Fix Constant to handle TString properly
JimClarke5 Mar 9, 2021
61e6206
Added Stddev check for not less than 0.
JimClarke5 Mar 9, 2021
3b4b607
Fix fix fill to cast the 1 to the approriate type before the fill
JimClarke5 Mar 9, 2021
98df654
Code reformat
JimClarke5 Mar 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add GeLU activation present in TF 2.4
  • Loading branch information
JimClarke5 committed Mar 4, 2021
commit c5d37bf7ee07f43985ecc742a9ae5236abbd7d94
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================*/
package org.tensorflow.framework.activations;

import org.tensorflow.Operand;
import org.tensorflow.Session;
import org.tensorflow.op.Ops;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.family.TFloating;
import org.tensorflow.types.family.TNumber;

import java.util.Arrays;

import static org.tensorflow.framework.utils.CastHelper.cast;

/**
* Applies the Gaussian error linear unit (GELU) activation function.
*
* <p>Gaussian error linear unit (GELU) computes {@code x * P(X <= x)}, where {@code P(X) ~ N(0,
* 1)}. The (GELU) nonlinearity weights inputs by their value, rather than gates inputs by their
* sign as in ReLU. if <code>approximate</code> is <code>true</code> :
*
* <pre>
* 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))
* </pre>
*
* <p>or, if <code>approximate</code> is <code>false</code>.
*
* <pre>
* x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2))),
* </pre>
*
* where <code>P(X) ~ N(0, 1)</code>.
*
* @see <a href="https://arxiv.org/abs/1606.08415">Hendrycks, Dan and Gimpel, Kevin, 2016-2020,
* Gaussian Error Linear Units (GELUs)</a>
*/
// TFloating
public class GeLU extends Activation {

private final boolean approximate;

/**
* Creates a e Gaussian error linear unit (GELU) activation, with approximate set to false
*
* @param tf The TensorFlow ops
*/
public GeLU(Ops tf) {
this(tf, false);
}

/**
* Creates a e Gaussian error linear unit (GELU) activation
*
* @param tf The TensorFlow ops
* @param approximate indicator whether to enable approximation.
*/
public GeLU(Ops tf, boolean approximate) {
super(tf);
this.approximate = approximate;
}

/** {@inheritDoc} */
@Override
public <T extends TNumber> Operand<T> call(Operand<T> input) {
if (!TFloating.class.isAssignableFrom(input.type())) {
throw new IllegalArgumentException(
"Tensor type must be numeric or boolean: " + input.type().getSimpleName());
}
if (approximate) {
/*
coeff = math_ops.cast(0.044715, features.dtype)
return 0.5 * features * (
1.0 + math_ops.tanh(0.7978845608028654 *
(features + coeff * math_ops.pow(features, 3))))
*/
Operand<T> coeff = cast(tf, tf.constant(0.044715), input.type());
Operand<T> point5 = cast(tf, tf.constant(0.5), input.type());
Operand<T> one = cast(tf, tf.constant(1.0), input.type());

return tf.math.mul(
point5,
tf.math.mul(
input,
tf.math.add(
one,
tf.math.tanh(
tf.math.mul(
// sqrt(2.0 / PI)
cast(tf, tf.constant(0.7978845608028654), input.type()),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't this one pulled out like the others?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was mainly for debugging and keeping the parts of the equation manageable. I will change this one and add one for the constant "three".

BTW: It would be nice if we could pass a type to tf.constant, something liketf.constant(3, input.dtype())to return the correct type.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a good extension to have. That should be fairly straightforward.

tf.math.add(
input,
tf.math.mul(
coeff,
tf.math.pow(input, cast(tf, tf.constant(3), input.type()))) // mul
) // add
) // mul
) // tanh
) // add
) // mul
); // mul

} else {
/*
return 0.5 * features * (1.0 + math_ops.erf(
features / math_ops.cast(1.4142135623730951, features.dtype)))
*/
return tf.math.mul(
cast(tf, tf.constant(0.5), input.type()),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe hoist this and the one below out of the if statement and use local variables?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

tf.math.mul(
input,
tf.math.add(
cast(tf, tf.constant(1), input.type()),
tf.math.erf(
tf.math.div(
input, cast(tf, tf.constant(1.4142135623730951), input.type()))))));
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package org.tensorflow.framework.activations;

import org.junit.jupiter.api.Test;
import org.tensorflow.Operand;
import org.tensorflow.framework.utils.TestSession;
import org.tensorflow.op.Ops;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;

import static org.junit.jupiter.api.Assertions.*;

class GeLUTest {

private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH};

/** Test of GeLU call method */
@Test
public void testCallFloat() {
float[][] input = {
{0.22805803f, 0.60407318f, 0.91519962f, 0.35643331f, 0.28702669f},
{0.11558246f, 0.57658853f, 0.47569648f, 0.02271072f, 0.24709974f}};
float[][] expected = {{0.13459972f, 0.43922312f, 0.75042395f, 0.22784713f, 0.17593417f},
{0.06310898f, 0.41392788f, 0.32483157f, 0.01156111f, 0.14766297f}};
for (TestSession.Mode tfMode : tfModes)
try (TestSession session = TestSession.createTestSession(tfMode)) {
Ops tf = session.getTF();
GeLU instance = new GeLU(tf);
Operand<TFloat32> result = instance.call(tf.constant(input));
session.evaluate(tf.constant(expected), result);
}
}

/** Test of GeLU call method */
@Test
public void testCallDouble() {
double[][] input = {
{0.22805803, 0.60407318, 0.91519962, 0.35643331, 0.28702669},
{0.11558246, 0.57658853, 0.47569648, 0.02271072, 0.24709974}};
double[][] expected = {
{0.13459972, 0.43922312, 0.75042395, 0.22784713, 0.17593417},
{0.06310898, 0.41392788, 0.32483157, 0.01156111, 0.14766297}
};
for (TestSession.Mode tfMode : tfModes)
try (TestSession session = TestSession.createTestSession(tfMode)) {
Ops tf = session.getTF();
GeLU instance = new GeLU(tf);
Operand<TFloat64> result = instance.call(tf.constant(input));
session.evaluate(tf.constant(expected), result);
}
}

/** Test of GeLU call method */
@Test
public void testCallFloatApproximate() {
float[][] input = {
{0.22805803f, 0.60407318f, 0.91519962f, 0.35643331f, 0.28702669f},
{0.11558246f, 0.57658853f, 0.47569648f, 0.02271072f, 0.24709974f}};
float[][] expected = {{0.13459886f, 0.43918941f, 0.75030122f, 0.22784227f, 0.17593207f},
{0.06310892f, 0.41389921f, 0.32481722f, 0.01156111f, 0.14766179f}};
for (TestSession.Mode tfMode : tfModes)
try (TestSession session = TestSession.createTestSession(tfMode)) {
Ops tf = session.getTF();
GeLU instance = new GeLU(tf, true);
Operand<TFloat32> result = instance.call(tf.constant(input));
session.evaluate(tf.constant(expected), result);
}
}

/** Test of GeLU call method */
@Test
public void testCallDoubleApproximate() {
double[][] input = {
{0.22805803, 0.60407318, 0.91519962, 0.35643331, 0.28702669},
{0.11558246, 0.57658853, 0.47569648, 0.02271072, 0.24709974}};
double[][] expected = {{0.13459886, 0.43918941, 0.75030122, 0.22784227, 0.17593207},
{0.06310892, 0.41389921, 0.32481722, 0.01156111, 0.14766179}};
//for (TestSession.Mode tfMode : tfModes)
try (TestSession session = TestSession.createTestSession(TestSession.Mode.GRAPH)) {
Ops tf = session.getTF();
GeLU instance = new GeLU(tf, true);
Operand<TFloat64> result = instance.call(tf.constant(input));
session.print(result);
session.evaluate(tf.constant(expected), result);
}
}


}