Join GitHub today
GitHub is home to over 50 million developers working together to host and review code, manage projects, and build software together.
Sign up`nan` gradient when `tf.where` is used #38349
Comments
I have tried on colab with TF version 2.1.0 , 2.2.0-rc2 and was able to reproduce the issue.Please, find the gist here. Thanks! |
This is due to a limitation limitation in how gradients are calculated. Unfortunately, it is unlikely to be fixed in the foreseable future. You can find more detail here, along with a recipe for how to avoid it: https://stackoverflow.com/questions/33712178/tensorflow-nan-bug/42497444#42497444 In short, if the input to a tf.where contains NaNs, the gradient will always be NaN, regardless whether the input is actually used or not, and the workaround is to prevent the inputs from ever containing NaNs. |
Shouldn't this be documented with big warning in |
Indeed it should. |
@mdanatg Hello, this is my first time contributing to TensofFlow lib. From the thread I gather you would require the |
Hello @0x0badc0de , @mdanatg |
Sorry for the delay. Feel free to send a PR - it's only a matter of adding a paragraph to the docstring. The text should be more in the lines of a warning. Something like: Important: if any of the inputs contain NaN values, etc.. And yes, it should include the workaround as well, which is something in the lines of: instead of |
@mdanatg Thanks for your reply. However, I would like to mention that this behavior also happens when the generated value in the inactive branch is not finite (i.e. import tensorflow as tf
a = tf.Variable(10.)
with tf.GradientTape() as tape:
out = tf.where(a < 15., a, tf.math.pow(10.0, tf.math.exp(a)))
grads = tape.gradient(out, a)
print(grads)
# tf.Tensor(nan, shape=(), dtype=float32) And also if we reverse the condition such that the branch with infinite value is selected, the gradient would be infinite (which is a bit surprising that it does not generate with tf.GradientTape() as tape:
out = tf.where(a > 15., a, tf.math.pow(10.0, tf.math.exp(a)))
grads = tape.gradient(out, a)
print(grads)
# tf.Tensor(inf, shape=(), dtype=float32) So this behavior happens for both CC: @anorak-k for potential consideration in your PR after @mdanatg confirms this. |
@mkaze that's true - nan, inf and any other special FP value will disrupt the gradient calculation. What happens internally is that the gradients are aggregated in this fashion: Moreover, the forward calculation doesn't need to result in a nan or inf. You can also get weird results if the gradient alone is nan or inf. For example, the cube root function is defined and well-behaved everywhere, but its derivative at zero is infinite. So this will give you a nan gradient too:
I think the tf.where workaround is useful with infinite values as well, so long as the branch not taken is forced to take a gradient that can be safely multiplied by 0. For your example, it would be something like this:
I agree that it sometimes can be impractical to do, but in principle it should always be possible as long as you control the inputs to the sensitive functions - all they have to do is force finite values in all the elements that are dropped. |
I want to fix the issue #38349 |
You can simply have it raise a value error if its getting Nan inputs. Or does it not work like that? |
Can I work on this issue if someone isn't now? |
@tushar-dalal The challenge is that verifying for such NaN inputs can be taking on performance. When debugging, @unicorn-io Feel free to tackle it, but note that it's extremely challenging to solve. That said, there was a PR (#38467) to add a warning message to the docs of tf.where, it would be useful to revive it. |
I am motivated to do this can you give me some tips to start with I will try my best to understand and resolve this issue. |
|
@unicorn-io You can start by looking at the gradient code and understanding how it works. Then you can reproduce when happens in the case of a where with bad gradients. |
Cool I'll get to it |
Hey i would like to work on it. can also help please |
This bug cannot be fixed as of now it seems. |
It's indeed very challenging to fix. However, the documentation of affected ops, like |
You mean #38467? It's closed due to staleness, and it would be useful to revive. By the looks of it it's safe to assume noone else is working on it. |
Seems like its a long time since the last activity. Is this issue still open to be worked on? |
I think so. There are two parts to it: (1) updating the docs of tf.where, which is fairly straightforward, and (2) actually trying to address the issue, which is a significant undertaking because it involves a rather fundamental issue. |
Is this issue still addressable ? |
Please make sure that this is a bug. As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template
System information
example script provided in TensorFlow): Yes
Linux Ubuntu 16.04): Debian GNU/Linux 10 (buster)
the issue happens on mobile device:
binary): binary
You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with: 1. TF 1.0:
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
2. TF 2.0:python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"
Describe the current behavior
Well-defined function with
tf.where
hasnan
gradients at points wheretf.where
inactive branch is undefined.Describe the expected behavior
Inactive branch should be ignored in gradients calculations.
Standalone code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
All 3 functions above are well defined for positive values used for testing. Still they show no gradient at point
1.
. while it has to be equal to1.
Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.