Skip to content

Commit 7f0b9b6

Browse files
committed
add np.load.
1 parent 254ba33 commit 7f0b9b6

5 files changed

Lines changed: 175 additions & 3 deletions

File tree

src/TensorFlowNET.Core/NumPy/Implementation/NumPyImpl.Creation.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.IO;
4+
using System.Linq;
35
using System.Text;
6+
using Tensorflow.Util;
47
using static Tensorflow.Binding;
58

69
namespace Tensorflow.NumPy
@@ -67,6 +70,19 @@ public NDArray linspace<T>(T start, T stop, int num = 50, bool endpoint = true,
6770
return new NDArray(result);
6871
}
6972

73+
Array ReadValueMatrix(BinaryReader reader, Array matrix, int bytes, Type type, int[] shape)
74+
{
75+
int total = 1;
76+
for (int i = 0; i < shape.Length; i++)
77+
total *= shape[i];
78+
var buffer = new byte[bytes * total];
79+
80+
reader.Read(buffer, 0, buffer.Length);
81+
System.Buffer.BlockCopy(buffer, 0, matrix, 0, buffer.Length);
82+
83+
return matrix;
84+
}
85+
7086
public (NDArray, NDArray) meshgrid<T>(T[] array, bool copy = true, bool sparse = false)
7187
{
7288
var tensors = array_ops.meshgrid(array, copy: copy, sparse: sparse);
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
using System;
2+
using System.Collections;
3+
using System.Collections.Generic;
4+
using System.IO;
5+
using System.Linq;
6+
using System.Reflection;
7+
using System.Text;
8+
using Tensorflow.Util;
9+
10+
namespace Tensorflow.NumPy
11+
{
12+
public partial class NumPyImpl
13+
{
14+
public NDArray load(string file)
15+
{
16+
using var stream = new FileStream(file, FileMode.Open);
17+
using var reader = new BinaryReader(stream, Encoding.ASCII, leaveOpen: true);
18+
int bytes;
19+
Type type;
20+
int[] shape;
21+
if (!ParseReader(reader, out bytes, out type, out shape))
22+
throw new FormatException();
23+
24+
Array array = Create(type, shape.Aggregate((dims, dim) => dims * dim));
25+
26+
var result = new NDArray(ReadValueMatrix(reader, array, bytes, type, shape));
27+
return result.reshape(shape);
28+
}
29+
30+
bool ParseReader(BinaryReader reader, out int bytes, out Type t, out int[] shape)
31+
{
32+
bytes = 0;
33+
t = null;
34+
shape = null;
35+
36+
// The first 6 bytes are a magic string: exactly "x93NUMPY"
37+
if (reader.ReadChar() != 63) return false;
38+
if (reader.ReadChar() != 'N') return false;
39+
if (reader.ReadChar() != 'U') return false;
40+
if (reader.ReadChar() != 'M') return false;
41+
if (reader.ReadChar() != 'P') return false;
42+
if (reader.ReadChar() != 'Y') return false;
43+
44+
byte major = reader.ReadByte(); // 1
45+
byte minor = reader.ReadByte(); // 0
46+
47+
if (major != 1 || minor != 0)
48+
throw new NotSupportedException();
49+
50+
ushort len = reader.ReadUInt16();
51+
52+
string header = new String(reader.ReadChars(len));
53+
string mark = "'descr': '";
54+
int s = header.IndexOf(mark) + mark.Length;
55+
int e = header.IndexOf("'", s + 1);
56+
string type = header.Substring(s, e - s);
57+
bool? isLittleEndian;
58+
t = GetType(type, out bytes, out isLittleEndian);
59+
60+
if (isLittleEndian.HasValue && isLittleEndian.Value == false)
61+
throw new Exception();
62+
63+
mark = "'fortran_order': ";
64+
s = header.IndexOf(mark) + mark.Length;
65+
e = header.IndexOf(",", s + 1);
66+
bool fortran = bool.Parse(header.Substring(s, e - s));
67+
68+
if (fortran)
69+
throw new Exception();
70+
71+
mark = "'shape': (";
72+
s = header.IndexOf(mark) + mark.Length;
73+
e = header.IndexOf(")", s + 1);
74+
shape = header.Substring(s, e - s).Split(',').Where(v => !String.IsNullOrEmpty(v)).Select(Int32.Parse).ToArray();
75+
76+
return true;
77+
}
78+
79+
Type GetType(string dtype, out int bytes, out bool? isLittleEndian)
80+
{
81+
isLittleEndian = IsLittleEndian(dtype);
82+
bytes = Int32.Parse(dtype.Substring(2));
83+
84+
string typeCode = dtype.Substring(1);
85+
86+
if (typeCode == "b1")
87+
return typeof(bool);
88+
if (typeCode == "i1")
89+
return typeof(Byte);
90+
if (typeCode == "i2")
91+
return typeof(Int16);
92+
if (typeCode == "i4")
93+
return typeof(Int32);
94+
if (typeCode == "i8")
95+
return typeof(Int64);
96+
if (typeCode == "u1")
97+
return typeof(Byte);
98+
if (typeCode == "u2")
99+
return typeof(UInt16);
100+
if (typeCode == "u4")
101+
return typeof(UInt32);
102+
if (typeCode == "u8")
103+
return typeof(UInt64);
104+
if (typeCode == "f4")
105+
return typeof(Single);
106+
if (typeCode == "f8")
107+
return typeof(Double);
108+
if (typeCode.StartsWith("S"))
109+
return typeof(String);
110+
111+
throw new NotSupportedException();
112+
}
113+
114+
bool? IsLittleEndian(string type)
115+
{
116+
bool? littleEndian = null;
117+
118+
switch (type[0])
119+
{
120+
case '<':
121+
littleEndian = true;
122+
break;
123+
case '>':
124+
littleEndian = false;
125+
break;
126+
case '|':
127+
littleEndian = null;
128+
break;
129+
default:
130+
throw new Exception();
131+
}
132+
133+
return littleEndian;
134+
}
135+
136+
Array Create(Type type, int length)
137+
{
138+
// ReSharper disable once PossibleNullReferenceException
139+
while (type.IsArray)
140+
type = type.GetElementType();
141+
142+
return Array.CreateInstance(type, length);
143+
}
144+
}
145+
}

src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@ public partial class NDArray
88
{
99
public void Deconstruct(out byte blue, out byte green, out byte red)
1010
{
11-
blue = (byte)dims[0];
12-
green = (byte)dims[1];
13-
red = (byte)dims[2];
11+
var data = Data<byte>();
12+
blue = data[0];
13+
green = data[1];
14+
red = data[2];
1415
}
1516

1617
public static implicit operator NDArray(Array array)

src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ public partial class NDArray
1919
public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) => Init(shape, dtype: dtype);
2020
public NDArray(Tensor value, Shape? shape = null) => Init(value, shape);
2121
public NDArray(byte[] bytes, Shape shape, TF_DataType dtype) => Init(bytes, shape, dtype);
22+
public NDArray(IntPtr address, Shape shape, TF_DataType dtype) => Init(address, shape, dtype);
2223

2324
public static NDArray Scalar<T>(T value) where T : unmanaged
2425
=> value switch
@@ -75,5 +76,11 @@ void Init(byte[] bytes, Shape shape, TF_DataType dtype)
7576
_tensor = new Tensor(bytes, shape, dtype);
7677
_tensor.SetReferencedByNDArray();
7778
}
79+
80+
void Init(IntPtr address, Shape shape, TF_DataType dtype)
81+
{
82+
_tensor = new Tensor(address, shape, dtype);
83+
_tensor.SetReferencedByNDArray();
84+
}
7885
}
7986
}

src/TensorFlowNET.Core/Numpy/Numpy.Creation.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ public static NDArray linspace<T>(T start, T stop, int num = 50, bool endpoint =
4040
TF_DataType dtype = TF_DataType.TF_DOUBLE, int axis = 0) where T : unmanaged
4141
=> tf.numpy.linspace(start, stop, num: num, endpoint: endpoint, retstep: retstep, dtype: dtype, axis: axis);
4242

43+
public static NDArray load(string file)
44+
=> tf.numpy.load(file);
45+
4346
public static (NDArray, NDArray) meshgrid<T>(T x, T y, bool copy = true, bool sparse = false)
4447
=> tf.numpy.meshgrid(new[] { x, y }, copy: copy, sparse: sparse);
4548

0 commit comments

Comments
 (0)