Skip to content

Commit 1deaa75

Browse files
committed
Fix strided_slice_grad for Eager mode.
1 parent dcb3f8b commit 1deaa75

4 files changed

Lines changed: 53 additions & 24 deletions

File tree

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,9 @@ public static bool issubset<T>(this IEnumerable<T> subset, IEnumerable<T> src)
424424
return true;
425425
}
426426

427+
public static bool empty<T>(this Queue<T> queue)
428+
=> queue.Count == 0;
429+
427430
public static TValue SetDefault<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue value)
428431
{
429432
if (dic.ContainsKey(key))

src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ namespace Tensorflow.Keras.Engine
1111
public partial class Layer
1212
{
1313
protected List<Layer> _layers = new List<Layer>();
14+
public List<Layer> Layers => _layers;
1415

1516
protected Layer Dense(int units,
1617
Activation activation = null,

src/TensorFlowNET.Core/Keras/Engine/Layer.cs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ public abstract partial class Layer : AutoTrackable
6161
protected List<IVariableV1> trainable_weights;
6262

6363
public virtual List<IVariableV1> trainable_variables => trainable_weights;
64-
6564

6665
protected List<IVariableV1> non_trainable_weights;
6766
public List<IVariableV1> non_trainable_variables => non_trainable_weights;
@@ -83,7 +82,8 @@ public abstract partial class Layer : AutoTrackable
8382
ThreadLocal<CallContext> callContext;
8483
public CallContext CallContext => callContext.Value;
8584
public Tensor[] input => inboundNodes[0].input_tensors;
86-
85+
public Dictionary<int, List<Node>> NodesByDepth { get; set; }
86+
public TensorShape output_shape => inboundNodes[0].Outputs.shape;
8787
public Layer(LayerArgs args)
8888
{
8989
this.args = args;
@@ -224,5 +224,23 @@ protected virtual void _init_set_name(string name, bool zero_based = true)
224224
this.name = base_layer_utils.unique_layer_name(base_name, zero_based: zero_based);
225225
}
226226
}
227+
228+
public int count_params()
229+
{
230+
if (Trainable)
231+
return layer_utils.count_params(this, weights);
232+
return 0;
233+
}
234+
235+
public List<IVariableV1> weights
236+
{
237+
get
238+
{
239+
var weights = new List<IVariableV1>();
240+
weights.AddRange(trainable_weights);
241+
weights.AddRange(non_trainable_weights);
242+
return weights;
243+
}
244+
}
227245
}
228246
}

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -388,19 +388,19 @@ public static Tensor placeholder_with_default<T>(T input, int[] shape, string na
388388
return _op.outputs[0];
389389
}
390390

391-
public static Tensor select<Tx, Ty>(Tensor condition, Tx t, Ty e, string name = null)
391+
public static Tensor select<Tx, Ty>(Tensor condition, Tx x, Ty y, string name = null)
392392
{
393393
if (tf.Context.executing_eagerly())
394394
{
395395
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
396-
"SelectV2", name,
396+
"Select", name,
397397
null,
398-
condition, t, e);
398+
condition, x, y);
399399

400400
return results[0];
401401
}
402402

403-
var _op = tf.OpDefLib._apply_op_helper("Select", name, new { condition, t, e });
403+
var _op = tf.OpDefLib._apply_op_helper("Select", name, new { condition, t = x, e = y });
404404
return _op.outputs[0];
405405
}
406406

@@ -580,26 +580,33 @@ public static Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] stri
580580
/// <param name="shrink_axis_mask">An optional `int`. Defaults to `0`.</param>
581581
/// <param name="name">A name for the operation (optional).</param>
582582
/// <returns>A `Tensor`. Has the same type as `dy`.</returns>
583-
public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy,
583+
public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy,
584584
int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0,
585585
int shrink_axis_mask = 0, string name = null)
586-
{
587-
var op = tf.OpDefLib._apply_op_helper("StridedSliceGrad", name: name, args: new
588-
{
589-
shape,
590-
begin,
591-
end,
592-
strides,
593-
dy,
594-
begin_mask,
595-
end_mask,
596-
ellipsis_mask,
597-
new_axis_mask,
598-
shrink_axis_mask
599-
});
600-
601-
return op.output;
602-
}
586+
=> tf.Context.RunInAutoMode(()
587+
=> tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new
588+
{
589+
shape,
590+
begin,
591+
end,
592+
strides,
593+
dy,
594+
begin_mask,
595+
end_mask,
596+
ellipsis_mask,
597+
new_axis_mask,
598+
shrink_axis_mask
599+
}).output, ()
600+
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
601+
"StridedSliceGrad", name,
602+
null,
603+
shape, begin, end, strides, dy,
604+
"begin_mask", begin_mask,
605+
"end_mask", end_mask,
606+
"ellipsis_mask", ellipsis_mask,
607+
"new_axis_mask", new_axis_mask,
608+
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
609+
shape, begin, end, strides, dy);
603610

604611
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
605612
{

0 commit comments

Comments
 (0)