Skip to content

Commit 2602fb9

Browse files
committed
Implement EagerRunner.
1 parent 91fb30f commit 2602fb9

11 files changed

Lines changed: 726 additions & 15 deletions
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using Tensorflow.Gradients;
6+
using static Tensorflow.Binding;
7+
using static Tensorflow.tensorflow;
8+
9+
namespace Tensorflow.Eager
10+
{
11+
public partial class EagerRunner
12+
{
13+
bool RecordGradient(string op_name,
14+
Tensor[] inputs,
15+
object[] attrs,
16+
Tensor[] results)
17+
{
18+
var input_ids = MakeTensorIDList(inputs);
19+
var input_dtypes = MakeTensorDtypeList(inputs);
20+
21+
bool should_record = false;
22+
foreach (var tape in tf.GetTapeSet())
23+
{
24+
if(tape.ShouldRecord(input_ids, input_dtypes))
25+
{
26+
should_record = true;
27+
break;
28+
}
29+
}
30+
31+
if (!should_record)
32+
{
33+
/*for (TFE_Py_ForwardAccumulator* accumulator : SafeAccumulatorSet())
34+
{
35+
if (accumulator->accumulator->ShouldRecord(input_ids, input_dtypes))
36+
{
37+
should_record = true;
38+
break;
39+
}
40+
}*/
41+
}
42+
43+
if (!should_record) return should_record;
44+
45+
Tensor[] op_outputs;
46+
bool op_outputs_tuple_created = false;
47+
var unused_output_indices = gradient_exclustions.OpGradientUnusedOutputIndices(op_name);
48+
if (unused_output_indices != null)
49+
{
50+
if (unused_output_indices.Length == 0)
51+
op_outputs = new Tensor[0];
52+
else
53+
{
54+
op_outputs_tuple_created = true;
55+
// op_outputs = CopySequenceSettingIndicesToNull(results, *unused_output_indices);
56+
}
57+
}
58+
else
59+
op_outputs = results;
60+
61+
Tensor[] op_inputs;
62+
bool op_inputs_tuple_created = false;
63+
var unused_input_indices = gradient_exclustions.OpGradientUnusedInputIndices(op_name);
64+
if(unused_input_indices != null)
65+
{
66+
if (unused_input_indices.Length == 0)
67+
op_inputs = new Tensor[0];
68+
else
69+
{
70+
op_inputs_tuple_created = true;
71+
// op_inputs = CopySequenceSettingIndicesToNull(inputs, *unused_input_indices);
72+
}
73+
}
74+
else
75+
op_inputs = inputs;
76+
77+
TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes,
78+
() => GetGradientFunction(op_name, inputs, attrs, results));
79+
80+
81+
return true;
82+
}
83+
84+
BackwardFunction GetGradientFunction(string op_name,
85+
Tensor[] op_inputs,
86+
object[] attrs,
87+
Tensor[] op_outputs)
88+
=> (output_grads, unneeded_gradients) =>
89+
{
90+
var gradients = ops.gradientFunctions[op_name](new EagerOperation
91+
{
92+
Name = op_name,
93+
NumInputs = op_inputs.Length,
94+
Inputs = op_inputs,
95+
NumOutputs = op_outputs.Length,
96+
Outputs = op_outputs,
97+
SkipInputIndices = unneeded_gradients,
98+
Attrs = attrs
99+
}, output_grads);
100+
101+
return gradients;
102+
};
103+
104+
bool CouldForwardprop()
105+
{
106+
return HasAccumulator();
107+
}
108+
109+
bool CouldBackprop()
110+
{
111+
return HasGradientTape();
112+
}
113+
114+
long[] MakeTensorIDList(Tensor[] tensors)
115+
{
116+
return tensors.Select(x => x.Id).ToArray();
117+
}
118+
119+
TF_DataType[] MakeTensorDtypeList(Tensor[] tensors)
120+
{
121+
return tensors.Select(x => x.dtype).ToArray();
122+
}
123+
}
124+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Eager
6+
{
7+
public partial class EagerRunner
8+
{
9+
bool RunCallbacks(FastPathOpExecInfo op_exec_info,
10+
int num_inferred_attrs,
11+
Tensor[] inputs,
12+
object[] attrs,
13+
Tensor[] flattened_result)
14+
{
15+
if (op_exec_info.run_gradient_callback)
16+
{
17+
if (!RecordGradient(op_exec_info.op_name, inputs, attrs,
18+
flattened_result))
19+
{
20+
return false;
21+
}
22+
}
23+
24+
if (op_exec_info.run_post_exec_callbacks)
25+
{
26+
27+
}
28+
29+
return true;
30+
}
31+
}
32+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
using System.Collections.Generic;
2+
using System.Linq;
3+
using System;
4+
using static Tensorflow.Binding;
5+
6+
namespace Tensorflow.Eager
7+
{
8+
/// <summary>
9+
/// python\eager\pywrap_tfe_src.cc
10+
/// </summary>
11+
public partial class EagerRunner
12+
{
13+
public Tensor[] TFE_Execute(Context ctx,
14+
string device_name,
15+
string op_name,
16+
Tensor[] inputs,
17+
object[] attrs,
18+
int num_outputs)
19+
=> TFE_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs, num_outputs);
20+
21+
public Tensor[] TFE_ExecuteCancelable(Context ctx,
22+
string device_name,
23+
string op_name,
24+
Tensor[] inputs,
25+
object[] attrs,
26+
int num_outputs)
27+
{
28+
var status = tf.status;
29+
var op = GetOp(ctx, op_name, status);
30+
status.Check(true);
31+
c_api.TFE_OpSetDevice(op, device_name, status.Handle);
32+
if (status.ok())
33+
{
34+
for (int i = 0; i < inputs.Length; ++i)
35+
{
36+
IntPtr tensor_handle;
37+
switch (inputs[i])
38+
{
39+
case EagerTensor et:
40+
tensor_handle = et.EagerTensorHandle;
41+
break;
42+
default:
43+
tensor_handle = c_api.TFE_NewTensorHandle(inputs[i], status.Handle);
44+
break;
45+
}
46+
c_api.TFE_OpAddInput(op, tensor_handle, status.Handle);
47+
}
48+
}
49+
if (status.ok())
50+
SetOpAttrs(op, attrs, status.Handle);
51+
52+
var outputs = new IntPtr[num_outputs];
53+
if (status.ok())
54+
{
55+
c_api.TFE_Execute(op, outputs, ref num_outputs, status.Handle);
56+
status.Check(true);
57+
}
58+
return outputs.Select(x => new EagerTensor(x)).ToArray();
59+
}
60+
}
61+
}

0 commit comments

Comments
 (0)