Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
Yes
Source
source
TensorFlow version
tf 2.20.0
Custom code
Yes
OS platform and distribution
linux ubuntu 24.04
Mobile device
No response
Python version
3.9
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current behavior?
When using tf.cond with more than 2 callable arguments (pred + 3 lambdas) inside a function decorated with @tf.function(jit_compile=True), a TypeError is raised indicating "expected string or bytes-like object". The error message does not clearly indicate the root cause (too many arguments), making debugging difficult.
Standalone code to reproduce the issue
import tensorflow as tf
class TestModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.d1 = tf.keras.layers.Dense(64, activation='tanh')
self.d2 = tf.keras.layers.Dense(32, activation='sigmoid')
self.d3 = tf.keras.layers.Dense(16)
def call(self, x):
x = self.d1(x)
x = tf.nn.leaky_relu(self.d2(x), alpha=0.1)
x = tf.cond(tf.reduce_mean(x) > 0, lambda: tf.cast(x, dtype=tf.float32), lambda: tf.cast(x, tf.float32), lambda: tf.cast(x, tf.float64))
return self.d3(x)
def get_default_model():
return TestModel()
def get_sample_inputs():
x = tf.random.normal([8, 64], dtype=tf.float32)
return (x,)
def main():
model = get_default_model()
inputs = get_sample_inputs()
output = model(*inputs)
print(f'input shape: {inputs[0].shape}')
print(f'output shape: {output.shape}')
@tf.function(jit_compile=True)
def compiled_forward(*args):
return model(*args)
compiled_out = compiled_forward(*inputs)
print('XLA Output shape:', compiled_out.shape)
if __name__ == '__main__':
main()
Relevant log output
TypeError: Exception encountered when calling TestModel.call().
expected string or bytes-like object
Arguments received by TestModel.call():
• x=tf.Tensor(shape=(8, 64), dtype=float32)
Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
Yes
Source
source
TensorFlow version
tf 2.20.0
Custom code
Yes
OS platform and distribution
linux ubuntu 24.04
Mobile device
No response
Python version
3.9
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current behavior?
When using
tf.condwith more than 2 callable arguments (pred + 3 lambdas) inside a function decorated with@tf.function(jit_compile=True), aTypeErroris raised indicating "expected string or bytes-like object". The error message does not clearly indicate the root cause (too many arguments), making debugging difficult.Standalone code to reproduce the issue
import tensorflow as tf class TestModel(tf.keras.Model): def __init__(self): super().__init__() self.d1 = tf.keras.layers.Dense(64, activation='tanh') self.d2 = tf.keras.layers.Dense(32, activation='sigmoid') self.d3 = tf.keras.layers.Dense(16) def call(self, x): x = self.d1(x) x = tf.nn.leaky_relu(self.d2(x), alpha=0.1) x = tf.cond(tf.reduce_mean(x) > 0, lambda: tf.cast(x, dtype=tf.float32), lambda: tf.cast(x, tf.float32), lambda: tf.cast(x, tf.float64)) return self.d3(x) def get_default_model(): return TestModel() def get_sample_inputs(): x = tf.random.normal([8, 64], dtype=tf.float32) return (x,) def main(): model = get_default_model() inputs = get_sample_inputs() output = model(*inputs) print(f'input shape: {inputs[0].shape}') print(f'output shape: {output.shape}') @tf.function(jit_compile=True) def compiled_forward(*args): return model(*args) compiled_out = compiled_forward(*inputs) print('XLA Output shape:', compiled_out.shape) if __name__ == '__main__': main()Relevant log output