Skip to content

Commit bd26bbd

Browse files
committed
Orthogonal initializer.
1 parent 321ddfc commit bd26bbd

16 files changed

Lines changed: 202 additions & 32 deletions

File tree

src/TensorFlowNET.Console/Program.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using Tensorflow.Keras;
23
using static Tensorflow.Binding;
34

45
namespace Tensorflow
@@ -7,6 +8,8 @@ class Program
78
{
89
static void Main(string[] args)
910
{
11+
tf.UseKeras<KerasInterface>();
12+
1013
var diag = new Diagnostician();
1114
// diag.Diagnose(@"D:\memory.txt");
1215

src/TensorFlowNET.Core/APIs/tf.linalg.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ public Tensor lstsq(Tensor matrix, Tensor rhs,
5858
NDArray l2_regularizer = null, bool fast = true, string name = null)
5959
=> ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name);
6060

61+
public Tensors qr(Tensor input, bool full_matrices = true, string name = null)
62+
=> ops.qr(input, full_matrices: full_matrices, name: name);
63+
64+
public Tensor tensor_diag_part(Tensor input, string name = null)
65+
=> gen_array_ops.diag_part(input, name: name);
66+
6167
public Tensor tensordot(Tensor x, Tensor y, NDArray axes, string name = null)
6268
=> math_ops.tensordot(x, y, axes, name: name);
6369
}

src/TensorFlowNET.Core/APIs/tf.random.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ public Tensor normal(Shape shape,
3939
int? seed = null,
4040
string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name);
4141

42+
public Tensor stateless_normal(Shape shape,
43+
float mean = 0.0f,
44+
float stddev = 1.0f,
45+
TF_DataType dtype = TF_DataType.TF_FLOAT,
46+
string name = null) => stateless_random_ops.stateless_random_normal(shape, mean, stddev, dtype, name: name);
47+
4248
/// <summary>
4349
/// Outputs random values from a truncated normal distribution.
4450
/// </summary>
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras
6+
{
7+
public interface IInitializersApi
8+
{
9+
IInitializer Orthogonal(float gain = 1.0f, int? seed = null);
10+
}
11+
}

src/TensorFlowNET.Core/Keras/IKerasApi.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ namespace Tensorflow.Keras
88
public interface IKerasApi
99
{
1010
public ILayersApi layers { get; }
11+
public IInitializersApi initializers { get; }
1112
}
1213
}

src/TensorFlowNET.Core/NumPy/NDArrayRender.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ static string Render(NDArray array)
109109
TF_DataType.TF_INT8 => Render(array.ToArray<sbyte>(), array.shape),
110110
TF_DataType.TF_INT32 => Render(array.ToArray<int>(), array.shape),
111111
TF_DataType.TF_INT64 => Render(array.ToArray<long>(), array.shape),
112+
TF_DataType.TF_UINT64 => Render(array.ToArray<ulong>(), array.shape),
112113
TF_DataType.TF_FLOAT => Render(array.ToArray<float>(), array.shape),
113114
TF_DataType.TF_DOUBLE => Render(array.ToArray<double>(), array.shape),
114115
_ => Render(array.ToArray<byte>(), array.shape)
Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,62 @@
1-
using System;
1+
/*****************************************************************************
2+
Copyright 2023 Haiping Chen. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
218
using System.Linq;
3-
using static Tensorflow.TensorShapeProto.Types;
19+
using static Tensorflow.Binding;
420

5-
namespace Tensorflow.Operations.Initializers
21+
namespace Tensorflow.Operations.Initializers;
22+
23+
public class Orthogonal : IInitializer
624
{
7-
public class Orthogonal : IInitializer
25+
float _gain = 0f;
26+
int? _seed;
27+
28+
public Orthogonal(float gain = 1.0f, int? seed = null)
829
{
9-
float _gain = 0f;
30+
_gain = gain;
31+
_seed = seed;
32+
}
1033

11-
public Orthogonal(float gain = 1.0f, int? seed = null)
12-
{
34+
public Tensor Apply(InitializerArgs args)
35+
{
36+
return _generate_init_val(args.Shape, args.DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : args.DType);
37+
}
1338

14-
}
39+
private Tensor _generate_init_val(Shape shape, TF_DataType dtype)
40+
{
41+
var num_rows = 1L;
42+
foreach (var dim in shape.dims.Take(shape.ndim - 1))
43+
num_rows *= dim;
44+
var num_cols = shape.dims.Last();
45+
var flat_shape = (Math.Max(num_cols, num_rows), Math.Min(num_cols, num_rows));
1546

16-
public Tensor Apply(InitializerArgs args)
17-
{
18-
return _generate_init_val(args.Shape, args.DType);
19-
}
47+
var a = tf.random.stateless_normal(flat_shape, dtype: dtype);
48+
// Compute the qr factorization
49+
var (q, r) = tf.linalg.qr(a, full_matrices: false);
50+
// Make Q uniform
51+
var d = tf.linalg.tensor_diag_part(r);
52+
q *= tf.sign(d);
2053

21-
private Tensor _generate_init_val(Shape shape, TF_DataType dtype)
54+
if (num_rows < num_cols)
2255
{
23-
var num_rows = 1L;
24-
foreach (var dim in shape.dims.Take(shape.ndim - 1))
25-
num_rows *= dim;
26-
var num_cols = shape.dims.Last();
27-
var flat_shape = (Math.Max(num_cols, num_rows), Math.Min(num_cols, num_rows));
28-
56+
// q = tf.linalg.matrix_transpose(q);
2957
throw new NotImplementedException("");
3058
}
59+
60+
return _gain * tf.reshape(q, shape);
3161
}
3262
}

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ public static Tensor[] concat_offset(Tensor concat_dim, Tensor[] shape, string n
113113
public static Tensor diag(Tensor diagonal, string name = null)
114114
=> tf.Context.ExecuteOp("Diag", name, new ExecuteOpArgs(diagonal));
115115

116+
public static Tensor diag_part(Tensor diagonal, string name = null)
117+
=> tf.Context.ExecuteOp("DiagPart", name, new ExecuteOpArgs(diagonal));
118+
116119
public static Tensor expand_dims(Tensor input, int axis, string name = null)
117120
=> tf.Context.ExecuteOp("ExpandDims", name, new ExecuteOpArgs(input, axis)
118121
.SetAttributes(new { dim = axis }));

src/TensorFlowNET.Core/Operations/gen_random_ops.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ You may obtain a copy of the License at
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
******************************************************************************/
16+
using static Tensorflow.ApiDef.Types;
17+
using System.Reflection;
1618
using static Tensorflow.Binding;
19+
using System.Xml.Linq;
1720

1821
namespace Tensorflow
1922
{
@@ -85,6 +88,15 @@ public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed
8588
int? seed2 = 0, string name = null)
8689
=> tf.Context.ExecuteOp("TruncatedNormal", name, new ExecuteOpArgs(shape)
8790
.SetAttributes(new { dtype, seed = seed ?? 0, seed2 = seed2 ?? 0 }));
91+
public static Tensor stateless_random_normal_v2(Tensor shape, Tensor key, Tensor counter,
92+
int alg, TF_DataType dtype, string name = null)
93+
=> tf.Context.ExecuteOp("StatelessRandomNormalV2", name,
94+
new ExecuteOpArgs(shape, key, counter, alg)
95+
.SetAttributes(new { dtype }));
96+
97+
public static Tensors stateless_random_get_key_counter(int[] seed, string name = null)
98+
=> tf.Context.ExecuteOp("StatelessRandomGetKeyCounter", name,
99+
new ExecuteOpArgs(seed));
88100

89101
public static Tensor multinomial(Tensor logits, int num_samples, int? seed = 0,
90102
int? seed2 = 0, TF_DataType output_dtype = TF_DataType.TF_INT64, string name = null)

src/TensorFlowNET.Core/Operations/linalg_ops.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,5 +129,12 @@ public Tensor matrix_triangular_solve(Tensor matrix, Tensor rhs, bool lower = tr
129129
lower,
130130
adjoint
131131
}));
132+
133+
public Tensors qr(Tensor input, bool full_matrices = false, string name = null)
134+
=> tf.Context.ExecuteOp("Qr", name,
135+
new ExecuteOpArgs(input).SetAttributes(new
136+
{
137+
full_matrices
138+
}));
132139
}
133140
}

0 commit comments

Comments
 (0)