Skip to content

Commit 7867751

Browse files
authored
Merge pull request SciSharp#184 from PppBr/master
building normal distribution
2 parents 0132f30 + 7f156f1 commit 7867751

6 files changed

Lines changed: 181 additions & 5 deletions

File tree

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Framework
6+
{
7+
public static class common_shapes
8+
{
9+
/// <summary>
10+
/// Returns the broadcasted shape between `shape_x` and `shape_y
11+
/// </summary>
12+
/// <param name="shape_x"></param>
13+
/// <param name="shape_y"></param>
14+
public static Tensor broadcast_shape(Tensor shape_x, Tensor shape_y)
15+
{
16+
var return_dims = _broadcast_shape_helper(shape_x, shape_y);
17+
// return tensor_shape(return_dims);
18+
throw new NotFiniteNumberException();
19+
}
20+
/// <summary>
21+
/// Helper functions for is_broadcast_compatible and broadcast_shape.
22+
/// </summary>
23+
/// <param name="shape_x"> A `TensorShape`</param>
24+
/// <param name="shape_y"> A `TensorShape`</param>
25+
/// <return> Returns None if the shapes are not broadcast compatible,
26+
/// a list of the broadcast dimensions otherwise.
27+
/// </return>
28+
public static Tensor _broadcast_shape_helper(Tensor shape_x, Tensor shape_y)
29+
{
30+
throw new NotFiniteNumberException();
31+
}
32+
}
33+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow;
5+
6+
namespace Tensorflow.Operations.Distributions
7+
{
8+
public enum DistributionEnum
9+
{
10+
11+
12+
13+
}
14+
}

src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
namespace Tensorflow
88
{
9-
abstract class _BaseDistribution : Object
9+
abstract class _BaseDistribution : Python
1010
{
1111
// Abstract base class needed for resolving subclass hierarchy.
1212
}
@@ -22,8 +22,8 @@ class Distribution : _BaseDistribution
2222
public ReparameterizationType _reparameterization_type {get;set;}
2323
public bool _validate_args {get;set;}
2424
public bool _allow_nan_stats {get;set;}
25-
public Dictionary<object, object> _parameters {get;set;}
26-
public List<object> _graph_parents {get;set;}
25+
public Dictionary<string, object> _parameters {get;set;}
26+
public List<Tensor> _graph_parents {get;set;}
2727
public string _name {get;set;}
2828

2929
/// <summary>
@@ -82,7 +82,21 @@ private Distribution (
8282
/// </summary>
8383
class ReparameterizationType
8484
{
85+
public string _rep_type { get; set; }
86+
public ReparameterizationType(string rep_type)
87+
{
88+
this._rep_type = rep_type;
89+
}
8590

91+
public void repr()
92+
{
93+
Console.WriteLine($"<Reparameteriation Type: {this._rep_type}>" );
94+
}
95+
96+
public bool eq (ReparameterizationType other)
97+
{
98+
return this.Equals(other);
99+
}
86100
}
87101

88102

src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,80 @@
1+
using System.Collections.Generic;
2+
13
namespace Tensorflow
24
{
35
class Normal : Distribution
46
{
5-
public Normal (Tensor loc, Tensor scale, bool validate_args=false, bool allow_nan_stats=true, string name="Normal")
7+
public Tensor _loc { get; set; }
8+
public Tensor _scale { get; set; }
9+
10+
Dictionary<string, object> parameters = new Dictionary<string, object>();
11+
/// <summary>
12+
/// The Normal distribution with location `loc` and `scale` parameters.
13+
/// Mathematical details
14+
/// The probability density function(pdf) is,
15+
/// '''
16+
/// pdf(x; mu, sigma) = exp(-0.5 (x - mu)**2 / sigma**2) / Z
17+
/// Z = (2 pi sigma**2)**0.5
18+
/// '''
19+
/// where `loc = mu` is the mean, `scale = sigma` is the std.deviation, and, `Z`
20+
/// is the normalization constant.
21+
/// </summary>
22+
/// <param name="loc"></param>
23+
/// <param name="scale"></param>
24+
/// <param name="validate_args"></param>
25+
/// <param name="allow_nan_stats"></param>
26+
/// <param name="name"></param>
27+
public Normal (Tensor loc, Tensor scale, bool validate_args=false, bool allow_nan_stats=true, string name="Normal")
28+
{
29+
parameters.Add("name", name);
30+
parameters.Add("loc", loc);
31+
parameters.Add("scale", scale);
32+
parameters.Add("validate_args", validate_args);
33+
parameters.Add("allow_nan_stats", allow_nan_stats);
34+
35+
with(new ops.name_scope(name, "", new { loc, scale }), scope =>
36+
{
37+
with(ops.control_dependencies(validate_args ? new Operation[] { scale.op} : new Operation[] { }), cd =>
38+
{
39+
this._loc = array_ops.identity(loc, name);
40+
this._scale = array_ops.identity(scale, name);
41+
base._dtype = this._scale.dtype;
42+
base._reparameterization_type = new ReparameterizationType("FULLY_REPARAMETERIZED");
43+
base._validate_args = validate_args;
44+
base._allow_nan_stats = allow_nan_stats;
45+
base._parameters = parameters;
46+
base._graph_parents = new List<Tensor>(new Tensor[] { this._loc, this._scale });
47+
base._name = name;
48+
});
49+
50+
});
51+
52+
}
53+
/// <summary>
54+
/// Distribution parameter for the mean.
55+
/// </summary>
56+
/// <returns></returns>
57+
public Tensor loc()
58+
{
59+
return this._loc;
60+
}
61+
/// <summary>
62+
/// Distribution parameter for standard deviation."
63+
/// </summary>
64+
/// <returns></returns>
65+
public Tensor scale()
66+
{
67+
return this._scale;
68+
}
69+
70+
public Tensor _batch_shape_tensor()
71+
{
72+
return array_ops.broadcast_dynamic_shape(array_ops.shape(this._loc), array_ops.shape(this._scale));
73+
}
74+
75+
public Tensor _batch_shape()
676
{
7-
77+
return array_ops.broadcast_static_shape(new Tensor(this._loc.shape), new Tensor(this._scale.shape));
878
}
979

1080
}

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,5 +239,34 @@ public static Tensor squeeze(Tensor input, int[] axis = null, string name = null
239239
{
240240
return gen_array_ops.squeeze(input, axis, name);
241241
}
242+
243+
public static Tensor identity(Tensor input, string name = null)
244+
{
245+
return gen_array_ops.identity(input, name);
246+
}
247+
/// <summary>
248+
/// Computes the shape of a broadcast given symbolic shapes.
249+
/// When shape_x and shape_y are Tensors representing shapes(i.e.the result of
250+
/// calling tf.shape on another Tensor) this computes a Tensor which is the shape
251+
/// of the result of a broadcasting op applied in tensors of shapes shape_x and
252+
/// shape_y.
253+
/// For example, if shape_x is [1, 2, 3] and shape_y is [5, 1, 3], the result is a
254+
/// Tensor whose value is [5, 2, 3].
255+
/// This is useful when validating the result of a broadcasting operation when the
256+
/// tensors do not have statically known shapes.
257+
/// </summary>
258+
/// <param name="shape_x"> A rank 1 integer `Tensor`, representing the shape of x.</param>
259+
/// <param name="shape_y"> A rank 1 integer `Tensor`, representing the shape of y.</param>
260+
/// <returns> A rank 1 integer `Tensor` representing the broadcasted shape.</returns>
261+
public static Tensor broadcast_dynamic_shape(Tensor shape_x, Tensor shape_y)
262+
{
263+
return gen_array_ops.broadcast_args(shape_x, shape_y);
264+
}
265+
266+
public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y)
267+
{
268+
return Framework.common_shapes.broadcast_shape(shape_x, shape_y);
269+
}
270+
242271
}
243272
}

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,5 +178,21 @@ public static Tensor squeeze(Tensor input, int[] axis = null, string name = null
178178

179179
return _op.outputs[0];
180180
}
181+
182+
/// <summary>
183+
/// Return the shape of s0 op s1 with broadcast.
184+
/// Given `s0` and `s1`, tensors that represent shapes, compute `r0`, the
185+
/// broadcasted shape. `s0`, `s1` and `r0` are all integer vectors.
186+
/// </summary>
187+
/// <param name="s0"> A `Tensor`. Must be one of the following types: `int32`, `int64`.</param>
188+
/// <param name="s1"> A `Tensor`. Must have the same type as `s0`.</param>
189+
/// <param name="name"> A name for the operation (optional).</param>
190+
/// <returns> `Tensor`. Has the same type as `s0`.</returns>
191+
public static Tensor broadcast_args(Tensor s0, Tensor s1, string name = null)
192+
{
193+
var _op = _op_def_lib._apply_op_helper("BroadcastArgs", name, args: new { s0, s1, name });
194+
195+
return _op.outputs[0];
196+
}
181197
}
182198
}

0 commit comments

Comments
 (0)