Skip to content

Commit e5a615a

Browse files
TensorListElementShape should return scalar with value -1 if shape has unknown
rank. PiperOrigin-RevId: 218418543
1 parent 5d04280 commit e5a615a

2 files changed

Lines changed: 21 additions & 5 deletions

File tree

tensorflow/core/kernels/list_kernels.cc

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,13 +274,22 @@ class TensorListElementShape : public OpKernel {
274274
"list. Saw: '",
275275
c->input(0).scalar<Variant>()().DebugString(), "'"));
276276
Tensor* result;
277-
OP_REQUIRES_OK(c, c->allocate_output(
278-
0, TensorShape{l->element_shape.dims()}, &result));
279-
for (int i = 0; i < l->element_shape.dims(); ++i) {
277+
if (l->element_shape.unknown_rank()) {
278+
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &result));
280279
if (result->dtype() == DT_INT32) {
281-
result->flat<int32>()(i) = l->element_shape.dim_size(i);
280+
result->scalar<int32>()() = -1;
282281
} else {
283-
result->flat<int64>()(i) = l->element_shape.dim_size(i);
282+
result->scalar<int64>()() = -1;
283+
}
284+
} else {
285+
OP_REQUIRES_OK(c, c->allocate_output(
286+
0, TensorShape{l->element_shape.dims()}, &result));
287+
for (int i = 0; i < l->element_shape.dims(); ++i) {
288+
if (result->dtype() == DT_INT32) {
289+
result->flat<int32>()(i) = l->element_shape.dim_size(i);
290+
} else {
291+
result->flat<int64>()(i) = l->element_shape.dim_size(i);
292+
}
284293
}
285294
}
286295
}

tensorflow/python/kernel_tests/list_ops_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,13 @@ def testZerosLikeVariant(self):
678678
self.assertAllEqual(
679679
self.evaluate(outputs[0]), np.zeros((2,), dtype=dtype.as_numpy_dtype))
680680

681+
@test_util.run_in_graph_and_eager_modes
682+
def testElementShape(self):
683+
l = list_ops.empty_tensor_list(
684+
element_dtype=dtypes.float32, element_shape=-1)
685+
shape = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32)
686+
self.assertEqual(self.evaluate(shape), -1)
687+
681688

682689
if __name__ == "__main__":
683690
test.main()

0 commit comments

Comments
 (0)