forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTensors.cs
More file actions
107 lines (86 loc) · 3.1 KB
/
Tensors.cs
File metadata and controls
107 lines (86 loc) · 3.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
using Tensorflow.NumPy;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
namespace Tensorflow
{
/// <summary>
/// Tensors is used to represent a Tensor or a array of Tensor.
/// It will simplify the API interface, it converts Tensor
/// and Tensor[] to Tensors implicitily. And parse back to Tensor
/// and Tensor[] from Tensors implicitily.
/// It works for tuple and scalar as well.
/// </summary>
public class Tensors : IEnumerable<Tensor>, IDisposable
{
List<Tensor> items = new List<Tensor>();
public TF_DataType dtype => items.First().dtype;
public Shape shape => items.First().shape;
public int rank => items.First().rank;
public Graph graph => items.First().graph;
public bool IsList { get; set; }
public int Length => items.Count();
public Tensor this[int index]
{
get => items[index];
set => items[index] = value;
}
public Tensor this[params string[] slices]
=> items.First()[slices];
public Tensors(params Tensor[] tensors)
{
items.AddRange(tensors);
}
public Tensors(IEnumerable<Tensor> tensors)
{
items.AddRange(tensors);
}
public Tensors(NDArray nd)
{
items.Add(ops.convert_to_tensor(nd));
}
public IEnumerator<Tensor> GetEnumerator()
{
foreach (var tensor in items)
yield return tensor;
}
public void Add(Tensor tensor)
=> items.Add(tensor);
public void AddRange(Tensor[] tensors)
=> items.AddRange(tensors);
public void Insert(int index, Tensor tensor)
=> items.Insert(index, tensor);
IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerator();
public static implicit operator Tensors(Tensor tensor)
=> new Tensors(tensor);
public static implicit operator Tensors((Tensor, Tensor) tuple)
=> new Tensors(tuple.Item1, tuple.Item2);
[AutoNumPy]
public static implicit operator Tensors(NDArray nd)
=> new Tensors(nd);
public static implicit operator Tensors(Tensor[] tensors)
=> new Tensors(tensors);
public static implicit operator Tensors(List<Tensor> tensors)
=> new Tensors(tensors.ToArray());
public static implicit operator Tensor(Tensors tensors)
=> tensors.FirstOrDefault();
public static implicit operator Tensor[](Tensors tensors)
=> tensors.items.ToArray();
public void Deconstruct(out Tensor a, out Tensor b)
{
a = items[0];
b = items[1];
}
public override string ToString()
=> items.Count() == 1
? items.First().ToString()
: items.Count() + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name));
public void Dispose()
{
foreach (var item in items)
item.Dispose();
}
}
}