forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathModel.cs
More file actions
70 lines (63 loc) · 2.32 KB
/
Model.cs
File metadata and controls
70 lines (63 loc) · 2.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
using NumSharp;
using System;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Optimizers;
namespace Tensorflow.Keras.Engine
{
/// <summary>
/// `Model` groups layers into an object with training and inference features.
/// </summary>
public class Model : Layer
{
#pragma warning disable CS0169 // The field 'Model._cloning' is never used
bool _cloning;
#pragma warning restore CS0169 // The field 'Model._cloning' is never used
#pragma warning disable CS0108 // Member hides inherited member; missing new keyword
#pragma warning disable CS0414 // The field 'Model._is_compiled' is assigned but its value is never used
bool _is_compiled;
#pragma warning restore CS0414 // The field 'Model._is_compiled' is assigned but its value is never used
#pragma warning restore CS0108 // Member hides inherited member; missing new keyword
string loss;
IOptimizer optimizer;
public Model(ModelArgs args)
: base(args)
{
}
public void compile(string optimizerName, string lossName)
{
switch (optimizerName)
{
case "rmsprop":
optimizer = new RMSprop();
break;
}
loss = lossName;
_is_compiled = true;
// Prepare list of loss functions, same size of model outputs.
}
/// <summary>
/// Generates output predictions for the input samples.
/// </summary>
/// <param name="x">Input samples</param>
/// <param name="batch_size">Number of samples per batch</param>
/// <param name="verbose">Verbosity mode</param>
/// <param name="steps">
/// Total number of steps (batches of samples)
/// before declaring the prediction round finished.
/// </param>
/// <param name="max_queue_size"></param>
/// <param name="workers"></param>
/// <param name="use_multiprocessing"></param>
/// <returns></returns>
public Tensor predict(Tensor x,
int batch_size = 32,
int verbose = 0,
int steps = -1,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false)
{
throw new NotImplementedException("");
}
}
}