Skip to content

Commit 78f84c0

Browse files
saudetkarllessard
authored andcommitted
Get rid of temporary output PointerScope objects
1 parent 467fb7a commit 78f84c0

3 files changed

Lines changed: 32 additions & 47 deletions

File tree

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,12 @@ private Tensor<?> resolveTensor(int outputIndex) {
133133
// Take an optimistic approach, where we attempt to resolve the output tensor without locking.
134134
// If another thread has resolved it meanwhile, release our copy and reuse the existing one
135135
// instead.
136-
try (PointerScope scope = new PointerScope()) {
137-
TF_Tensor tensorNativeHandle = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), scope);
138-
Tensor<?> tensor = Tensor.fromHandle(tensorNativeHandle, session);
139-
if (!outputTensors.compareAndSet(outputIndex, null, tensor)) {
140-
session.detach(tensorNativeHandle);
141-
tensor = outputTensors.get(outputIndex);
142-
}
143-
return tensor;
136+
Tensor<?> tensor = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), session);
137+
if (!outputTensors.compareAndSet(outputIndex, null, tensor)) {
138+
session.detach(tensor.nativeHandle());
139+
tensor = outputTensors.get(outputIndex);
144140
}
141+
return tensor;
145142
}
146143

147144
private TFE_Op opHandle;
@@ -159,14 +156,13 @@ private static void requireTensorHandle(TFE_TensorHandle handle) {
159156
}
160157
}
161158

162-
private static TF_Tensor resolveTensorHandle(TFE_TensorHandle handle, PointerScope outputScope) {
159+
private static Tensor<?> resolveTensorHandle(TFE_TensorHandle handle, EagerSession session) {
163160
requireTensorHandle(handle);
164161
try (PointerScope scope = new PointerScope()) {
165162
TF_Status status = TF_Status.newStatus();
166163
TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator();
167164
status.throwExceptionIfNotOK();
168-
outputScope.attach(tensor);
169-
return tensor;
165+
return Tensor.fromHandle(tensor, session);
170166
}
171167
}
172168

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -55,26 +55,21 @@
5555
final class EagerOperationBuilder implements OperationBuilder {
5656

5757
EagerOperationBuilder(EagerSession session, String type, String name) {
58-
try (PointerScope scope = new PointerScope()) {
59-
this.session = session;
60-
this.type = type;
61-
this.name = name;
62-
this.opHandle = allocate(session.nativeHandle(), type, scope);
63-
session.attach(opHandle);
64-
}
58+
this.session = session;
59+
this.type = type;
60+
this.name = name;
61+
this.opHandle = allocate(session, type);
6562
}
6663

6764
@Override
6865
public EagerOperation build() {
69-
try (PointerScope scope = new PointerScope()) {
70-
TFE_TensorHandle[] tensorHandles = execute(opHandle, scope);
71-
EagerOperation operation =
72-
new EagerOperation(session, opHandle, tensorHandles, type, name);
73-
// Release our reference to the native op handle now that we transferred its
74-
// ownership to the EagerOperation
75-
session.detach(opHandle);
76-
return operation;
77-
}
66+
TFE_TensorHandle[] tensorHandles = execute(opHandle, session);
67+
EagerOperation operation =
68+
new EagerOperation(session, opHandle, tensorHandles, type, name);
69+
// Release our reference to the native op handle now that we transferred its
70+
// ownership to the EagerOperation
71+
session.detach(opHandle);
72+
return operation;
7873
}
7974

8075
@Override
@@ -257,18 +252,18 @@ private static void requireTensorHandle(TFE_TensorHandle handle) {
257252
}
258253
}
259254

260-
private static TFE_Op allocate(TFE_Context ctxHandle, String type, PointerScope outputScope) {
261-
requireContext(ctxHandle);
255+
private static TFE_Op allocate(EagerSession session, String type) {
256+
requireContext(session.nativeHandle());
262257
try (PointerScope scope = new PointerScope()) {
263258
TF_Status status = TF_Status.newStatus();
264-
TFE_Op op = TFE_Op.newOp(ctxHandle, type, status);
259+
TFE_Op op = TFE_Op.newOp(session.nativeHandle(), type, status);
265260
status.throwExceptionIfNotOK();
266-
outputScope.attach(op);
261+
session.attach(op);
267262
return op;
268263
}
269264
}
270265

271-
private static TFE_TensorHandle[] execute(TFE_Op opHandle, PointerScope outputScope) {
266+
private static TFE_TensorHandle[] execute(TFE_Op opHandle, EagerSession session) {
272267
requireOp(opHandle);
273268
try (PointerScope scope = new PointerScope()) {
274269
IntPointer numRetvals = new IntPointer(1).put(MAX_OUTPUTS_PER_OP);
@@ -280,7 +275,7 @@ private static TFE_TensorHandle[] execute(TFE_Op opHandle, PointerScope outputSc
280275
TFE_TensorHandle[] rethandles = new TFE_TensorHandle[numRetvals.get()];
281276
for (int i = 0; i < rethandles.length; ++i) {
282277
rethandles[i] = retvals.get(TFE_TensorHandle.class, i).withDeallocator();
283-
outputScope.attach(rethandles[i]);
278+
session.attach(rethandles[i]);
284279
}
285280
return rethandles;
286281
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,6 @@ private Run runHelper(boolean wantMetadata) {
311311
TF_Operation[] outputOpHandles = new TF_Operation[outputs.size()];
312312
int[] outputOpIndices = new int[outputs.size()];
313313
TF_Operation[] targetOpHandles = new TF_Operation[targets.size()];
314-
TF_Tensor[] outputTensorHandles = new TF_Tensor[outputs.size()];
315314

316315
// It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the
317316
// validity of the Graph and graphRef ensures that.
@@ -338,7 +337,7 @@ private Run runHelper(boolean wantMetadata) {
338337
Reference runRef = new Reference();
339338
RunMetadata metadata = null;
340339
List<Tensor<?>> outputs = new ArrayList<>();
341-
try (PointerScope scope = new PointerScope()) {
340+
try {
342341
metadata =
343342
Session.run(
344343
nativeHandle,
@@ -350,11 +349,7 @@ private Run runHelper(boolean wantMetadata) {
350349
outputOpIndices,
351350
targetOpHandles,
352351
wantMetadata,
353-
outputTensorHandles,
354-
scope);
355-
for (TF_Tensor h : outputTensorHandles) {
356-
outputs.add(Tensor.fromHandle(h));
357-
}
352+
outputs);
358353
} catch (Exception e) {
359354
for (Tensor<?> t : outputs) {
360355
t.close();
@@ -554,8 +549,8 @@ private static void delete(TF_Session handle) {
554549
* @param targetOpHandles is the set of Operations in the graph that are to be executed but whose
555550
* output will not be returned
556551
* @param wantRunMetadata indicates whether metadata about this execution should be returned.
557-
* @param outputTensorHandles will be filled in with handles to the outputs requested. It is
558-
* required that outputTensorHandles.length == outputOpHandles.length.
552+
* @param outputTensors will be filled in with tensors to the outputs requested. It is
553+
* required that outputs.length == outputOpHandles.length.
559554
* @return if wantRunMetadata is true, a RunMetadata protocol buffer, false otherwise.
560555
*/
561556
private static RunMetadata run(
@@ -568,12 +563,11 @@ private static RunMetadata run(
568563
int[] outputOpIndices,
569564
TF_Operation[] targetOpHandles,
570565
boolean wantRunMetadata,
571-
TF_Tensor[] outputTensorHandles,
572-
PointerScope outputScope) {
566+
List<Tensor<?>> outputTensors) {
573567
requireHandle(handle);
574568

575569
int ninputs = inputTensorHandles.length;
576-
int noutputs = outputTensorHandles.length;
570+
int noutputs = outputOpHandles.length;
577571
int ntargets = targetOpHandles.length;
578572

579573
try (PointerScope scope = new PointerScope()) {
@@ -598,8 +592,8 @@ private static RunMetadata run(
598592
status.throwExceptionIfNotOK();
599593

600594
for (int i = 0; i < noutputs; ++i) {
601-
outputTensorHandles[i] = outputValues.get(TF_Tensor.class, i).withDeallocator();
602-
outputScope.attach(outputTensorHandles[i]);
595+
TF_Tensor h = outputValues.get(TF_Tensor.class, i).withDeallocator();
596+
outputTensors.add(Tensor.fromHandle(h));
603597
}
604598
try {
605599
return runMetadata != null ? RunMetadata.parseFrom(runMetadata.dataAsByteBuffer()) : null;

0 commit comments

Comments
 (0)