Skip to content

[XLA]tf.cond with more than 2 branches causes TypeError in jit_compile=True #113345

@Blooming-Tree

Description

@Blooming-Tree

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

Yes

Source

source

TensorFlow version

tf 2.20.0

Custom code

Yes

OS platform and distribution

linux ubuntu 24.04

Mobile device

No response

Python version

3.9

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current behavior?

When using tf.cond with more than 2 callable arguments (pred + 3 lambdas) inside a function decorated with @tf.function(jit_compile=True), a TypeError is raised indicating "expected string or bytes-like object". The error message does not clearly indicate the root cause (too many arguments), making debugging difficult.

Standalone code to reproduce the issue

import tensorflow as tf

class TestModel(tf.keras.Model):

    def __init__(self):
        super().__init__()
        self.d1 = tf.keras.layers.Dense(64, activation='tanh')
        self.d2 = tf.keras.layers.Dense(32, activation='sigmoid')
        self.d3 = tf.keras.layers.Dense(16)

    def call(self, x):
        x = self.d1(x)
        x = tf.nn.leaky_relu(self.d2(x), alpha=0.1)
        x = tf.cond(tf.reduce_mean(x) > 0, lambda: tf.cast(x, dtype=tf.float32), lambda: tf.cast(x, tf.float32), lambda: tf.cast(x, tf.float64))
        return self.d3(x)

def get_default_model():
    return TestModel()

def get_sample_inputs():
    x = tf.random.normal([8, 64], dtype=tf.float32)
    return (x,)

def main():
    model = get_default_model()
    inputs = get_sample_inputs()
    output = model(*inputs)
    print(f'input shape: {inputs[0].shape}')
    print(f'output shape: {output.shape}')
    @tf.function(jit_compile=True)
    def compiled_forward(*args):
        return model(*args)
    compiled_out = compiled_forward(*inputs)
    print('XLA Output shape:', compiled_out.shape)

if __name__ == '__main__':
    main()

Relevant log output

TypeError: Exception encountered when calling TestModel.call().
    
    expected string or bytes-like object
    
    Arguments received by TestModel.call():
      • x=tf.Tensor(shape=(8, 64), dtype=float32)

Metadata

Metadata

Assignees

Labels

2.20.0tensorflow 2.20.0comp:opsOPs related issuesstaleThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authortype:bugBug

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions