Skip to content

Commit 420e195

Browse files
committed
TensorShape: Added implicit conversions for object type.
1 parent f70f0a9 commit 420e195

2 files changed

Lines changed: 76 additions & 0 deletions

File tree

src/TensorFlowNET.Core/Tensors/TensorShape.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,22 @@ public override string ToString()
254254

255255
public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0);
256256
public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6);
257+
258+
public static explicit operator (int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 7 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6]) : (0, 0, 0, 0, 0, 0, 0);
259+
public static implicit operator TensorShape((int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7);
260+
261+
public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0);
262+
public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8);
263+
264+
public static implicit operator TensorShape(int?[] dims) => new TensorShape(dims);
265+
public static implicit operator TensorShape(int? dim) => new TensorShape(dim);
266+
public static implicit operator TensorShape((object, object) dims) => new TensorShape(dims.Item1, dims.Item2);
267+
public static implicit operator TensorShape((object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3);
268+
public static implicit operator TensorShape((object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
269+
public static implicit operator TensorShape((object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5);
270+
public static implicit operator TensorShape((object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6);
271+
public static implicit operator TensorShape((object, object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7);
272+
public static implicit operator TensorShape((object, object, object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8);
257273

258274
}
259275
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
using System;
2+
using Microsoft.VisualStudio.TestTools.UnitTesting;
3+
using NumSharp;
4+
using Tensorflow;
5+
using static Tensorflow.Binding;
6+
7+
namespace TensorFlowNET.UnitTest
8+
{
9+
[TestClass]
10+
public class TensorShapeTest
11+
{
12+
[TestMethod]
13+
public void Case1()
14+
{
15+
int? a = 2;
16+
int? b = 3;
17+
var dims = new object[] {(int?) None, a, b};
18+
new TensorShape(dims).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, 3);
19+
}
20+
21+
[TestMethod]
22+
public void Case2()
23+
{
24+
int? a = 2;
25+
int? b = 3;
26+
var dims = new object[] {(int?) None, a, b};
27+
new TensorShape(new object[] {dims}).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, 3);
28+
}
29+
30+
[TestMethod]
31+
public void Case3()
32+
{
33+
int? a = 2;
34+
int? b = null;
35+
var dims = new object[] {(int?) None, a, b};
36+
new TensorShape(new object[] {dims}).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, -1);
37+
}
38+
39+
[TestMethod]
40+
public void Case4()
41+
{
42+
TensorShape shape = (None, None);
43+
shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, -1);
44+
}
45+
46+
[TestMethod]
47+
public void Case5()
48+
{
49+
TensorShape shape = (1, None, 3);
50+
shape.GetPrivate<Shape>("shape").Should().BeShaped(1, -1, 3);
51+
}
52+
53+
[TestMethod]
54+
public void Case6()
55+
{
56+
TensorShape shape = (None, 1, 2, 3, None);
57+
shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, 1, 2, 3, -1);
58+
}
59+
}
60+
}

0 commit comments

Comments
 (0)