Skip to content

Commit 8dfc3b7

Browse files
committed
1 parent 3b93c7b commit 8dfc3b7

13 files changed

Lines changed: 105 additions & 60 deletions

File tree

src/TensorFlowNET.Core/Gradients/math_grad.py.cs

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,7 @@ public static (Tensor, Tensor) _SubGrad(Operation op, Tensor grad)
7272

7373
public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad)
7474
{
75-
if (x.NDims == 0 && y.NDims == 0 && grad.NDims == 0) return true;
76-
77-
return string.Join(",", x.shape).Equals(string.Join(",", y.shape)) &&
78-
string.Join(",", x.shape).Equals(string.Join(",", grad.shape));
75+
return x.NDims == y.NDims && y.NDims == grad.NDims && x.NDims > -1;
7976
}
8077

8178
public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad)
@@ -110,14 +107,15 @@ public static (Tensor, Tensor) _RealDivGrad(Operation op, Tensor grad)
110107
x = math_ops.conj(x);
111108
y = math_ops.conj(y);
112109

113-
var realdiv1 = gen_math_ops.real_div(grad, y);
114-
var reduce_sum1 = math_ops.reduce_sum(realdiv1, rx);
115-
var realdiv2 = gen_math_ops.real_div(-x, y);
116-
var realdiv3 = gen_math_ops.real_div(realdiv2, y);
117-
var mul = grad * realdiv3;
118-
var reduce_sum2 = math_ops.reduce_sum(mul, ry);
110+
var realdiv1 = gen_math_ops.real_div(-x, y);
111+
var realdiv2 = gen_math_ops.real_div(realdiv1, y);
112+
var reduce_sum1 = math_ops.reduce_sum(grad * realdiv2, ry);
113+
var reshape1 = gen_array_ops.reshape(reduce_sum1, sy);
114+
var realdiv3 = gen_math_ops.real_div(grad, y);
115+
var reduce_sum2 = math_ops.reduce_sum(realdiv3, rx);
116+
var reshape2 = gen_array_ops.reshape(reduce_sum2, sx);
119117

120-
return (gen_array_ops.reshape(reduce_sum1, sx), gen_array_ops.reshape(reduce_sum2, sy));
118+
return (reshape2, reshape1);
121119
}
122120

123121
public static (Tensor, Tensor) _PowGrad(Operation op, Tensor grad)
@@ -135,17 +133,16 @@ public static (Tensor, Tensor) _PowGrad(Operation op, Tensor grad)
135133
var gx = gen_array_ops.reshape(math_ops.reduce_sum(grad * y * gen_math_ops.pow(x, y - 1.0), rx), sx);
136134
Tensor log_x = null;
137135
// Avoid false singularity at x = 0
136+
Tensor mask = null;
138137
if (x.dtype.is_complex())
139-
{
140138
throw new NotImplementedException("x.dtype.is_complex()");
141-
}
142139
else
143-
{
144-
var x1 = gen_array_ops.log(x);
145-
var y1 = array_ops.zeros_like(x);
146-
log_x = array_ops.where(x > 0.0, x1, y1);
147-
}
148-
140+
mask = x > 0.0f;
141+
var ones = array_ops.ones_like(x);
142+
var safe_x = array_ops.where(mask, x, ones);
143+
var x1 = gen_array_ops.log(safe_x);
144+
var y1 = array_ops.zeros_like(x);
145+
log_x = array_ops.where(mask, x1, y1);
149146
var gy = gen_array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy);
150147

151148
return (gx, gy);

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,13 @@ public object get_collection(string name, string scope = "")
357357
return _collections.ContainsKey(name) ? _collections[name] : null;
358358
}
359359

360+
public object get_collection_ref(string name)
361+
{
362+
if (!_collections.ContainsKey(name))
363+
_collections[name] = new List<object>();
364+
return _collections[name];
365+
}
366+
360367
public void Dispose()
361368
{
362369
c_api.TF_DeleteGraph(_handle);

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,43 @@ public static Tensor rank(Tensor input, string name = "")
5555
return math_ops.rank_internal(input, name, optimize: true);
5656
}
5757

58+
/// <summary>
59+
/// Creates a tensor with all elements set to 1.
60+
/// </summary>
61+
/// <param name="tensor"></param>
62+
/// <param name="dtype"></param>
63+
/// <param name="name"></param>
64+
/// <param name="optimize"></param>
65+
/// <returns></returns>
66+
public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool optimize = true)
67+
=> ones_like_impl(tensor, dtype, name, optimize);
68+
69+
private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true)
70+
{
71+
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "ones_like", new { tensor }), scope =>
72+
{
73+
name = scope;
74+
var tensor1 = ops.convert_to_tensor(tensor, name: "tensor");
75+
var ones_shape = shape_internal(tensor1, optimize: optimize);
76+
if (dtype == TF_DataType.DtInvalid)
77+
dtype = tensor1.dtype;
78+
var ret = ones(ones_shape, dtype: dtype, name: name);
79+
ret.shape = tensor1.shape;
80+
return ret;
81+
});
82+
}
83+
84+
public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "")
85+
{
86+
dtype = dtype.as_base_dtype();
87+
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "ones", new { shape }), scope =>
88+
{
89+
name = scope;
90+
var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name);
91+
return output;
92+
});
93+
}
94+
5895
public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, string name = "")
5996
{
6097
if( x == null && y == null)

src/TensorFlowNET.Core/Operations/math_ops.py.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ public static Tensor range(object start, object limit = null, object delta = nul
111111
if (delta == null)
112112
delta = 1;
113113

114-
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Range", new object[] { start, limit, delta }), scope =>
114+
return with<ops.name_scope, Tensor>(new ops.name_scope(name, "Range", new object[] { start, limit, delta }), scope =>
115115
{
116116
name = scope;
117117
var start1 = ops.convert_to_tensor(start, name: "start");
@@ -124,15 +124,15 @@ public static Tensor range(object start, object limit = null, object delta = nul
124124

125125
public static Tensor floordiv(Tensor x, Tensor y, string name = "")
126126
{
127-
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "floordiv", new object[] { }), scope =>
127+
return with<ops.name_scope, Tensor>(new ops.name_scope("", "floordiv", new { x, y }), scope =>
128128
{
129-
return gen_math_ops.floor_div(x, y, name);
129+
return gen_math_ops.floor_div(x, y, scope);
130130
});
131131
}
132132

133133
public static Tensor rank_internal(Tensor input, string name = "", bool optimize = true)
134134
{
135-
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Rank", new List<Tensor> { input }), scope =>
135+
return with<ops.name_scope, Tensor>(new ops.name_scope(name, "Rank", new List<Tensor> { input }), scope =>
136136
{
137137
name = scope;
138138
var input_tensor = ops.convert_to_tensor(input);

src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -63,31 +63,6 @@ private IntPtr Allocate(NDArray nd)
6363
break;
6464
case "Single":
6565
Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size);
66-
/*if (nd.size > 1)
67-
{
68-
var bb = nd.Data<byte>();
69-
var bytes = Marshal.AllocHGlobal(bb.Length);
70-
Marshal.Copy(bb, 0, bytes, bb.Length);
71-
ulong bytes_len = c_api.TF_StringEncodedSize((ulong)bb.Length);
72-
var dataTypeByte = ToTFDataType(nd.dtype);
73-
// shape
74-
var dims2 = nd.shape.Select(x => (long)x).ToArray();
75-
76-
var tfHandle2 = c_api.TF_AllocateTensor(dataTypeByte,
77-
dims2,
78-
nd.ndim,
79-
bytes_len + sizeof(Int64));
80-
81-
dotHandle = c_api.TF_TensorData(tfHandle2);
82-
Marshal.WriteInt64(dotHandle, 0);
83-
c_api.TF_StringEncode(bytes, (ulong)bb.Length, dotHandle + sizeof(Int64), bytes_len, status);
84-
return tfHandle2;
85-
}
86-
else
87-
{
88-
Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size);
89-
}*/
90-
9166
break;
9267
case "Double":
9368
Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size);

src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ public partial class Tensor
2727
public static Tensor operator %(Tensor x, Tensor y) => BinaryOpWrapper("mod", x, y);
2828

2929
public static Tensor operator >(Tensor x, int y) => gen_array_ops.greater(x, y);
30+
public static Tensor operator >(Tensor x, float y) => gen_array_ops.greater(x, y);
3031
public static Tensor operator >(Tensor x, double y) => gen_array_ops.greater(x, y);
3132
public static Tensor operator <(Tensor x, int y) => gen_array_ops.less(x, y);
33+
public static Tensor operator <(Tensor x, float y) => gen_array_ops.less(x, y);
3234
public static Tensor operator <(Tensor x, double y) => gen_array_ops.less(x, y);
3335

3436
private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y)

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,12 @@ public long[] shape
6868
c_api.TF_GraphSetTensorShape(this.Graph, this._as_tf_output(), value, value.Length, status);
6969
}
7070
}
71-
71+
72+
public int[] _shape_tuple()
73+
{
74+
return null;
75+
}
76+
7277
/// <summary>
7378
/// number of dimensions
7479
/// 0 Scalar (magnitude only)

src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace Tensorflow
66
{
77
public class GradientDescentOptimizer : Optimizer
88
{
9-
public GradientDescentOptimizer(double learning_rate, bool use_locking = false, string name = "GradientDescent")
9+
public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent")
1010
: base(learning_rate, use_locking, name)
1111
{
1212
LearningRate = learning_rate;

src/TensorFlowNET.Core/Train/Optimizer.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ public abstract class Optimizer
2020
public static int GATE_GRAPH = 2;
2121

2222
public string Name { get; set; }
23-
public double LearningRate { get; set; }
23+
public float LearningRate { get; set; }
2424
public Tensor LearningRateTensor { get; set; }
2525
public bool _use_locking;
2626
public Dictionary<string, object> _slots;
2727
public Dictionary<string, object> _non_slot_dict;
2828
public Dictionary<string, object> _deferred_slot_restorations;
2929

30-
public Optimizer(double learning_rate, bool use_locking, string name = "")
30+
public Optimizer(float learning_rate, bool use_locking, string name = "")
3131
{
3232
if (String.IsNullOrEmpty(name))
3333
throw new NotImplementedException("Must specify the optimizer name");
@@ -114,6 +114,13 @@ public Operation apply_gradients(Tuple<Tensor, RefVariable>[] grads_and_vars, Te
114114

115115
}
116116

117+
if (!tf.context.executing_eagerly())
118+
{
119+
var train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) as List<object>;
120+
if (!train_op.Contains(apply_updates))
121+
train_op.Add(apply_updates);
122+
}
123+
117124
return apply_updates;
118125
});
119126
}

src/TensorFlowNET.Core/Train/tf.optimizers.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ public static partial class tf
99
{
1010
public static class train
1111
{
12-
public static Optimizer GradientDescentOptimizer(double learning_rate) => new GradientDescentOptimizer(learning_rate);
12+
public static Optimizer GradientDescentOptimizer(float learning_rate) => new GradientDescentOptimizer(learning_rate);
1313

1414
public static Saver Saver() => new Saver();
1515

0 commit comments

Comments
 (0)