forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTape.cs
More file actions
72 lines (62 loc) · 2.03 KB
/
Tape.cs
File metadata and controls
72 lines (62 loc) · 2.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using Tensorflow.Eager;
namespace Tensorflow.Gradients
{
public class Tape : DisposableObject
{
public int nesting_id { get; set; }
public Tape(bool persistent, bool watch_accessed_variables)
{
_handle = c_api.TFE_TapeSetNew(persistent, watch_accessed_variables);
}
public void watch(EagerTensor x)
{
c_api.TFE_TapeWatch(_handle, x.EagerTensorHandle);
}
public void pop_tape(Tape tape)
{
c_api.TFE_TapeSetRemove(tape);
}
public static void variable_accessed(ResourceVariable variable)
{
c_api.TFE_TapeVariableAccessed(variable);
}
public unsafe ResourceVariable[] watched_variables()
{
BindingArray result = c_api.TFE_TapeWatchedVariables(_handle);
var variables = new ResourceVariable[result.length];
for (int i = 0; i < result.length; i++)
{
var handle = *((IntPtr*)result.array + i);
var tensor = c_api.ResourceVariable_Handle(handle);
variables[i] = new ResourceVariable(handle, tensor);
}
return variables;
}
public static bool IsDtypeTrainable(DataType dtype)
{
switch (dtype)
{
case DataType.DtHalf:
case DataType.DtBfloat16:
case DataType.DtFloat:
case DataType.DtDouble:
case DataType.DtComplex64:
case DataType.DtComplex128:
case DataType.DtResource:
case DataType.DtVariant:
return true;
default:
return false;
}
}
protected override void DisposeUnmanagedResources(IntPtr handle)
{
}
public static implicit operator IntPtr(Tape tape)
=> tape._handle;
}
}