Skip to content

Commit 12ea493

Browse files
committed
Add ConcreteFunction to support dataset map.
1 parent 14f8adb commit 12ea493

9 files changed

Lines changed: 84 additions & 12 deletions

File tree

src/TensorFlowNET.Core/Data/DatasetV2.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs
5252

5353
public IDatasetV2 map(Func<Tensor, Tensor> map_func,
5454
bool use_inter_op_parallelism = true,
55-
bool preserve_cardinality = false,
55+
bool preserve_cardinality = true,
5656
bool use_legacy_function = false)
5757
=> new MapDataset(this,
5858
map_func,

src/TensorFlowNET.Core/Data/MapDataset.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
5+
using Tensorflow.Functions;
6+
using Tensorflow.Graphs;
7+
using static Tensorflow.Binding;
48

59
namespace Tensorflow
610
{
@@ -15,12 +19,10 @@ public MapDataset(IDatasetV2 input_dataset,
1519
bool preserve_cardinality = false,
1620
bool use_legacy_function = false) : base(input_dataset)
1721
{
18-
foreach(var input in input_dataset)
19-
{
20-
var data = map_func(input.Item1);
21-
}
22+
var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype);
2223

2324
variant_tensor = ops.map_dataset(input_dataset.variant_tensor,
25+
func,
2426
output_types,
2527
output_shapes);
2628
}

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using Tensorflow.Util;
77
using System.Runtime.InteropServices;
88
using Tensorflow.Contexts;
9+
using Tensorflow.Functions;
910

1011
namespace Tensorflow.Eager
1112
{
@@ -385,7 +386,10 @@ bool SetOpAttrScalar(Context ctx, SafeOpHandle op,
385386
status.Check(true);
386387
break;
387388
case TF_AttrType.TF_ATTR_FUNC:
388-
c_api.TFE_OpSetAttrFunctionName(op, key, value.ToString(), value.ToString().Length);
389+
if (value is ConcreteFunction func)
390+
c_api.TFE_OpSetAttrFunctionName(op, key, func.Name, func.Name.Length);
391+
else
392+
throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC");
389393
break;
390394
default:
391395
throw new NotImplementedException($"SetOpAttrScalar for {type}");
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using Tensorflow.Graphs;
6+
using static Tensorflow.Binding;
7+
8+
namespace Tensorflow.Functions
9+
{
10+
/// <summary>
11+
///
12+
/// </summary>
13+
public class ConcreteFunction : IDisposable
14+
{
15+
public string Name => _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle));
16+
IntPtr _handle;
17+
18+
public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype)
19+
{
20+
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
21+
22+
tf.compat.v1.disable_eager_execution();
23+
24+
// IntPtr func_handle;
25+
using (var graph = new FuncGraph(func_name))
26+
{
27+
graph.as_default();
28+
var input = tf.placeholder(dtype);
29+
var output = func(input);
30+
31+
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
32+
_handle = graph.ToGraph(opers,
33+
new Operation[] { input },
34+
new Operation[] { output },
35+
null);
36+
37+
c_api.TFE_ContextAddFunction(tf.Context.Handle, _handle, tf.Status.Handle);
38+
}
39+
40+
tf.enable_eager_execution();
41+
}
42+
43+
public Tensor Execute(Tensor arg)
44+
{
45+
var result = tf.Runner.TFE_Execute(tf.Context,
46+
tf.Context.DeviceName,
47+
Name,
48+
new[] { arg },
49+
null,
50+
1);
51+
return result[0];
52+
}
53+
54+
public void Dispose()
55+
{
56+
c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle);
57+
}
58+
}
59+
}

src/TensorFlowNET.Core/Functions/c_api.function.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ namespace Tensorflow
2121
{
2222
public partial class c_api
2323
{
24+
[DllImport(TensorFlowLibName)]
25+
public static extern void TF_DeleteFunction(IntPtr handle);
26+
2427
/// <summary>
2528
/// Write out a serialized representation of `func` (as a FunctionDef protocol
2629
/// message) to `output_func_def` (allocated by TF_NewBuffer()).

src/TensorFlowNET.Core/Graphs/FuncGraph.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Linq;
4+
using System.Runtime.InteropServices;
45
using System.Text;
6+
using Tensorflow.Functions;
57
using static Tensorflow.Binding;
68

79
namespace Tensorflow.Graphs

src/TensorFlowNET.Core/Operations/dataset_ops.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Collections.Generic;
33
using System.Text;
44
using Tensorflow.Framework.Models;
5+
using Tensorflow.Functions;
56
using static Tensorflow.Binding;
67

78
namespace Tensorflow
@@ -419,7 +420,7 @@ public ITensorOrOperation make_iterator(Tensor dataset, Tensor iterator, string
419420
/// <param name="iterator"></param>
420421
/// <param name="name"></param>
421422
/// <returns></returns>
422-
public Tensor map_dataset(Tensor dataset, TF_DataType[] output_types, TensorShape[] output_shapes,
423+
public Tensor map_dataset(Tensor dataset, ConcreteFunction f, TF_DataType[] output_types, TensorShape[] output_shapes,
423424
bool use_inter_op_parallelism = true, bool preserve_cardinality = false, string name = null)
424425
{
425426
if (tf.Context.executing_eagerly())
@@ -428,7 +429,7 @@ public Tensor map_dataset(Tensor dataset, TF_DataType[] output_types, TensorShap
428429
"MapDataset", name,
429430
null,
430431
dataset, new Tensor[0],
431-
"f", "MapDataset",
432+
"f", f,
432433
"output_types", output_types,
433434
"output_shapes", output_shapes,
434435
"use_inter_op_parallelism", use_inter_op_parallelism,

test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,17 +118,17 @@ public void Skip()
118118
}
119119
}
120120

121-
[TestMethod, Ignore]
121+
[TestMethod]
122122
public void Map()
123123
{
124124
long value = 0;
125125

126-
var dataset = tf.data.Dataset.range(3);
127-
var dataset1 = dataset.map(x => x);
126+
var dataset = tf.data.Dataset.range(0, 2);
127+
dataset = dataset.map(x => x + 10);
128128

129129
foreach (var item in dataset)
130130
{
131-
Assert.AreEqual(value, (long)item.Item1);
131+
Assert.AreEqual(value + 10, (long)item.Item1);
132132
value++;
133133
}
134134
}

test/TensorFlowNET.UnitTest/NativeAPI/c_test_util.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Diagnostics.CodeAnalysis;
33
using System.Runtime.CompilerServices;
44
using Tensorflow;
5+
using Tensorflow.Functions;
56
using Tensorflow.Util;
67
using Buffer = Tensorflow.Buffer;
78

0 commit comments

Comments
 (0)