Skip to content

Commit 6ec984e

Browse files
Adds the following new ops:
alias_inplace_add alias_inplace_subtract alias_inplace_update empty empty_like parallel_stack Some of these ops will be used to speed up the data-loading pipline. Change: 143502368
1 parent 109c03d commit 6ec984e

12 files changed

Lines changed: 1012 additions & 7 deletions

File tree

tensorflow/contrib/makefile/tf_op_files.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ tensorflow/core/kernels/maxpooling_op.cc
8888
tensorflow/core/kernels/matmul_op.cc
8989
tensorflow/core/kernels/lrn_op.cc
9090
tensorflow/core/kernels/logging_ops.cc
91+
tensorflow/core/kernels/inplace_ops.cc
9192
tensorflow/core/kernels/in_topk_op.cc
9293
tensorflow/core/kernels/immutable_constant_op.cc
9394
tensorflow/core/kernels/identity_op.cc

tensorflow/core/kernels/BUILD

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,7 @@ cc_library(
476476
":gather_nd_op",
477477
":gather_op",
478478
":identity_op",
479+
":inplace_ops",
479480
":listdiff_op",
480481
":matrix_band_part_op",
481482
":matrix_diag_op",
@@ -653,6 +654,12 @@ tf_kernel_library(
653654
deps = ARRAY_DEPS + [":split_lib"],
654655
)
655656

657+
tf_kernel_library(
658+
name = "inplace_ops",
659+
prefix = "inplace_ops",
660+
deps = ARRAY_DEPS,
661+
)
662+
656663
tf_kernel_library(
657664
name = "tile_ops",
658665
prefix = "tile_ops",
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#define EIGEN_USE_THREADS
17+
18+
#include "tensorflow/core/framework/op_kernel.h"
19+
#include "tensorflow/core/framework/register_types.h"
20+
#include "tensorflow/core/framework/tensor.h"
21+
#include "tensorflow/core/framework/tensor_shape.h"
22+
#include "tensorflow/core/kernels/fill_functor.h"
23+
#include "tensorflow/core/kernels/inplace_ops_functor.h"
24+
#include "tensorflow/core/lib/core/status.h"
25+
26+
namespace tensorflow {
27+
28+
typedef Eigen::ThreadPoolDevice CPUDevice;
29+
30+
class InplaceOpBase : public OpKernel {
31+
public:
32+
explicit InplaceOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
33+
34+
void Compute(OpKernelContext* ctx) override {
35+
auto value = ctx->input(0);
36+
auto loc = ctx->input(1);
37+
auto update = ctx->input(2);
38+
39+
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(loc.shape()),
40+
errors::InvalidArgument("loc must be a vector. ",
41+
loc.shape().DebugString()));
42+
OP_REQUIRES(
43+
ctx, value.dims() == update.dims(),
44+
errors::InvalidArgument("value and update shape doesn't match: ",
45+
value.shape().DebugString(), " vs. ",
46+
update.shape().DebugString()));
47+
for (int i = 1; i < value.dims(); ++i) {
48+
OP_REQUIRES(
49+
ctx, value.dim_size(i) == update.dim_size(i),
50+
errors::InvalidArgument("value and update shape doesn't match ",
51+
value.shape().DebugString(), " vs. ",
52+
update.shape().DebugString()));
53+
}
54+
OP_REQUIRES(ctx, loc.dim_size(0) == update.dim_size(0),
55+
errors::InvalidArgument("loc and update shape doesn't match: ",
56+
loc.shape().DebugString(), " vs. ",
57+
update.shape().DebugString()));
58+
59+
Tensor output = value; // This creates an alias intentionally.
60+
OP_REQUIRES_OK(ctx, DoCompute(ctx, update, loc, &output));
61+
ctx->set_output(0, output);
62+
}
63+
64+
protected:
65+
virtual Status DoCompute(OpKernelContext* ctx, const Tensor& value,
66+
const Tensor& loc, Tensor* output) = 0;
67+
};
68+
69+
namespace functor {
70+
71+
template <typename T>
72+
Status DoInplaceUpdate(const CPUDevice& d, InplaceOpType op,
73+
const Tensor& value, const Tensor& loc, Tensor* output) {
74+
auto Tloc = loc.flat<int64>();
75+
auto Tvalue = value.flat_outer_dims<T>();
76+
auto Toutput = output->flat_outer_dims<T>();
77+
auto nrows = Toutput.dimension(0);
78+
for (int64 j = 0; j < Tloc.size(); ++j) {
79+
auto r = (Tloc(j) % nrows + nrows) % nrows; // Guard index range.
80+
switch (op) {
81+
case I_UPDATE:
82+
Toutput.template chip<0>(r).device(d) = Tvalue.template chip<0>(j);
83+
break;
84+
case I_ADD:
85+
Toutput.template chip<0>(r).device(d) += Tvalue.template chip<0>(j);
86+
break;
87+
case I_SUB:
88+
Toutput.template chip<0>(r).device(d) -= Tvalue.template chip<0>(j);
89+
break;
90+
default:
91+
return errors::InvalidArgument("Unsupported inplace operation", op);
92+
}
93+
}
94+
return Status::OK();
95+
}
96+
97+
template <>
98+
Status DoInplace(const CPUDevice& d, InplaceOpType op, const Tensor& value,
99+
const Tensor& loc, Tensor* output) {
100+
CHECK_EQ(value.dtype(), output->dtype());
101+
switch (value.dtype()) {
102+
#define CASE(type) \
103+
case DataTypeToEnum<type>::value: \
104+
return DoInplaceUpdate<type>(d, op, value, loc, output);
105+
TF_CALL_NUMBER_TYPES(CASE);
106+
#undef CASE
107+
default:
108+
return errors::InvalidArgument("Unsupported data type: ", value.dtype());
109+
}
110+
}
111+
112+
} // end namespace functor
113+
114+
template <typename Device, functor::InplaceOpType op>
115+
class InplaceOp : public InplaceOpBase {
116+
public:
117+
explicit InplaceOp(OpKernelConstruction* ctx) : InplaceOpBase(ctx) {}
118+
119+
protected:
120+
Status DoCompute(OpKernelContext* ctx, const Tensor& value, const Tensor& loc,
121+
Tensor* output) override {
122+
const auto& d = ctx->eigen_device<Device>();
123+
return ::tensorflow::functor::DoInplace(d, op, value, loc, output);
124+
}
125+
};
126+
127+
template <typename Device, typename T>
128+
class EmptyOp : public OpKernel {
129+
public:
130+
explicit EmptyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
131+
OP_REQUIRES_OK(ctx, ctx->GetAttr("init", &init_));
132+
}
133+
134+
void Compute(OpKernelContext* ctx) override {
135+
const Tensor& shape = ctx->input(0);
136+
OP_REQUIRES(
137+
ctx, TensorShapeUtils::IsVector(shape.shape()),
138+
errors::InvalidArgument("shape must be a vector of int32, got shape ",
139+
shape.shape().DebugString()));
140+
auto dims = shape.flat<int32>();
141+
TensorShape out_shape;
142+
OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
143+
reinterpret_cast<const int32*>(dims.data()),
144+
dims.size(), &out_shape));
145+
Tensor* out = nullptr;
146+
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
147+
148+
if (init_) {
149+
functor::SetZeroFunctor<Device, T>()(ctx->eigen_device<Device>(),
150+
out->flat<T>());
151+
}
152+
}
153+
154+
private:
155+
bool init_;
156+
};
157+
158+
#define REGISTER(type) \
159+
REGISTER_KERNEL_BUILDER( \
160+
Name("InplaceUpdate").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
161+
InplaceOp<CPUDevice, functor::I_UPDATE>); \
162+
REGISTER_KERNEL_BUILDER( \
163+
Name("InplaceAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
164+
InplaceOp<CPUDevice, functor::I_ADD>); \
165+
REGISTER_KERNEL_BUILDER( \
166+
Name("InplaceSubtract").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
167+
InplaceOp<CPUDevice, functor::I_SUB>);
168+
TF_CALL_NUMBER_TYPES(REGISTER)
169+
#undef REGISTER
170+
171+
#define REGISTER_EMPTY(type) \
172+
REGISTER_KERNEL_BUILDER(Name("Empty") \
173+
.Device(DEVICE_CPU) \
174+
.HostMemory("shape") \
175+
.TypeConstraint<type>("dtype"), \
176+
EmptyOp<CPUDevice, type>)
177+
178+
TF_CALL_POD_STRING_TYPES(REGISTER_EMPTY)
179+
#undef REGISTER_EMPTY
180+
181+
#if GOOGLE_CUDA
182+
183+
typedef Eigen::GpuDevice GPUDevice;
184+
185+
#define REGISTER_EMPTY(type) \
186+
REGISTER_KERNEL_BUILDER(Name("Empty") \
187+
.Device(DEVICE_GPU) \
188+
.HostMemory("shape") \
189+
.TypeConstraint<type>("dtype"), \
190+
EmptyOp<GPUDevice, type>);
191+
TF_CALL_GPU_NUMBER_TYPES(REGISTER_EMPTY)
192+
#undef REGISTER_EMPTY
193+
194+
#define REGISTER(type) \
195+
REGISTER_KERNEL_BUILDER( \
196+
Name("InplaceUpdate").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
197+
InplaceOp<GPUDevice, functor::I_UPDATE>); \
198+
REGISTER_KERNEL_BUILDER( \
199+
Name("InplaceAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
200+
InplaceOp<GPUDevice, functor::I_ADD>); \
201+
REGISTER_KERNEL_BUILDER( \
202+
Name("InplaceSubtract").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
203+
InplaceOp<GPUDevice, functor::I_SUB>);
204+
TF_CALL_GPU_NUMBER_TYPES(REGISTER)
205+
#undef REGISTER
206+
207+
// Register versions that operate on int32 data on the CPU even though the op
208+
// has been placed on the GPU
209+
210+
REGISTER_KERNEL_BUILDER(Name("InplaceUpdate")
211+
.Device(DEVICE_GPU)
212+
.HostMemory("value")
213+
.HostMemory("loc")
214+
.HostMemory("update")
215+
.HostMemory("output")
216+
.TypeConstraint<int32>("T"),
217+
InplaceOp<CPUDevice, functor::I_UPDATE>);
218+
219+
REGISTER_KERNEL_BUILDER(Name("InplaceAdd")
220+
.Device(DEVICE_GPU)
221+
.HostMemory("value")
222+
.HostMemory("loc")
223+
.HostMemory("update")
224+
.HostMemory("output")
225+
.TypeConstraint<int32>("T"),
226+
InplaceOp<CPUDevice, functor::I_ADD>);
227+
228+
REGISTER_KERNEL_BUILDER(Name("InplaceSubtract")
229+
.Device(DEVICE_GPU)
230+
.HostMemory("value")
231+
.HostMemory("loc")
232+
.HostMemory("update")
233+
.HostMemory("output")
234+
.TypeConstraint<int32>("T"),
235+
InplaceOp<CPUDevice, functor::I_SUB>);
236+
#endif
237+
238+
} // end namespace tensorflow
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_KERNELS_INPLACE_FUNCTOR_H_
17+
#define TENSORFLOW_KERNELS_INPLACE_FUNCTOR_H_
18+
19+
#include "tensorflow/core/framework/tensor.h"
20+
#include "tensorflow/core/lib/core/status.h"
21+
22+
namespace tensorflow {
23+
namespace functor {
24+
25+
// Inplace update/add/sub values in 'y'. It computes
26+
// y[i, :] = v if op is I_UPDATE
27+
// y[i, :] += v if op is I_ADD
28+
// y[i, :] -= v if op is I_SUB
29+
enum InplaceOpType {
30+
I_UPDATE, // x = y
31+
I_ADD, // x += y
32+
I_SUB, // x -= y
33+
};
34+
35+
template <typename Device>
36+
Status DoInplace(const Device& device, InplaceOpType op, const Tensor& value,
37+
const Tensor& loc, Tensor* output);
38+
39+
} // end namespace functor
40+
} // end namespace tensorflow
41+
42+
#endif // TENSORFLOW_KERNELS_INPLACE_FUNCTOR_H_

0 commit comments

Comments
 (0)