-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update flash_attention_fwd_benchmark.py
#2265
Open
anmyachev
wants to merge
7
commits into
main
Choose a base branch
from
anmyachev-patch-1
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
+7
−9
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Signed-off-by: Anatoly Myachev <[email protected]>
anmyachev
force-pushed
the
anmyachev-patch-1
branch
from
September 19, 2024 15:05
80a8f26
to
92998b2
Compare
Signed-off-by: Anatoly Myachev <[email protected]>
Signed-off-by: Anatoly Myachev <[email protected]>
Signed-off-by: Anatoly Myachev <[email protected]>
This reverts commit de9335c.
Signed-off-by: Anatoly Myachev <[email protected]>
anmyachev
commented
Sep 19, 2024
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=sm_scale).to(torch.float32) | ||
atol = 1e-1 if N_CTX == 16384 else 1e-2 | ||
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch') | ||
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using ZE_FLAT_DEVICE_HIERARCHY=COMPOSITE
the available memory is doubled and there is no more out of memory error for upstream pytorch (however, this affects the performance)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
CI:
Error:
torch.OutOfMemoryError: XPU out of memory. Tried to allocate 32.00 GiB. GPU 0 has a total capacity of 64.00 GiB. Of the allocated memory 32.81 GiB is allocated by PyTorch, and 0 bytes is reserved by PyTorch but unallocated. Please use `empty_cache` to release all unoccupied cached memory.
It's strange that
a total capacity
is 64.00 GiB. I need to understand why (the expected capacity should be more in my understanding).UPD: Maybe it's related to https://spec.oneapi.io/level-zero/latest/core/PROG.html#environment-variables ZE_FLAT_DEVICE_HIERARCHY