Skip to content

Commit e93f112

Browse files
committed
tf.Variable() in eager mode
1 parent af9b64c commit e93f112

55 files changed

Lines changed: 915 additions & 670 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/TensorFlowNET.Core/APIs/tf.gradients.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,15 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Tensorflow.Gradients;
18+
1719
namespace Tensorflow
1820
{
1921
public partial class tensorflow
2022
{
23+
public GradientActor GradientTape()
24+
=> new GradientActor();
25+
2126
public Tensor[] gradients(Tensor[] ys,
2227
Tensor[] xs,
2328
Tensor[] grad_ys = null,
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using Tensorflow.Keras;
18+
using Tensorflow.Keras.Engine;
19+
using Tensorflow.Keras.Optimizers;
20+
21+
namespace Tensorflow
22+
{
23+
public partial class tensorflow
24+
{
25+
public KerasOptimizers optimizers => new KerasOptimizers();
26+
27+
public class KerasOptimizers
28+
{
29+
public SGD SGD(float learning_rate) => new SGD(learning_rate);
30+
}
31+
}
32+
}

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ public static void add<T>(this IList<T> list, T element)
4141
public static void append<T>(this IList<T> list, T element)
4242
=> list.Add(element);
4343

44+
public static T[] concat<T>(this IList<T> list1, IList<T> list2)
45+
{
46+
var list = new List<T>();
47+
list.AddRange(list1);
48+
list.AddRange(list2);
49+
return list.ToArray();
50+
}
51+
4452
public static void extend<T>(this List<T> list, IEnumerable<T> elements)
4553
=> list.AddRange(elements);
4654

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using NumSharp;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
using Tensorflow.Eager;
6+
7+
namespace Tensorflow.Eager
8+
{
9+
public partial class EagerTensor
10+
{
11+
public static explicit operator TFE_TensorHandle(EagerTensor tensor)
12+
=> tensor.tfe_tensor_handle;
13+
}
14+
}

src/TensorFlowNET.Core/Eager/EagerTensor.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,39 @@
55

66
namespace Tensorflow.Eager
77
{
8-
public class EagerTensor : Tensor
8+
public partial class EagerTensor : Tensor
99
{
10+
Status status = new Status();
11+
TFE_TensorHandle tfe_tensor_handle;
1012
public EagerTensor(IntPtr handle) : base(handle)
1113
{
14+
tfe_tensor_handle = handle;
15+
_handle = c_api.TFE_TensorHandleResolve(handle, status);
1216
}
1317

1418
public EagerTensor(string value, string device_name) : base(value)
1519
{
20+
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
1621
}
1722

1823
public EagerTensor(int value, string device_name) : base(value)
1924
{
25+
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
2026
}
2127

2228
public EagerTensor(float[] value, string device_name) : base(value)
2329
{
30+
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
2431
}
2532

2633
public EagerTensor(double[] value, string device_name) : base(value)
2734
{
35+
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
2836
}
2937

3038
public EagerTensor(NDArray value, string device_name) : base(value)
3139
{
40+
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
3241
}
3342

3443
public override string ToString()
@@ -51,6 +60,8 @@ private string GetFormattedString()
5160
{
5261
case TF_DataType.TF_STRING:
5362
return $"b'{(string)nd}'";
63+
case TF_DataType.TF_BOOL:
64+
return (nd.GetByte(0) > 0).ToString();
5465
default:
5566
return nd.ToString();
5667
}

src/TensorFlowNET.Core/Eager/Execute.cs

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,31 +32,41 @@ public Tensor execute(Context ctx, string op_name, Tensor[] inputs, object[] att
3232
ctx.ensure_initialized();
3333
using (var status = new Status())
3434
{
35-
var retVals = wrap_tfe_src.TFE_Py_Execute(ctx, ctx.device_name, op_name, inputs, attrs, 1, status);
35+
var retVals = wrap_tfe_src.TFE_Execute(ctx, ctx.device_name, op_name, inputs, attrs, 1, status);
3636

37-
var t = c_api.TFE_TensorHandleResolve(retVals[0], status);
38-
status.Check(true);
39-
40-
return new EagerTensor(t);
37+
return new EagerTensor(retVals[0]);
4138
}
4239
}
4340

44-
public (TF_DataType, Tensor) args_to_matching_eager(Tensor[] l, Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid)
41+
public (TF_DataType, Tensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null)
4542
{
46-
var dtype = default_dtype;
47-
if(dtype == TF_DataType.DtInvalid)
48-
{
49-
var tensor = ops.convert_to_tensor(l, dtype, preferred_dtype: default_dtype, ctx: ctx);
43+
if (args.Length == 0 && default_dtype != TF_DataType.DtInvalid)
44+
return (default_dtype, null);
5045

51-
if (dtype == TF_DataType.DtInvalid)
52-
dtype = tensor.dtype;
46+
if (args.Count(x => x is EagerTensor) == args.Length)
47+
return ((args[0] as EagerTensor).dtype, args.Select(x => x as EagerTensor).ToArray());
5348

54-
return (dtype, tensor);
49+
var dtype = TF_DataType.DtInvalid;
50+
foreach (var x in args)
51+
{
52+
if (x is EagerTensor et)
53+
dtype = et.dtype;
5554
}
56-
else
55+
56+
if (dtype == TF_DataType.DtInvalid)
5757
{
58-
return (dtype, l[0]);
58+
var ret = new List<Tensor>();
59+
foreach (var t in args)
60+
{
61+
ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx));
62+
if (dtype == TF_DataType.DtInvalid)
63+
dtype = ret.Last().dtype;
64+
}
65+
66+
return (dtype, ret.ToArray());
5967
}
68+
else
69+
throw new NotImplementedException("");
6070
}
6171

6272
public void record_gradient(string op_name, InputList inputs, Dictionary<string, object> attrs, Tensor[] results, string name = null)

src/TensorFlowNET.Core/Eager/c_api.eager.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,16 @@ public partial class c_api
101101
[DllImport(TensorFlowLibName)]
102102
public static extern TFE_Op TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status);
103103

104+
/// <summary>
105+
///
106+
/// </summary>
107+
/// <param name="ctx">TFE_Context*</param>
108+
/// <param name="op_or_function_name">const char*</param>
109+
/// <param name="status">TF_Status*</param>
110+
/// <param name="op_to_reset">TFE_Op*</param>
111+
[DllImport(TensorFlowLibName)]
112+
public static extern void TFE_OpReset(IntPtr ctx, string op_or_function_name, IntPtr status, IntPtr op_to_reset);
113+
104114
/// <summary>
105115
///
106116
/// </summary>

src/TensorFlowNET.Core/Eager/wrap_tfe_src.RecordGradient.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using System.Collections.Generic;
22
using System.Linq;
33
using System;
4-
using static Tensorflow.OpDef.Types;
4+
using Tensorflow.Gradients;
55

66
namespace Tensorflow.Eager
77
{

src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_Execute.cs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,40 @@ namespace Tensorflow.Eager
1010
/// </summary>
1111
public partial class wrap_tfe_src
1212
{
13-
public static IntPtr[] TFE_Py_Execute(Context ctx,
13+
public static IntPtr[] TFE_Execute(Context ctx,
1414
string device_name,
1515
string op_name,
1616
Tensor[] inputs,
1717
object[] attrs,
1818
int num_outputs,
1919
Status status)
20-
=> TFE_Py_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs, num_outputs, status);
20+
=> TFE_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs, num_outputs, status);
2121

22-
public static IntPtr[] TFE_Py_ExecuteCancelable(Context ctx,
22+
public static IntPtr[] TFE_ExecuteCancelable(Context ctx,
2323
string device_name,
2424
string op_name,
2525
Tensor[] inputs,
2626
object[] attrs,
2727
int num_outputs,
2828
Status status)
2929
{
30-
var op = c_api.TFE_NewOp(ctx, op_name, status);
30+
var op = GetOp(ctx, op_name, status);
3131
status.Check(true);
3232
c_api.TFE_OpSetDevice(op, device_name, status);
3333
if(status.ok())
3434
{
3535
for (int i = 0; i < inputs.Length; ++i)
3636
{
37-
var tensor_handle = c_api.TFE_NewTensorHandle(inputs[i], status);
37+
TFE_TensorHandle tensor_handle;
38+
switch (inputs[i])
39+
{
40+
case EagerTensor et:
41+
tensor_handle = (TFE_TensorHandle)et;
42+
break;
43+
default:
44+
tensor_handle = c_api.TFE_NewTensorHandle(inputs[i], status);
45+
break;
46+
}
3847
c_api.TFE_OpAddInput(op, tensor_handle, status);
3948
}
4049
}

src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public static EagerTensor TFE_FastPathExecute(Context ctx,
2222
var attr_list_sizes = new Dictionary<string, long>();
2323
using (var status = new Status())
2424
{
25-
var op = c_api.TFE_NewOp(ctx, opName, status);
25+
var op = GetOp(ctx, opName, status);
2626

2727
var op_def = Graph.TFE_GetOpDef(opName);
2828

@@ -101,11 +101,31 @@ public static EagerTensor TFE_FastPathExecute(Context ctx,
101101
c_api.TFE_Execute(op, retVals, ref num_retvals, status);
102102
status.Check(true);
103103

104-
var t = c_api.TFE_TensorHandleResolve(retVals[0], status);
105-
status.Check(true);
104+
return num_retvals == 0 ? null : new EagerTensor(retVals[0]);
105+
}
106+
}
106107

107-
return new EagerTensor(t);
108+
private static TFE_Op GetOp(Context ctx, string op_or_function_name, Status status)
109+
{
110+
var maybe_op = ReleaseThreadLocalOp();
111+
if (maybe_op != IntPtr.Zero)
112+
{
113+
c_api.TFE_OpReset(ctx, op_or_function_name, status, maybe_op);
114+
}
115+
else
116+
{
117+
maybe_op = c_api.TFE_NewOp(ctx, op_or_function_name, status);
118+
op = maybe_op;
108119
}
120+
121+
status.Check(true);
122+
return maybe_op;
123+
}
124+
125+
static TFE_Op op;
126+
private static TFE_Op ReleaseThreadLocalOp()
127+
{
128+
return op;
109129
}
110130

111131
/// <summary>
@@ -126,19 +146,19 @@ private static bool AddInputToOp(object inputs,
126146
{
127147
TFE_TensorHandle input_handle;
128148

149+
// ConvertToTensor();
129150
switch (inputs)
130151
{
131-
case Tensor input:
132-
input_handle = c_api.TFE_NewTensorHandle(input, status);
152+
case EagerTensor input:
153+
input_handle = (TFE_TensorHandle)input;
133154
break;
134-
case Tensor[] input_list:
135-
input_handle = c_api.TFE_NewTensorHandle(input_list[0], status);
155+
case EagerTensor[] input_list:
156+
input_handle = (TFE_TensorHandle)input_list[0];
136157
break;
137158
default:
138159
throw new NotImplementedException("");
139160
}
140161

141-
142162
if(add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr))
143163
{
144164
var dtype = c_api.TFE_TensorHandleDataType(input_handle);

0 commit comments

Comments
 (0)