Skip to content

Commit 725e968

Browse files
girvingtensorflower-gardener
authored andcommitted
Remove race conditions from TensorShapeUtils::MakeShape
Also move some functions from header to C++. Change: 117255894
1 parent b10c65c commit 725e968

4 files changed

Lines changed: 41 additions & 23 deletions

File tree

tensorflow/core/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,7 @@ tf_cuda_library(
799799
":lib",
800800
":lib_internal",
801801
":protos_all_cc",
802+
"//tensorflow/core/kernels:bounds_check",
802803
"//third_party/eigen3",
803804
],
804805
alwayslink = 1,

tensorflow/core/framework/tensor_shape.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#include "tensorflow/core/framework/tensor_shape.h"
1717

18+
#include "tensorflow/core/kernels/bounds_check.h"
1819
#include "tensorflow/core/lib/core/errors.h"
1920
#include "tensorflow/core/lib/strings/str_util.h"
2021
#include "tensorflow/core/lib/strings/strcat.h"
@@ -327,4 +328,38 @@ bool TensorShapeUtils::StartsWith(const TensorShape& shape,
327328
return true;
328329
}
329330

331+
template <typename T>
332+
static inline Status MakeShapeHelper(const T* dims, int n, TensorShape* out) {
333+
*out = TensorShape();
334+
for (int i = 0; i < n; ++i) {
335+
const T dim = internal::SubtleMustCopy(dims[i]);
336+
if (dim >= 0) {
337+
out->AddDim(dim);
338+
} else {
339+
return errors::InvalidArgument("Dimension ", dim, " must be >= 0");
340+
}
341+
}
342+
return Status::OK();
343+
}
344+
345+
#define MAKE_SHAPE(T) \
346+
Status TensorShapeUtils::MakeShape(const T* dims, int n, TensorShape* out) { \
347+
return MakeShapeHelper(dims, n, out); \
348+
}
349+
MAKE_SHAPE(int32)
350+
MAKE_SHAPE(int64)
351+
#undef MAKE_SHAPE
352+
353+
string TensorShapeUtils::ShapeListString(
354+
const gtl::ArraySlice<TensorShape>& shapes) {
355+
string result = "[";
356+
bool first = true;
357+
for (const TensorShape& shape : shapes) {
358+
strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
359+
first = false;
360+
}
361+
strings::StrAppend(&result, "]");
362+
return result;
363+
}
364+
330365
} // namespace tensorflow

tensorflow/core/framework/tensor_shape.h

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -252,29 +252,10 @@ class TensorShapeUtils {
252252

253253
/// \brief Returns a `TensorShape` whose dimensions are
254254
/// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
255-
template <typename T>
256-
static Status MakeShape(const T* dims, int n, TensorShape* out) {
257-
*out = TensorShape();
258-
for (int i = 0; i < n; ++i) {
259-
if (dims[i] >= 0) {
260-
out->AddDim(dims[i]);
261-
} else {
262-
return errors::InvalidArgument("Dimension ", dims[i], " must be >= 0");
263-
}
264-
}
265-
return Status::OK();
266-
}
255+
static Status MakeShape(const int32* dims, int n, TensorShape* out);
256+
static Status MakeShape(const int64* dims, int n, TensorShape* out);
267257

268-
static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes) {
269-
string result = "[";
270-
bool first = true;
271-
for (const TensorShape& shape : shapes) {
272-
strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
273-
first = false;
274-
}
275-
strings::StrAppend(&result, "]");
276-
return result;
277-
}
258+
static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes);
278259

279260
static bool StartsWith(const TensorShape& shape0, const TensorShape& shape1);
280261
};

tensorflow/core/kernels/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ cc_library(
203203
hdrs = ["bounds_check.h"],
204204
visibility = ["//tensorflow:__subpackages__"],
205205
deps = [
206-
"//tensorflow/core:framework",
206+
"//tensorflow/core:lib",
207207
"//third_party/eigen3",
208208
],
209209
)
@@ -1002,6 +1002,7 @@ filegroup(
10021002
name = "android_srcs",
10031003
srcs = [
10041004
"avgpooling_op.h",
1005+
"bounds_check.h",
10051006
"maxpooling_op.h",
10061007
"ops_util.cc",
10071008
"ops_util.h",

0 commit comments

Comments
 (0)