Skip to content

Commit fb52dd4

Browse files
committed
Change UNTRUNCATED_NORMAL to NORMAL
1 parent 1f32de2 commit fb52dd4

4 files changed

Lines changed: 11 additions & 11 deletions

File tree

tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,13 @@ public void testCallUniformReproducible() {
130130
}
131131

132132
@Test
133-
public void testCallUNTRUNCATED_NORMALReproducible() {
133+
public void testCall_NORMALReproducible() {
134134
for (TestSession.Mode tfMode : tfModes)
135135
try (TestSession session = TestSession.createTestSession(tfMode)) {
136136
Ops tf = session.getTF();
137137
Shape shape = Shape.of(2, 2);
138138
Glorot<TFloat64, TFloat64> instance =
139-
new Glorot<>(tf, Distribution.UNTRUNCATED_NORMAL, SEED);
139+
new Glorot<>(tf, Distribution.NORMAL, SEED);
140140
Operand<TFloat64> operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE);
141141
Operand<TFloat64> operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE);
142142
session.evaluate(operand1, operand2);

tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,12 @@ public void testCallUniformReproducible() {
128128
}
129129

130130
@Test
131-
public void testCallUNTRUNCATED_NORMALReproducible() {
131+
public void testCall_NORMALReproducible() {
132132
for (TestSession.Mode tfMode : tfModes)
133133
try (TestSession session = TestSession.createTestSession(tfMode)) {
134134
Ops tf = session.getTF();
135135
Shape shape = Shape.of(2, 2);
136-
He<TFloat64, TFloat64> instance = new He<>(tf, Distribution.UNTRUNCATED_NORMAL, SEED);
136+
He<TFloat64, TFloat64> instance = new He<>(tf, Distribution.NORMAL, SEED);
137137
Operand<TFloat64> operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE);
138138
Operand<TFloat64> operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE);
139139
session.evaluate(operand1, operand2);

tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,12 @@ public void testCallUniformReproducible() {
128128
}
129129

130130
@Test
131-
public void testCallUNTRUNCATED_NORMALReproducible() {
131+
public void testCall_NORMALReproducible() {
132132
for (TestSession.Mode tfMode : tfModes)
133133
try (TestSession session = TestSession.createTestSession(tfMode)) {
134134
Ops tf = session.getTF();
135135
Shape shape = Shape.of(2, 2);
136-
LeCun<TFloat64, TFloat64> instance = new LeCun<>(tf, Distribution.UNTRUNCATED_NORMAL, SEED);
136+
LeCun<TFloat64, TFloat64> instance = new LeCun<>(tf, Distribution.NORMAL, SEED);
137137
Operand<TFloat64> operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE);
138138
Operand<TFloat64> operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE);
139139
session.evaluate(operand1, operand2);

tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ public void testCall_Double_1_FAN_IN_TRUNCATED_NORMAL() {
8585

8686
/** Test of call method, of class VarianceScaling. */
8787
@Test
88-
public void testCall_Float_1_FAN_IN_UNTRUNCATED_NORMAL() {
88+
public void testCall_Float_1_FAN_IN_NORMAL() {
8989
float[] expected = {-0.46082667F, -0.25798687F, -0.06924929F, -0.28017485F};
9090
for (TestSession.Mode tfMode : tfModes)
9191
try (TestSession session = TestSession.createTestSession(tfMode)) {
@@ -96,15 +96,15 @@ public void testCall_Float_1_FAN_IN_UNTRUNCATED_NORMAL() {
9696
tf,
9797
1.0,
9898
VarianceScaling.Mode.FAN_IN,
99-
VarianceScaling.Distribution.UNTRUNCATED_NORMAL,
99+
VarianceScaling.Distribution.NORMAL,
100100
SEED);
101101
Operand<TFloat32> operand = instance.call(tf.constant(shape), TFloat32.DTYPE);
102102
session.evaluate(expected, operand);
103103
}
104104
}
105105

106106
@Test
107-
public void testCall_Double_1_FAN_IN_UNTRUNCATED_NORMAL() {
107+
public void testCall_Double_1_FAN_IN_NORMAL() {
108108
double[] expected = {
109109
1.3169108626945392, -1.0985224689731887, -0.13536536217837225, -1.698770780615686
110110
};
@@ -117,7 +117,7 @@ public void testCall_Double_1_FAN_IN_UNTRUNCATED_NORMAL() {
117117
tf,
118118
1.0,
119119
VarianceScaling.Mode.FAN_IN,
120-
VarianceScaling.Distribution.UNTRUNCATED_NORMAL,
120+
VarianceScaling.Distribution.NORMAL,
121121
SEED);
122122
Operand<TFloat64> operand = instance.call(tf.constant(shape), TFloat64.DTYPE);
123123
session.evaluate(expected, operand);
@@ -185,7 +185,7 @@ public void testReproducible2() {
185185
tf,
186186
1.0,
187187
VarianceScaling.Mode.FAN_IN,
188-
VarianceScaling.Distribution.UNTRUNCATED_NORMAL,
188+
VarianceScaling.Distribution.NORMAL,
189189
SEED);
190190
Operand<TFloat64> operand1 = instance.call(tf.constant(shape), TFloat64.DTYPE);
191191
Operand<TFloat64> operand2 = instance.call(tf.constant(shape), TFloat64.DTYPE);

0 commit comments

Comments
 (0)