forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstate_ops.cs
More file actions
112 lines (101 loc) · 4.5 KB
/
state_ops.cs
File metadata and controls
112 lines (101 loc) · 4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/
using System;
using System.Collections.Generic;
using System.Text;
namespace Tensorflow
{
public class state_ops
{
/// <summary>
/// Create a variable Operation.
/// </summary>
/// <param name="shape"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
/// <param name="container"></param>
/// <param name="shared_name"></param>
/// <returns></returns>
public static Tensor variable_op_v2(int[] shape,
TF_DataType dtype,
string name = "Variable",
string container = "",
string shared_name = "") => gen_state_ops.variable_v2(shape,
dtype,
name: name,
container: container,
shared_name: shared_name);
public static Tensor assign(Tensor @ref, object value,
bool validate_shape = true,
bool use_locking = true,
string name = null)
{
if (@ref.dtype.is_ref_dtype())
return gen_state_ops.assign(@ref,
value,
validate_shape: validate_shape,
use_locking: use_locking,
name: name);
throw new NotImplementedException("state_ops.assign");
//return @ref.assign(value, name: name);
}
public static Tensor assign(RefVariable @ref, object value,
bool validate_shape = true,
bool use_locking = true,
string name = null)
{
return gen_state_ops.assign(@ref,
value,
validate_shape: validate_shape,
use_locking: use_locking,
name: name);
}
public static Tensor assign_sub(RefVariable @ref,
Tensor value,
bool use_locking = false,
string name = null) => gen_state_ops.assign_sub(@ref,
value,
use_locking: use_locking,
name: name);
//"""Update 'ref' by adding 'value' to it.
//
// This operation outputs "ref" after the update is done.
// This makes it easier to chain operations that need to use the reset value.
//
// Args:
// ref: A mutable `Tensor`. Must be one of the following types:
// `float32`, `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`,
// `int8`, `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`.
// Should be from a `Variable` node.
// value: A `Tensor`. Must have the same type as `ref`.
// The value to be added to the variable.
// use_locking: An optional `bool`. Defaults to `False`.
// If True, the addition will be protected by a lock;
// otherwise the behavior is undefined, but may exhibit less contention.
// name: A name for the operation (optional).
//
// Returns:
// Same as "ref". Returned as a convenience for operations that want
// to use the new value after the variable has been updated.
public static Tensor assign_add(RefVariable @ref,
Tensor value,
bool use_locking = false,
string name = null) => gen_state_ops.assign_add(@ref, value, use_locking: use_locking, name: name);
public static Tensor scatter_add(RefVariable @ref, Tensor indices, Tensor updates, bool use_locking = false, string name = null)
{
if (@ref.dtype.is_ref_dtype())
return gen_state_ops.scatter_add(@ref, indices, updates, use_locking: use_locking, name: name);
throw new NotImplementedException("scatter_add");
}
}
}