Skip to content

[Bug] tf.where returns None gradient for differentiable expression #115116

@amadhan882

Description

@amadhan882

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

No

Source

source

TensorFlow version

2.19.0

Custom code

Yes

OS platform and distribution

Linux-6.6.113+-x86_64-with-glibc2.35 (Google Colab)

Mobile device

No response

Python version

3.12.13

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No GPU (CPU-only environment)

GPU model and memory

No response

Current behavior?

Using tf.where with differentiable branches results in a None gradient, even though the function is mathematically differentiable.

Observed Output

y: [1. 2. 9.]
grad: None

The gradient is not computed at all (None), leading to silent failure.

Standalone code to reproduce the issue

import tensorflow as tf

x = tf.Variable([-1.0, 2.0, -3.0])

with tf.GradientTape() as tape:
    y = tf.where(x > 0, x, x * x)

grad = tape.gradient(tf.reduce_sum(y), x)

print("y:", y.numpy())
print("grad:", grad)

Relevant log output

No error or warning is raised. Gradient is silently None.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions