Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c57a2e7
Merge pull request #3 from tensorflow/master
JimClarke5 Oct 8, 2020
09fc07e
Merge pull request #4 from tensorflow/master
JimClarke5 Oct 27, 2020
a99dcb4
Merge pull request #5 from tensorflow/master
JimClarke5 Nov 17, 2020
ba294ea
Merge pull request #6 from tensorflow/master
JimClarke5 Nov 19, 2020
04f419a
Merge pull request #7 from tensorflow/master
JimClarke5 Dec 30, 2020
02e7ebf
Merge pull request #8 from tensorflow/master
JimClarke5 Jan 29, 2021
e0c9ed8
Merge pull request #9 from tensorflow/master
JimClarke5 Feb 1, 2021
5b0374b
Merge pull request #10 from tensorflow/master
JimClarke5 Feb 11, 2021
e038bbd
Merge pull request #11 from tensorflow/master
JimClarke5 Feb 23, 2021
28a34dd
Clean up generics, remove generics from class and fix call method to …
JimClarke5 Mar 3, 2021
309b834
resynch with master, for some reason when I build on mac, the order f…
JimClarke5 Mar 3, 2021
def3051
Merge pull request #13 from tensorflow/master
JimClarke5 Mar 3, 2021
3a9ae37
Merge branch 'master' of https://github.com/JimClarke5/java into Gene…
JimClarke5 Mar 3, 2021
c5d37bf
Add GeLU activation present in TF 2.4
JimClarke5 Mar 4, 2021
11f8ac9
Fix @param<T> and reformat
JimClarke5 Mar 4, 2021
40a95af
Fix JavaDoc to add @param <T>
JimClarke5 Mar 6, 2021
d0e8de9
Refactor to add generic to base class and change signature of call me…
JimClarke5 Mar 6, 2021
478b78a
Add check for scalar.
JimClarke5 Mar 6, 2021
f53fa08
Change to accept TString value.
JimClarke5 Mar 7, 2021
79594da
Fix GeLU equations with separate Operands
JimClarke5 Mar 9, 2021
112c740
Fix Constant to handle TString properly
JimClarke5 Mar 9, 2021
61e6206
Added Stddev check for not less than 0.
JimClarke5 Mar 9, 2021
3b4b607
Fix fix fill to cast the 1 to the approriate type before the fill
JimClarke5 Mar 9, 2021
98df654
Code reformat
JimClarke5 Mar 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added Stddev check for not less than 0.
  • Loading branch information
JimClarke5 committed Mar 9, 2021
commit 61e620620a24985425f2a2fb2f3ea45265088a8f
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,13 @@ public RandomNormal(Ops tf, double mean, long seed) {
* @param stddev Standard deviation of the random values to generate.
* @param seed the seed for random number generation. An initializer created with a given seed
* will always produce the same random tensor for a given shape and data type.
* @throws IllegalArgumentException if standard deviation is less than 0.
*/
public RandomNormal(Ops tf, double mean, double stddev, long seed) {
super(tf);
if(stddev < 0) {
throw new IllegalArgumentException("Standard deviation (stddev) cannot be less than 0, got " + stddev);
}
this.mean = mean;
this.stddev = stddev;
this.seed = seed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
=======================================================================*/
package org.tensorflow.framework.initializers;

import org.junit.jupiter.api.*;
import org.junit.jupiter.api.Test;
import org.tensorflow.Operand;
import org.tensorflow.framework.utils.TestSession;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;

import static org.junit.jupiter.api.Assertions.assertThrows;

/** Test the RandomNormal initializer */
public class RandomNormalTest {

Expand All @@ -32,18 +34,6 @@ public class RandomNormalTest {

public RandomNormalTest() {}

@BeforeAll
public static void setUpClass() {}

@AfterAll
public static void tearDownClass() {}

@BeforeEach
public void setUp() {}

@AfterEach
public void tearDown() {}

/** Test of call method, of class RandomNormal. */
@Test
public void testCalltestSoftmaxFloat() {
Expand Down Expand Up @@ -86,4 +76,19 @@ public void testReproducible() {
session.evaluate(operand1, operand2);
}
}

@Test
public void testInvalidStdDev() {
for (TestSession.Mode tfMode : tfModes)
assertThrows(
IllegalArgumentException.class,
() -> {
try (TestSession session = TestSession.createTestSession(tfMode)) {
Ops tf = session.getTF();
Shape shape = Shape.of(2, 2);

RandomNormal instance = new RandomNormal(tf, MEAN_VALUE, -2.5, SEED);
}
});
}
}