Skip to content

Commit d9b9fa4

Browse files
alextptensorflower-gardener
authored andcommitted
Disable input forwarding for tensors fed from python or returned from py_func.
Change: 151591184
1 parent e9db82f commit d9b9fa4

5 files changed

Lines changed: 18 additions & 1 deletion

File tree

tensorflow/c/c_api.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ class TF_ManagedBuffer : public TensorBuffer {
135135
proto->set_requested_bytes(rb);
136136
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
137137
}
138+
139+
// Prevents input forwarding from mutating this buffer.
140+
bool OwnsMemory() const override { return false; }
138141
};
139142

140143
void* allocate_tensor(const char* operation, size_t len) {

tensorflow/core/framework/tensor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ void Tensor::UnsafeCopyFromInternal(const Tensor& other, DataType dtype,
531531
// one both for the SubBuffer _and_ the underlying TensorBuffer.
532532
bool Tensor::RefCountIsOne() const {
533533
return buf_ != nullptr && buf_->RefCountIsOne() &&
534-
buf_->root_buffer()->RefCountIsOne();
534+
buf_->root_buffer()->RefCountIsOne() && buf_->OwnsMemory();
535535
}
536536

537537
// The macro CASES() expands to a switch statement conditioned on

tensorflow/core/framework/tensor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,9 @@ class TensorBuffer : public core::RefCounted {
502502
T* base() const {
503503
return reinterpret_cast<T*>(data());
504504
}
505+
506+
// Whether this TensorBuffer owns the underlying memory.
507+
virtual bool OwnsMemory() const { return true; }
505508
};
506509

507510
template <typename T>

tensorflow/python/kernel_tests/py_func_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,14 @@ def testCleanup(self):
160160
_ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
161161
self.assertTrue(script_ops._py_funcs.size() < 100)
162162

163+
def testAlias(self):
164+
with self.test_session():
165+
np_array = np.array([1.0, 2.0], dtype=np.float32)
166+
tf_array = script_ops.py_func(lambda: np_array, [], [dtypes.float32])
167+
value = tf_array + constant_op.constant([2.0, 3.0], dtype=dtypes.float32)
168+
value.op.run()
169+
self.assertAllEqual(np_array, [1.0, 2.0])
170+
163171
def testBadNumpyReturnType(self):
164172
with self.test_session():
165173

tensorflow/python/lib/core/py_func.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,9 @@ class NumpyTensorBuffer : public TensorBuffer {
254254
return Tensor(dtype, shape, this);
255255
}
256256

257+
// Prevents input forwarding from overwriting this buffer.
258+
bool OwnsMemory() const override { return false; }
259+
257260
private:
258261
PyArrayObject* array_;
259262
size_t len_;

0 commit comments

Comments
 (0)