forked from kaldi-asr/kaldi
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcontext.h
More file actions
291 lines (228 loc) · 10.5 KB
/
context.h
File metadata and controls
291 lines (228 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
// tensor/context.h
// Copyright 2019 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_TENSOR_CONTEXT_H_
#define KALDI_TENSOR_CONTEXT_H_ 1
#include <cstdint>
#include <vector>
#include <string>
#include "tensor/tensor-common.h"
/**
This file contains certain mechanisms to set settings about default
data types and devices within scopes, some related things like
an equivalent of PyTorch's .no_grad(). Also the `Tick()` mechanism
is here.
*/
namespace kaldi {
namespace tensor {
// class Context contains various configurations that we will sometimes need
// when we do operations on Tensors.
struct Context {
// The default DataType for newly created Tensors
DataType default_dtype;
// The default Device for newly created Tensors
Device default_device;
};
// ExecutionContext is used when executing Ops (or doing other things
// with them, e.g. just storing them); we explicitly pass this
// object into functions that might want to execute Ops.
class ExecutionContext: public Context {
/// This function executes the Op (op.Do()) and/or does something else
/// relating to taking derivatives.
virtual void Execute(const Op &op);
virtual ~ExecutionContext() {}
};
// SimpleExecutionContext means we just execute an Op and then immediately
// delete it. It's used when we are just computing something with no
// autograd. You could, of course, just call the version of the
// Op that doesn't take an ExecutionContext, but this option makes
// it easier to switch between autograd and no-autograd.
class SimpleExecutionContext: public ExecutionContext {
virtual void Execute(const Op &op) { op.Do(); }
virtual ~SimpleExecutionContext() {}
};
/**
Execution context that you use while doing a forward computation, that
executes the forward commands and stores the things required to later do the
backprop. See its Backprop() function for how to execute the backprop.
*/
class BackpropExecutionContext: public ExecutionContext {
/**
Constructor of BackpropExecutionContext from an existing DerivMap, which
might map, for instance, parameters to their derivatives.
@param [in] deriv_map An existing DerivMap, to which the user will
likely have added the model parameters and anything
else that derivatives are needed for, with its
Deriv() function. This is *copied*, not held as a
reference, by this object, to avoid a kind of memory
leakage.
@param [in] base_context The base execution context, which would
normally be SimpleExecutionContext; it is used to
execute both the forward and backward commands.
This class will store the pointer but will not take
ownership; it is the user's responsibility to
make sure it stays alive as long as this object is
alive.
*/
BackpropExecutionContext(const DerivMap &deriv_map,
ExecutionContext *base_context);
/**
Does the backprop on a Tensor t; propagates the derivative back to whatever
quantities you had added derivs for in the DerivMap passed to the constructor.
The backprop commands will be executed with a SimpleExecutionContext
whose Context base-class is a copy of this class's one. If you want to
do something fancier (e.g. for 2, you can use the version of Backprop
If retain_info is false, it will delete deriv_map_ and clear backward_ops_.
This is recommended in most cases; it's more memory efficient.
@param [in] t The Tensor that we are taking the derivative with
respect to.
@param [in] deriv The derivative w.r.t. t of the function we
are taking the derivative of. Might be just
1.0. Must satsify Broadcastable(deriv, t).
Note: deriv may have more axes than t, in which
case the extra leading axes are required to
have dimensions equal to deriv_map_->ExtraDims().
If deriv_map_->ExtraDims() is nonempty,
the num-axes of 'deriv' is required to equal
`t.NumAxes() + deriv_map_->ExtraDims().size()`.
*/
void Backprop(const Tensor &t,
const Tensor &deriv) {
if (deriv_map_ == nullptr)
KALDI_ERR << "You cannot call Backprop twice on the same "
"BackpropExecutionContext";
// Delete deriv_map_. This will help ensure that derivative
// quantities are deleted as soon as they are no longer needed
// (since once we delete the deriv_map_ and the ops referring
// to those derivative matrices, they will be garbage collected).
deriv_map_ = nullptr;
for (auto iter = backward_ops_.rbegin();
iter != backward_ops_.rend(); ++iter){
base_context_->Execute(**iter);
// Delete this op. Deleting the ops also deletes the associated
// derivative matrices, via shared_ptr garbage collection.
*iter = nullptr;
}
backward_ops_.clear();
}
virtual void Execute(const Op &op) {
base_context_->Execute(op);
op.GetBackwardDerivOps(&deriv_map_, &backward_ops_);
}
virtual ~BackpropExecutionContext() { }
private:
std::vector<unique_ptr<Op> > backward_ops_;
unique_ptr<DerivMap> deriv_map_;
ExecutionContext *base_context_;
};
/**
Execution context that you use while doing a forward computation, that
executes the forward commands and also computes forward derivatives
w.r.t. something.
*/
class ForwardPropExecutionContext: public ExecutionContext {
/**
Constructor of ForwardPropExecutionContext from an existing DerivMap, which
might map, for instance, some input x to dx/da, where a is the thing
we're taking the derivative of.
@param [in] deriv_map An existing DerivMap, to which the user will
likely have added the thing we are taking the derivative
w.r.t. (e.g. some input where we want to see its
effect on the computation). deriv_map is *copied*,
not held as a reference, by this object, to avoid
a kind of memory leakage.
@param [in] base_context The base execution context, which would
normally be SimpleExecutionContext; it is used to
execute both the forward and backward commands.
This class will store the pointer but will not take
ownership; it is the user's responsibility to
make sure it stays alive as long as this object is
alive.
*/
ForwardPropExecutionContext(const DerivMap &deriv_map,
ExecutionContext *base_context);
virtual void Execute(const Op &op) {
base_context_->Execute(op);
std::vector<std::unique_ptr<Op> > ops;
op.GetForwardDerivOps(&deriv_map_, &ops);
for (auto iter = ops.begin(); iter != ops.end(); ++iter)
base_context_->Execute(*iter);
// and let the ops in 'ops' go out of scope and get deleted.
}
// Returns pointer to this deriv_map_ (still owned by this class).
// May be used to query the derivative of some Tensor w.r.t. the
// input, e.g. forward_context.GetDerivMap()->DerivIfPresent(some_tensor).
DerivMap *GetDerivMap() { return deriv_map_.get(); }
};
// struct TensorOptions is used as an arg for some constructors
// when creating Tensors and Variables; it allows flexibility
// in specifying the device and/or dtype. See the examples
// shown where constructors of Tensor or Variable are declared.
struct TensorOptions {
DataType dtype;
Device device;
explicit TensorOptions(const Context &context):
dtype(context.default_dtype),
device(context.default_device) { }
explicit TensorOptions(const Context &context,
DataType dtype):
dtype(dtype), device(context.default_device) { }
explicit TensorOptions(const Context &context, Device device):
dtype(context.default_dtype), device(device) { }
explicit TensorOptions(const Context &context, DeviceType device_type):
dtype(context.default_dtype), device(device_type) { }
// Here the context is not used; we could create a new version
// that doesn't take the context object, but of course that would
// make it harder if we add more options later.
TensorOptions(const Context &context, DataType dtype,
Device device):
dtype(dtype), device(device) { }
TensorOptions(const Context &context, DataType dtype,
Device device_type):
dtype(dtype), device(device_type) { }
TensorOptions(DataType dtype, Device device_type):
dtype(dtype), device(device_type) { }
explicit TensorOptions(const TensorOptions &other):
dtype(other.dtype), device(other.device) { }
};
// Global variable, initialized from zero, that is used in GetTick().
// This is defined in tensor-settings.cc.
extern int64 g_tick_counter;
inline int64 NextTick() { return ++g_tick_counter; }
// debug_mode activates code that checks for invalidated data in the backprop
// pass; see "Invalidated:" in glossary in tensor.h.
// Don't access this variable directly,
extern bool debug_mode; // Do not access directly!
extern int64 debug_start_tick; // Do not access directly!
inline bool DebugMode() {
return debug_mode;
}
inline void SetDebugMode(bool b) {
if (!debug_mode)
debug_start_tick = NextTick();
debug_mode = b;
}
/**
Returns the tick at which debug mode most recently changed from false to
true.
*/
inline int64 DebugTick() {
KALDI_PARANOID_ASSERT(debug_mode);
return debug_start_tick;
}
} // namespace tensor
} // namespace kaldi
#endif // KALDI_TENSOR_CONTEXT_H_