forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathActivations.cs
More file actions
100 lines (88 loc) · 3.2 KB
/
Activations.cs
File metadata and controls
100 lines (88 loc) · 3.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations.Activation;
using static Tensorflow.Binding;
namespace Tensorflow.Keras
{
public class Activations: IActivationsApi
{
private static Dictionary<string, Activation> _nameActivationMap;
private static Activation _linear = new Activation()
{
Name = "linear",
ActivationFunction = (features, name) => features
};
private static Activation _relu = new Activation()
{
Name = "relu",
ActivationFunction = (features, name) => tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features))
};
private static Activation _relu6 = new Activation()
{
Name = "relu6",
ActivationFunction = (features, name) => tf.Context.ExecuteOp("Relu6", name, new ExecuteOpArgs(features))
};
private static Activation _sigmoid = new Activation()
{
Name = "sigmoid",
ActivationFunction = (features, name) => tf.Context.ExecuteOp("Sigmoid", name, new ExecuteOpArgs(features))
};
private static Activation _softmax = new Activation()
{
Name = "softmax",
ActivationFunction = (features, name) => tf.Context.ExecuteOp("Softmax", name, new ExecuteOpArgs(features))
};
private static Activation _tanh = new Activation()
{
Name = "tanh",
ActivationFunction = (features, name) => tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(features))
};
private static Activation _mish = new Activation()
{
Name = "mish",
ActivationFunction = (features, name) => features * tf.math.tanh(tf.math.softplus(features))
};
/// <summary>
/// Register the name-activation mapping in this static class.
/// </summary>
/// <param name="activation"></param>
private static void RegisterActivation(Activation activation)
{
_nameActivationMap[activation.Name] = activation;
}
static Activations()
{
_nameActivationMap = new Dictionary<string, Activation>();
RegisterActivation(_relu);
RegisterActivation(_relu6);
RegisterActivation(_linear);
RegisterActivation(_sigmoid);
RegisterActivation(_softmax);
RegisterActivation(_tanh);
RegisterActivation(_mish);
}
public Activation Linear => _linear;
public Activation Relu => _relu;
public Activation Relu6 => _relu6;
public Activation Sigmoid => _sigmoid;
public Activation Softmax => _softmax;
public Activation Tanh => _tanh;
public Activation Mish => _mish;
public Activation GetActivationFromName(string name)
{
if (name == null)
{
return _linear;
}
if (!_nameActivationMap.TryGetValue(name, out var res))
{
throw new Exception($"Activation {name} not found");
}
else
{
return res;
}
}
}
}