forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTapeTensor.cs
More file actions
65 lines (56 loc) · 1.54 KB
/
TapeTensor.cs
File metadata and controls
65 lines (56 loc) · 1.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
using OneOf;
using static Tensorflow.Binding;
namespace Tensorflow.Gradients
{
public class TapeTensor
{
internal Tensor tensor;
internal long id;
internal TF_DataType dtype;
internal OneOf<Shape, Tensor> shape;
public TapeTensor(long id, TF_DataType dtype, Shape shape)
{
this.id = id;
this.dtype = dtype;
this.shape = shape;
}
public TapeTensor(long id, TF_DataType dtype, Tensor shape)
{
this.id = id;
this.dtype = dtype;
this.shape = shape;
}
public TapeTensor(Tensor tensor)
{
this.id = tensor.Id;
this.dtype = tensor.dtype;
this.shape = tensor.shape;
this.tensor = tensor;
}
public long GetID() => id;
public Tensor ZerosLike()
{
if(dtype == dtypes.resource)
{
return null;
}
if(shape.Index == 1)
{
return tf.zeros_like(shape.AsT1);
}
return tf.zeros(shape.AsT0, dtype);
}
public Tensor OnesLike()
{
if (shape.Index == 1)
{
return tf.ones_like(shape.AsT1);
}
return tf.ones(shape.AsT0, dtype);
}
//public Tensor OnesLike()
// => tf.ones(shape: shape, dtype: dtype);
public override string ToString()
=> $"{id}, {shape}, {dtype.as_numpy_name()}";
}
}