Skip to content

Conversation

pctablet505
Copy link
Collaborator

No description provided.

@codecov-commenter
Copy link

codecov-commenter commented May 6, 2025

Codecov Report

Attention: Patch coverage is 14.51613% with 53 lines in your changes missing coverage. Please review.

Project coverage is 82.53%. Comparing base (6ddaefb) to head (f60811e).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/jax/nn.py 14.51% 48 Missing and 5 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21254      +/-   ##
==========================================
- Coverage   82.59%   82.53%   -0.07%     
==========================================
  Files         564      564              
  Lines       54594    54642      +48     
  Branches     8483     8495      +12     
==========================================
+ Hits        45092    45098       +6     
- Misses       7415     7454      +39     
- Partials     2087     2090       +3     
Flag Coverage Δ
keras 82.34% <14.51%> (-0.07%) ⬇️
keras-jax 63.62% <14.51%> (-0.05%) ⬇️
keras-numpy 58.74% <0.00%> (-0.06%) ⬇️
keras-openvino 32.96% <0.00%> (-0.03%) ⬇️
keras-tensorflow 64.03% <0.00%> (-0.06%) ⬇️
keras-torch 63.69% <0.00%> (-0.06%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

strides: a sequence of `N` integers, representing the inter-window
strides (default: `(1, ..., 1)`).
strides (default: `(1, ..., 1)`).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add indent

"Sharding along sequence dimension not allowed in tpu kernel "
"attention"
"Sharding along sequence dimension not allowed"
" in tpu kernel attention"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TPU


Args:
query: Queries with shape `[batch, time, heads,
depth_k]`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use 4-space indent

Corrected indentation in doc string
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels May 7, 2025
@fchollet fchollet merged commit d8f3f70 into keras-team:master May 7, 2025
7 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase kokoro:force-run labels May 7, 2025
pctablet505 added a commit to pctablet505/keras that referenced this pull request May 27, 2025
fchollet pushed a commit that referenced this pull request May 27, 2025
pctablet505 added a commit to pctablet505/keras that referenced this pull request May 29, 2025
fchollet pushed a commit that referenced this pull request Jun 10, 2025
…after addressing cuDNN/FlashAttention API updates (#21333)

* Update nn.py

* Update nn.py

* Update nn.py

* Update nn.py

* Update nn.py

Corrected indentation in doc string

* Update nn.py

* Update random_grayscale.py

Fixed issue with passing a single image without batch dimension.

* Update keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py

Co-authored-by: Jyotinder Singh <[email protected]>

* Update random_grayscale_test.py

Test case for unbatched inputs

* code reformat

* Update random_grayscale_test.py

Testcase for checking both unbatched and batched single image inputs.

* changed compute_output_spec

There was a bug, and it was causing cycle in graph.

* Update random_grayscale.py

removed the use of tree.map_structure

* Reapply "Fixed issue with dot_product_attention when using TPU.  (#21254)" (#21329)

This reverts commit 81821e0.

* Improve error handling in _can_use_flash_attention for better debugging

Enhanced the _can_use_flash_attention function to provide more detailed
error messages when flash attention compatibility checks fail.

Changes:
- Replace generic exception catching with specific error propagation
- When raise_error=True, directly re-raise original exceptions from
  check_layout() and check_is_flash_attention() functions
- Preserve detailed error context from JAX internal validation functions
- Maintain existing behavior when raise_error=False (returns False)

This improves debugging experience by surfacing specific technical details
about tensor layout incompatibilities, cuDNN version requirements, and
other flash attention compatibility issues.

Relates to keras-hub PR #2257 and addresses flash attention debugging needs.

* Revert "Improve error handling in _can_use_flash_attention for better debugging"

This reverts commit 7a0c547.

* Fix JAX API compatibility and improve error handling in `_can_use_flash_attention`

Changes:
- Add missing q_offsets=None and kv_offsets=None parameters to check_layout()
  call to match updated JAX function signature
- Replace bare `except:` with `except Exception as e:` and `raise e` to
  preserve detailed error messages from JAX validation functions
- Maintain existing fallback behavior when raise_error=False

This resolves compatibility issues with newer JAX versions and improves
debugging experience by surfacing specific technical details about
flash attention compatibility failures.

* Updated `dot_product_attention`

Simplified the check for `flasth_attention` by removing redundant checks that are already done in `_can_use_flash_attention`.

* Update nn.py

* Update nn.py

---------

Co-authored-by: Jyotinder Singh <[email protected]>
fchollet pushed a commit that referenced this pull request Jul 14, 2025
* Update nn.py

* Update nn.py

* Update nn.py

* Update nn.py

* Update nn.py

Corrected indentation in doc string

* Update nn.py

* Update random_grayscale.py

Fixed issue with passing a single image without batch dimension.

* Update keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py

Co-authored-by: Jyotinder Singh <[email protected]>

* Update random_grayscale_test.py

Test case for unbatched inputs

* code reformat

* Update random_grayscale_test.py

Testcase for checking both unbatched and batched single image inputs.

* changed compute_output_spec

There was a bug, and it was causing cycle in graph.

* Update random_grayscale.py

removed the use of tree.map_structure

* Reapply "Fixed issue with dot_product_attention when using TPU.  (#21254)" (#21329)

This reverts commit 81821e0.

* Improve error handling in _can_use_flash_attention for better debugging

Enhanced the _can_use_flash_attention function to provide more detailed
error messages when flash attention compatibility checks fail.

Changes:
- Replace generic exception catching with specific error propagation
- When raise_error=True, directly re-raise original exceptions from
  check_layout() and check_is_flash_attention() functions
- Preserve detailed error context from JAX internal validation functions
- Maintain existing behavior when raise_error=False (returns False)

This improves debugging experience by surfacing specific technical details
about tensor layout incompatibilities, cuDNN version requirements, and
other flash attention compatibility issues.

Relates to keras-hub PR #2257 and addresses flash attention debugging needs.

* Revert "Improve error handling in _can_use_flash_attention for better debugging"

This reverts commit 7a0c547.

* Fix JAX API compatibility and improve error handling in `_can_use_flash_attention`

Changes:
- Add missing q_offsets=None and kv_offsets=None parameters to check_layout()
  call to match updated JAX function signature
- Replace bare `except:` with `except Exception as e:` and `raise e` to
  preserve detailed error messages from JAX validation functions
- Maintain existing fallback behavior when raise_error=False

This resolves compatibility issues with newer JAX versions and improves
debugging experience by surfacing specific technical details about
flash attention compatibility failures.

* Updated `dot_product_attention`

Simplified the check for `flasth_attention` by removing redundant checks that are already done in `_can_use_flash_attention`.

* Update nn.py

* Update nn.py

* Update image.py

* Update keras/src/backend/tensorflow/image.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Revert "Update keras/src/backend/tensorflow/image.py"

This reverts commit cb7e955.

* Update image.py

* Update image.py

---------

Co-authored-by: Jyotinder Singh <[email protected]>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants