Fix tf.linalg.expm float16 support by computing in float32#122575
Fix tf.linalg.expm float16 support by computing in float32#122575arechaithanya wants to merge 5 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for float16 inputs in tf.linalg.expm by casting the input to float32 during computation and casting the result back to float16, along with a corresponding unit test. The review feedback recommends extending this support to bfloat16 inputs as well, since they face the same execution limitations, and expanding the unit tests to cover both data types.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if original_dtype == dtypes.float16: | ||
| result = math_ops.cast(result, original_dtype) |
There was a problem hiding this comment.
Cast the result back to the original dtype if it was either float16 or bfloat16.
| if original_dtype == dtypes.float16: | |
| result = math_ops.cast(result, original_dtype) | |
| if original_dtype in (dtypes.float16, dtypes.bfloat16): | |
| result = math_ops.cast(result, original_dtype) |
References
- Ensure robustness across different hardware targets and validate correctness of tensor operations and data types. (link)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…st.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Fixes #121912
Summary
tf.linalg.expmdocuments support forfloat16, but it fails at runtime becauseMatrixSolvehas noDT_HALFCPU/GPU kernel in eager execution.This change computes the matrix exponential in
float32forfloat16inputsand casts the result back to
float16, preserving the documented API whileavoiding the unsupported kernel.
Testing
float16path.not available in my development environment. I rely on the project's CI to
validate the change.