-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[TensorExpr] Cache use of fallback in kernel invocation #47812
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: I compared scripted before of def foo(x, y): return x + y for x, y = torch.tensor([1]). I also removed the prim::TypeCheck node to better isolate the kernel (I cheated). Here is gist: https://gist.github.com/eellison/39f3bc368f5bd1f25ded4827feecd15e Without Changes Run 1: no fusion: sum 6.416894399004377 min: 0.6101883250012179 median 0.6412974080012646 with fusion: sum 6.437897570998757 min: 0.6350401220006461 median 0.6446951820034883 Without Changes Run2: no fusion: sum 6.601341788002173 min: 0.6292048720024468 median 0.6642187059987918 with fusion: sum 6.734651455997664 min: 0.6365462899993872 median 0.6755226659988693 With Changes Run1: no fusion: sum 6.097717430002376 min: 0.5977709550024883 median 0.613631643998815 with fusion: sum 6.1299369639964425 min: 0.5857932209983119 median 0.6159247440009494 With Changes Run2: no fusion: sum 6.5672018059995025 min: 0.6245676209982776 median 0.6386050750006689 with fusion: sum 6.489086147994385 min: 0.6236886289989343 median 0.6535737619997235 Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit f838a25 (more details on the Dr. CI page):
🚧 2 ongoing upstream failures:These were probably caused by upstream breakages that are not fixed yet:
🚧 1 fixed upstream failure:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
Summary: I compared scripted before of def foo(x, y): return x + y for x, y = torch.tensor([1]). I also removed the prim::TypeCheck node to better isolate the kernel (I cheated). Here is gist: https://gist.github.com/eellison/39f3bc368f5bd1f25ded4827feecd15e Without Changes Run 1: no fusion: sum 6.416894399004377 min: 0.6101883250012179 median 0.6412974080012646 with fusion: sum 6.437897570998757 min: 0.6350401220006461 median 0.6446951820034883 Without Changes Run2: no fusion: sum 6.601341788002173 min: 0.6292048720024468 median 0.6642187059987918 with fusion: sum 6.734651455997664 min: 0.6365462899993872 median 0.6755226659988693 With Changes Run1: no fusion: sum 6.097717430002376 min: 0.5977709550024883 median 0.613631643998815 with fusion: sum 6.1299369639964425 min: 0.5857932209983119 median 0.6159247440009494 With Changes Run2: no fusion: sum 6.5672018059995025 min: 0.6245676209982776 median 0.6386050750006689 with fusion: sum 6.489086147994385 min: 0.6236886289989343 median 0.6535737619997235 Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
runKernel(stack); | ||
} catch (...) { | ||
fallback_ = true; | ||
} else if (!use_fallback_ && allow_fallback_) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: I wonder if this structure might be a bit easier to read
void TensorExprKernel::run(Stack& stack) {
if (!use_fallback_) {
try {
runKernel(stack);
return;
}
catch (std::exception& e) {
if (!allow_fallback_) {
throw e; // re-throw
}
// fall-through to `fallback`
}
}
fallback(stack);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was trying to optimize for the common case - i guess try catch doesn't have a pereformance penalty if exceptions aren;t thrown though (?)
compile(); | ||
return; | ||
} | ||
|
||
use_fallback_ = fallbackEnforced(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: maybe renaming use_fallback_
to failed_comp_or_required_fallback_
makes it a bit more obvious in which cases we are using the fallback?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is fallback actually useful anymore? Since we default it to off should we consider killing it?
I think I use the fallback for Block Codegen. I can look into removing the dependency on fallback. |
Summary: Previously we were checking the environment every kernel invocation for `tensorExprFuserEnabled`, which checks the environment for `PYTORCH_TENSOREXPR`. This is only a dev-exposed API, so I think it is fine to only check once when the kernel is initialized. The `disable_optimization` flag which is user-exposed more or less covers the same functionality. For fun, some benchmarking. I compared scripted before and after of ``` def foo(x, y): return x + y ``` for x, y = torch.tensor([1]). I also removed the prim::TypeCheck node to better isolate the kernel (I cheated). Here is gist: https://gist.github.com/eellison/39f3bc368f5bd1f25ded4827feecd15e Without Changes Run 1: no fusion: sum 6.416894399004377 min: 0.6101883250012179 median 0.6412974080012646 with fusion: sum 6.437897570998757 min: 0.6350401220006461 median 0.6446951820034883 Without Changes Run2: no fusion: sum 6.601341788002173 min: 0.6292048720024468 median 0.6642187059987918 with fusion: sum 6.734651455997664 min: 0.6365462899993872 median 0.6755226659988693 With Changes Run1: no fusion: sum 6.097717430002376 min: 0.5977709550024883 median 0.613631643998815 with fusion: sum 6.1299369639964425 min: 0.5857932209983119 median 0.6159247440009494 With Changes Run2: no fusion: sum 6.5672018059995025 min: 0.6245676209982776 median 0.6386050750006689 with fusion: sum 6.489086147994385 min: 0.6236886289989343 median 0.6535737619997235 Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Previously we were checking the environment every kernel invocation for `tensorExprFuserEnabled`, which checks the environment for `PYTORCH_TENSOREXPR`. This is only a dev-exposed API, so I think it is fine to only check once when the kernel is initialized. The `disable_optimization` flag which is user-exposed more or less covers the same functionality. For fun, some benchmarking. I compared scripted before and after of ``` def foo(x, y): return x + y ``` for x, y = torch.tensor([1]). I also removed the prim::TypeCheck node to better isolate the kernel (I cheated). Here is gist: https://gist.github.com/eellison/39f3bc368f5bd1f25ded4827feecd15e Without Changes Run 1: no fusion: sum 6.416894399004377 min: 0.6101883250012179 median 0.6412974080012646 with fusion: sum 6.437897570998757 min: 0.6350401220006461 median 0.6446951820034883 Without Changes Run2: no fusion: sum 6.601341788002173 min: 0.6292048720024468 median 0.6642187059987918 with fusion: sum 6.734651455997664 min: 0.6365462899993872 median 0.6755226659988693 With Changes Run1: no fusion: sum 6.097717430002376 min: 0.5977709550024883 median 0.613631643998815 with fusion: sum 6.1299369639964425 min: 0.5857932209983119 median 0.6159247440009494 With Changes Run2: no fusion: sum 6.5672018059995025 min: 0.6245676209982776 median 0.6386050750006689 with fusion: sum 6.489086147994385 min: 0.6236886289989343 median 0.6535737619997235 Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Previously we were checking the environment every kernel invocation for `tensorExprFuserEnabled`, which checks the environment for `PYTORCH_TENSOREXPR`. This is only a dev-exposed API, so I think it is fine to only check once when the kernel is initialized. The `disable_optimization` flag which is user-exposed more or less covers the same functionality. For fun, some benchmarking. I compared scripted before and after of ``` def foo(x, y): return x + y ``` for x, y = torch.tensor([1]). I also removed the prim::TypeCheck node to better isolate the kernel (I cheated). Here is gist: https://gist.github.com/eellison/39f3bc368f5bd1f25ded4827feecd15e Without Changes Run 1: no fusion: sum 6.416894399004377 min: 0.6101883250012179 median 0.6412974080012646 with fusion: sum 6.437897570998757 min: 0.6350401220006461 median 0.6446951820034883 Without Changes Run2: no fusion: sum 6.601341788002173 min: 0.6292048720024468 median 0.6642187059987918 with fusion: sum 6.734651455997664 min: 0.6365462899993872 median 0.6755226659988693 With Changes Run1: no fusion: sum 6.097717430002376 min: 0.5977709550024883 median 0.613631643998815 with fusion: sum 6.1299369639964425 min: 0.5857932209983119 median 0.6159247440009494 With Changes Run2: no fusion: sum 6.5672018059995025 min: 0.6245676209982776 median 0.6386050750006689 with fusion: sum 6.489086147994385 min: 0.6236886289989343 median 0.6535737619997235 Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D25286210](https://our.internmc.facebook.com/intern/diff/D25286210) [ghstack-poisoned]
This pull request has been merged in 0e666a9. |
Stack from ghstack:
Summary:
Previously we were checking the environment every kernel invocation for
tensorExprFuserEnabled
, which checks the environment forPYTORCH_TENSOREXPR
. This is only a dev-exposed API, so I think it is fine to only check once when the kernel is initialized. Thedisable_optimization
flag which is user-exposed more or less covers the same functionality.For fun, some benchmarking. I compared scripted before and after of
for x, y = torch.tensor([1]). I also removed the prim::TypeCheck node to better
isolate the kernel (I cheated). Here is gist: https://gist.github.com/eellison/39f3bc368f5bd1f25ded4827feecd15e
Without Changes Run 1:
no fusion: sum 6.416894399004377 min: 0.6101883250012179 median 0.6412974080012646
with fusion: sum 6.437897570998757 min: 0.6350401220006461 median 0.6446951820034883
Without Changes Run2:
no fusion: sum 6.601341788002173 min: 0.6292048720024468 median 0.6642187059987918
with fusion: sum 6.734651455997664 min: 0.6365462899993872 median 0.6755226659988693
With Changes Run1:
no fusion: sum 6.097717430002376 min: 0.5977709550024883 median 0.613631643998815
with fusion: sum 6.1299369639964425 min: 0.5857932209983119 median 0.6159247440009494
With Changes Run2:
no fusion: sum 6.5672018059995025 min: 0.6245676209982776 median 0.6386050750006689
with fusion: sum 6.489086147994385 min: 0.6236886289989343 median 0.6535737619997235
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D25286210