Skip to content

Change persistent reduction threshold to 32 #147899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

PaulZhang12
Copy link
Contributor

@PaulZhang12 PaulZhang12 commented Feb 25, 2025

Summary:

Increasing threshold for inductor multikernel flag from 16->32 can lead to significant performance gain. This change is safe as TORCHINDUCTOR_MULTI_KERNEL is disabled by defaul

Example benchmark:

import torch
import torch.nn.functional as F
from triton.testing import do_bench
from torch._inductor import config as inductor_config
import math

def position_bias_softmax(scores, weight=None, pw_bias=False):
    scores = scores.to(torch.float32)
    context_position = torch.arange(2048, dtype=torch.long, device="cuda")[:, None]
    memory_position = torch.arange(2048, dtype=torch.long, device="cuda")[None, :]
    relative_position = memory_position - context_position  # shape (query_length, key_length)
    relative_buckets = 0
    num_buckets=32
    max_distance=128
    relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
    max_exact = num_buckets // 2
    is_small = relative_position < max_exact
    relative_position_if_large = max_exact + (
        torch.log(relative_position.float() / max_exact)
        / math.log(max_distance / max_exact)
        * (num_buckets - max_exact)
    ).to(torch.long)
    relative_position_if_large = torch.min(
        relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
    )

    relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
    values = F.embedding(relative_buckets, weight)
    values = values.permute([2, 0, 1]).unsqueeze(0) 
    scores = scores + values

    return F.softmax(scores, dim=-1).to(torch.float16)


scores = torch.randn(8, 2048, 2048, device="cuda", dtype=torch.float16)
weight = torch.randn(32, 1, device="cuda")
position_bias_softmax(scores, weight)
compiled = torch.compile(position_bias_softmax)

compiled(scores, weight=weight)
gb = 2 * scores.element_size() * scores.numel() / 1e9
sec = do_bench(lambda: compiled(scores, weight=weight)) / 1e3
print(f"weighted bias gb/s: {gb/sec}")

With this change: gb/s: 987.0799446648006
Baseline: gb/s: 693.3391918370983

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @shunting314 @eellison

Copy link

pytorch-bot bot commented Feb 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147899

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure

As of commit 93e71d1 with merge base 1e894d2 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Feb 25, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: PaulZhang12 / name: Paul Zhang (93e71d1)

@shunting314
Copy link
Contributor

More context from this meta internal post: https://fb.workplace.com/groups/1075192433118967/posts/1612836222687916/

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works, but ideally would work out of the box, following similar analysis we did of #141916 where persistent reductions result in less memory.

We should be able to use the memory analysis that @jansel did #142026.

cc @FindHao

Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should use SIMDKernelFeatures to write a better heuristic here. Though this change seems fine in the shorter term.

@PaulZhang12
Copy link
Contributor Author

/easycla

1 similar comment
@PaulZhang12
Copy link
Contributor Author

/easycla

Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Apr 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants