Skip to content

Fix tf.linalg.expm float16 support by computing in float32#122575

Open
arechaithanya wants to merge 5 commits into
tensorflow:masterfrom
arechaithanya:fix-expm-float16
Open

Fix tf.linalg.expm float16 support by computing in float32#122575
arechaithanya wants to merge 5 commits into
tensorflow:masterfrom
arechaithanya:fix-expm-float16

Conversation

@arechaithanya

Copy link
Copy Markdown

Fixes #121912

Summary

tf.linalg.expm documents support for float16, but it fails at runtime because
MatrixSolve has no DT_HALF CPU/GPU kernel in eager execution.

This change computes the matrix exponential in float32 for float16 inputs
and casts the result back to float16, preserving the documented API while
avoiding the unsupported kernel.

Testing

  • Added a regression test covering the float16 path.
  • I was unable to run the TensorFlow Bazel test suite locally because Bazel was
    not available in my development environment. I rely on the project's CI to
    validate the change.

@google-ml-butler google-ml-butler Bot added the size:S CL Change Size: Small label Jul 3, 2026
@google-ml-butler google-ml-butler Bot requested a review from cantonios July 3, 2026 10:22
@google-ml-butler google-ml-butler Bot added the awaiting review Pull request awaiting review label Jul 3, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for float16 inputs in tf.linalg.expm by casting the input to float32 during computation and casting the result back to float16, along with a corresponding unit test. The review feedback recommends extending this support to bfloat16 inputs as well, since they face the same execution limitations, and expanding the unit tests to cover both data types.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread tensorflow/python/ops/linalg/linalg_impl.py
Comment on lines +357 to +358
if original_dtype == dtypes.float16:
result = math_ops.cast(result, original_dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Cast the result back to the original dtype if it was either float16 or bfloat16.

Suggested change
if original_dtype == dtypes.float16:
result = math_ops.cast(result, original_dtype)
if original_dtype in (dtypes.float16, dtypes.bfloat16):
result = math_ops.cast(result, original_dtype)
References
  1. Ensure robustness across different hardware targets and validate correctness of tensor operations and data types. (link)

Comment thread tensorflow/python/kernel_tests/linalg/matrix_exponential_op_test.py Outdated
arechaithanya and others added 2 commits July 3, 2026 15:55
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…st.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@nithyak0204 nithyak0204 requested a review from a team July 3, 2026 17:11
@nithyak0204 nithyak0204 added prtype:bugfix PR to fix a bug python Pull requests that update Python code labels Jul 3, 2026
@github-project-automation github-project-automation Bot moved this to Assigned Reviewer in PR Queue Jul 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

awaiting review Pull request awaiting review prtype:bugfix PR to fix a bug python Pull requests that update Python code size:S CL Change Size: Small

Projects

Status: Assigned Reviewer

Development

Successfully merging this pull request may close these issues.

tf.linalg.expm fails for documented float16 input because MatrixSolve has no DT_HALF CPU/GPU kernel

3 participants