Skip to content

Commit e802957

Browse files
committed
eager for tf.relu, tf.tanh and tf.sigmoid
1 parent dcfaa77 commit e802957

21 files changed

Lines changed: 235 additions & 88 deletions

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,11 @@ public Tensor embedding_lookup(Tensor @params,
116116
public IActivation relu() => new relu();
117117
public IActivation swish() => new swish();
118118
public IActivation tanh() => new tanh();
119+
public Tensor tanh(Tensor x, string name = null)
120+
=> gen_nn_ops.tanh(x, name);
119121

120-
public Tensor relu(Tensor features, string name = null) => gen_nn_ops.relu(features, name);
122+
public Tensor relu(Tensor features, string name = null)
123+
=> gen_nn_ops.relu(features, name);
121124

122125
public Tensor[] fused_batch_norm(Tensor x,
123126
VariableV1 scale,
@@ -212,6 +215,14 @@ public Tensor softmax_cross_entropy_with_logits(Tensor labels, Tensor logits, in
212215
public Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null)
213216
=> nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name);
214217

218+
/// <summary>
219+
/// Computes sigmoid of `x` element-wise.
220+
/// Specifically, `y = 1 / (1 + exp(-x))`.
221+
/// </summary>
222+
/// <typeparam name="T"></typeparam>
223+
/// <param name="x"></param>
224+
/// <param name="name">A name for the operation (optional).</param>
225+
/// <returns>A Tensor with the same type as `x`.</returns>
215226
public Tensor sigmoid<T>(T x, string name = null)
216227
=> math_ops.sigmoid(x, name: name);
217228
}

src/TensorFlowNET.Core/APIs/tf.ops.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool
3333
public Tensor assign(RefVariable @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
3434
=> state_ops.assign(@ref, value, validate_shape, use_locking, name);
3535

36+
public Tensor assign(ResourceVariable @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
37+
=> state_ops.assign(@ref, value, validate_shape, use_locking, name);
38+
3639
public void device(string device_name)
3740
=> get_default_graph().device(device_name);
3841

src/TensorFlowNET.Core/Clustering/KMeans.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,10 @@ private RefVariable[] _create_variables(Tensor num_clusters)
9898
var cluster_counts = _use_mini_batch ? tf.Variable(ones) : null;
9999
return new RefVariable[]
100100
{
101-
cluster_centers,
101+
/*cluster_centers,
102102
cluster_centers_initialized,
103103
cluster_counts,
104-
cluster_centers_updated,
104+
cluster_centers_updated,*/
105105
update_in_steps
106106
};
107107
}

src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace Tensorflow.Eager
1111
public partial class wrap_tfe_src
1212
{
1313
static int kFastPathExecuteInputStartIndex = 0;
14-
public static EagerTensor TFE_Py_FastPathExecute(Context ctx,
14+
public static EagerTensor TFE_FastPathExecute(Context ctx,
1515
string device_name,
1616
string opName,
1717
string name,

src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ public static VariableV1 make_variable(string name,
4646
Func<Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype);
4747

4848
var variable_dtype = dtype.as_base_dtype();
49-
var v = tf.VariableV1(init_val,
50-
use_resource: use_resource,
49+
var v = tf.Variable(init_val,
5150
dtype: dtype,
5251
shape: shape,
5352
name: name);

src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs

Lines changed: 22 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Tensorflow.Eager;
1718
using static Tensorflow.Binding;
1819

1920
namespace Tensorflow.Operations
@@ -463,50 +464,30 @@ public static (Tensor loss, Tensor backprop) sparse_softmax_cross_entropy_with_l
463464
/// <returns>A `Tensor`. Has the same type as `features`.</returns>
464465
public static Tensor relu(Tensor features, string name = null)
465466
{
467+
if (tf.context.executing_eagerly())
468+
{
469+
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
470+
"Relu", name, null,
471+
features);
472+
return _result;
473+
}
466474

467-
//_ctx = _context._context
468-
//if _ctx is not None and _ctx._eager_context.is_eager:
469-
// try:
470-
// _result = _pywrap_tensorflow.TFE_Py_FastPathExecute(
471-
// _ctx._context_handle, _ctx._eager_context.device_name, "Relu", name,
472-
// _ctx._post_execution_callbacks, features)
473-
// return _result
474-
// except _core._FallbackException:
475-
// try:
476-
// return relu_eager_fallback(
477-
// features, name=name, ctx=_ctx)
478-
// except _core._SymbolicException:
479-
// pass # Add nodes to the TensorFlow graph.
480-
// except (TypeError, ValueError):
481-
// result = _dispatch.dispatch(
482-
// relu, features=features, name=name)
483-
// if result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
484-
// return result
485-
// raise
486-
// except _core._NotOkStatusException as e:
487-
// if name is not None:
488-
// message = e.message + " name: " + name
489-
// else:
490-
// message = e.message
491-
// _six.raise_from(_core._status_to_exception(e.code, message), None)
492-
//# Add nodes to the TensorFlow graph.
493-
//try:
494-
OpDefLibrary _op_def_lib = new OpDefLibrary();
495475
var _op = _op_def_lib._apply_op_helper("Relu", name: name, args: new { features });
496476
return _op.outputs[0];
497-
//except (TypeError, ValueError):
498-
// result = _dispatch.dispatch(
499-
// relu, features=features, name=name)
500-
// if result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
501-
// return result
502-
// raise
503-
// var _result = _op.outputs.ToArray();
504-
//_inputs_flat = _op.inputs
505-
//_attrs = ("T", _op.get_attr("T"))
506-
//_execute.record_gradient(
507-
// "Relu", _inputs_flat, _attrs, _result, name)
508-
//_result, = _result
509-
// return _result;
477+
}
478+
479+
public static Tensor tanh(Tensor x, string name = null)
480+
{
481+
if (tf.context.executing_eagerly())
482+
{
483+
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
484+
"Tanh", name, null,
485+
x);
486+
return _result;
487+
}
488+
489+
var _op = _op_def_lib._apply_op_helper("Tanh", name: name, args: new { x });
490+
return _op.outputs[0];
510491
}
511492
}
512493
}

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ public static Tensor pack(Tensor[] values, int axis = 0, string name = null)
125125
{
126126
if(tf.context.executing_eagerly())
127127
{
128-
var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name, "Pack", name, null, values, "axis", axis);
128+
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, "Pack", name, null, values, "axis", axis);
129129
return _result;
130130
}
131131

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ public static Tensor mean<T1, T2>(T1 input, T2 axis, bool keep_dims= false, stri
120120
{
121121
try
122122
{
123-
var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name, "Mean", name, null, input, axis, "keep_dims", keep_dims);
123+
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, "Mean", name, null, input, axis, "keep_dims", keep_dims);
124124
return _result;
125125
}
126126
catch (Exception ex)
@@ -171,7 +171,7 @@ public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null)
171171
{
172172
if (tf.context.executing_eagerly())
173173
{
174-
var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "Add", name, null, x, y);
174+
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "Add", name, null, x, y);
175175
return _result;
176176
}
177177

@@ -204,6 +204,14 @@ public static Tensor ceil(Tensor x, string name = null)
204204

205205
public static Tensor sin(Tensor x, string name = null)
206206
{
207+
if (tf.context.executing_eagerly())
208+
{
209+
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
210+
"Sin", name, null,
211+
x);
212+
return _result;
213+
}
214+
207215
var _op = _op_def_lib._apply_op_helper("Sin", name, args: new { x });
208216

209217
return _op.outputs[0];
@@ -225,6 +233,14 @@ public static Tensor sin(Tensor x, string name = null)
225233
/// </remarks>
226234
public static Tensor sigmoid(Tensor x, string name = "Sigmoid")
227235
{
236+
if (tf.context.executing_eagerly())
237+
{
238+
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
239+
"Sigmoid", name, null,
240+
x);
241+
return _result;
242+
}
243+
228244
var op = _op_def_lib._apply_op_helper("Sigmoid", name: name, new { x });
229245

230246
return op.output;
@@ -493,7 +509,7 @@ public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate= false, stri
493509
{
494510
if (tf.context.executing_eagerly())
495511
{
496-
var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "Cast", name, null, x, "DstT", DstT, "Truncate", Truncate);
512+
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "Cast", name, null, x, "DstT", DstT, "Truncate", Truncate);
497513
return _result;
498514
}
499515

@@ -520,7 +536,7 @@ public static Tensor sub<Tx, Ty>(Tx x, Ty y, string name = null)
520536
{
521537
if (tf.context.executing_eagerly())
522538
{
523-
var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "Sub", name, null, x, y);
539+
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "Sub", name, null, x, y);
524540
return _result;
525541
}
526542

@@ -571,7 +587,7 @@ public static Tensor mul<Tx, Ty>(Tx x, Ty y, string name = null)
571587
{
572588
if (tf.context.executing_eagerly())
573589
{
574-
var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "Mul", name, null, x, y);
590+
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "Mul", name, null, x, y);
575591
return _result;
576592
}
577593

@@ -591,7 +607,7 @@ public static Tensor real_div(Tensor x, Tensor y, string name = null)
591607
{
592608
if (tf.context.executing_eagerly())
593609
{
594-
var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "RealDiv", name, null, x, y);
610+
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "RealDiv", name, null, x, y);
595611
return _result;
596612
}
597613

@@ -618,7 +634,7 @@ public static Tensor floor_div(Tensor x, Tensor y, string name = null)
618634
{
619635
if (tf.context.executing_eagerly())
620636
{
621-
var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, "", "FloorDiv", name, null, x, y);
637+
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "FloorDiv", name, null, x, y);
622638
return _result;
623639
}
624640

@@ -640,7 +656,7 @@ public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool
640656
{
641657
if (tf.context.executing_eagerly())
642658
{
643-
var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name,
659+
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
644660
"MatMul", name, null,
645661
a, b, "transpose_a", transpose_a, "transpose_b", transpose_b);
646662
return _result;
@@ -748,7 +764,7 @@ public static Tensor _sum<Tx, Ty>(Tx input, Ty axis = default, bool keep_dims =
748764
{
749765
try
750766
{
751-
var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name,
767+
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
752768
"Sum", name, null,
753769
input, axis, "keep_dims", keep_dims);
754770
return _result;
@@ -789,7 +805,7 @@ public static Tensor range(Tensor start, Tensor limit, Tensor delta, string name
789805
{
790806
if (tf.context.executing_eagerly())
791807
{
792-
var _result = wrap_tfe_src.TFE_Py_FastPathExecute(tf.context, tf.context.device_name, "Range", name, null, start, limit, delta);
808+
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, "Range", name, null, start, limit, delta);
793809
return _result;
794810
}
795811

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,10 +278,12 @@ public static Tensor reduce_prod(Tensor input_tensor, int[] axis = null, bool ke
278278
}
279279

280280
public static Tensor sigmoid<T>(T x, string name = null)
281-
{
282-
var x_tensor = ops.convert_to_tensor(x, name: "x");
283-
return gen_math_ops.sigmoid(x_tensor, name: name);
284-
}
281+
=> tf_with(ops.name_scope(name, "Sigmoid", x), scope =>
282+
{
283+
name = scope;
284+
var x_tensor = ops.convert_to_tensor(x, name: "x");
285+
return gen_math_ops.sigmoid(x_tensor, name: name);
286+
});
285287

286288
public static Tensor sign<T>(T x, string name = null)
287289
=> gen_math_ops.sign(x, name: name);

src/TensorFlowNET.Core/Tensors/constant_op.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ private static EagerTensor convert_to_eager_tensor(object value, Context ctx, TF
9191
return new EagerTensor(str, ctx.device_name);
9292
case int int32:
9393
return new EagerTensor(int32, ctx.device_name);
94+
case float[] float32s:
95+
return new EagerTensor(float32s, ctx.device_name);
9496
default:
9597
throw new NotImplementedException($"convert_to_eager_tensor {value.GetType()}");
9698
}

0 commit comments

Comments
 (0)