Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

`nan` gradient when `tf.where` is used #38349

Open
0x0badc0de opened this issue Apr 8, 2020 · 27 comments
Open

`nan` gradient when `tf.where` is used #38349

0x0badc0de opened this issue Apr 8, 2020 · 27 comments

Comments

@0x0badc0de
Copy link

@0x0badc0de 0x0badc0de commented Apr 8, 2020

Please make sure that this is a bug. As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock
    example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g.,
    Linux Ubuntu 16.04): Debian GNU/Linux 10 (buster)
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
    the issue happens on mobile device:
  • TensorFlow installed from (source or
    binary): binary
  • TensorFlow version (use command below): v2.1.0-rc2-17-ge5bf8de 2.1.0 / v1.12.1-29016-g38797a1c8b 2.2.0-dev20200407
  • Python version: 3.7.7
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: - GPU model and memory:

You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with: 1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" 2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior
Well-defined function with tf.where has nan gradients at points where tf.where inactive branch is undefined.

Describe the expected behavior
Inactive branch should be ignored in gradients calculations.

Standalone code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.

import tensorflow as tf

for ex in range(-3, 3):
    x = tf.convert_to_tensor(10.**ex)
    with tf.GradientTape() as g:
        g.watch(x)
        y = tf.where(x >= -1., x, tf.math.log1p(-x))
#         y = tf.where(x >= -1., x, tf.math.log(1.-x))
#         y = tf.where(x >= -1., x, 1./(1.-x))
    dy_dx = g.gradient(y, x)
    print(f'y({x})={y}, dy/dx({x})={dy_dx}')

All 3 functions above are well defined for positive values used for testing. Still they show no gradient at point 1.. while it has to be equal to 1.

y(0.0010000000474974513)=0.0010000000474974513, dy/dx(0.0010000000474974513)=1.0
y(0.009999999776482582)=0.009999999776482582, dy/dx(0.009999999776482582)=1.0
y(0.10000000149011612)=0.10000000149011612, dy/dx(0.10000000149011612)=1.0
y(1.0)=1.0, dy/dx(1.0)=nan
y(10.0)=10.0, dy/dx(10.0)=1.0
y(100.0)=100.0, dy/dx(100.0)=1.0

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

@ravikyram
Copy link

@ravikyram ravikyram commented Apr 8, 2020

I have tried on colab with TF version 2.1.0 , 2.2.0-rc2 and was able to reproduce the issue.Please, find the gist here. Thanks!

@ravikyram ravikyram assigned ymodak and unassigned ravikyram Apr 8, 2020
@mdanatg
Copy link
Contributor

@mdanatg mdanatg commented Apr 8, 2020

This is due to a limitation limitation in how gradients are calculated. Unfortunately, it is unlikely to be fixed in the foreseable future.

You can find more detail here, along with a recipe for how to avoid it: https://stackoverflow.com/questions/33712178/tensorflow-nan-bug/42497444#42497444

In short, if the input to a tf.where contains NaNs, the gradient will always be NaN, regardless whether the input is actually used or not, and the workaround is to prevent the inputs from ever containing NaNs.

@mdanatg mdanatg closed this Apr 8, 2020
@tensorflow-butler
Copy link

@tensorflow-butler tensorflow-butler bot commented Apr 8, 2020

Are you satisfied with the resolution of your issue?
Yes
No

@0x0badc0de
Copy link
Author

@0x0badc0de 0x0badc0de commented Apr 8, 2020

Shouldn't this be documented with big warning in tf.where docs in this case?

@mdanatg
Copy link
Contributor

@mdanatg mdanatg commented Apr 8, 2020

Indeed it should.

@joemaren
Copy link

@joemaren joemaren commented Apr 9, 2020

@mdanatg Hello, this is my first time contributing to TensofFlow lib. From the thread I gather you would require the tf.where be updated. If it is so can I work on this?

@ymodak ymodak removed their assignment Apr 9, 2020
@anorak-k
Copy link

@anorak-k anorak-k commented Apr 11, 2020

Hello @0x0badc0de , @mdanatg
Should the updated doc contain a something like a warning? or will a small note at the end, about the input not being Nan will do? Also should the workaround for avoiding it also be added to the doc?

@mdanatg
Copy link
Contributor

@mdanatg mdanatg commented Apr 11, 2020

@joemaren @anorak-k

Sorry for the delay. Feel free to send a PR - it's only a matter of adding a paragraph to the docstring.

The text should be more in the lines of a warning. Something like: Important: if any of the inputs contain NaN values, etc.. And yes, it should include the workaround as well, which is something in the lines of: instead of tf.where(x, ops_that_can_nan(z), ...), write tf.where(x, ops_that_can_nan(tf.where(x, z, safe_value)), ...).

@anorak-k
Copy link

@anorak-k anorak-k commented Apr 13, 2020

@mdanatg I have added the change and raised a PR #38467

@mkaze
Copy link

@mkaze mkaze commented Apr 18, 2020

@mdanatg Thanks for your reply. However, I would like to mention that this behavior also happens when the generated value in the inactive branch is not finite (i.e. inf or -inf). Here is a minimal reproducible example:

import tensorflow as tf

a = tf.Variable(10.)
with tf.GradientTape() as tape:
  out = tf.where(a < 15., a, tf.math.pow(10.0, tf.math.exp(a)))
  grads = tape.gradient(out, a)

print(grads)
# tf.Tensor(nan, shape=(), dtype=float32)

And also if we reverse the condition such that the branch with infinite value is selected, the gradient would be infinite (which is a bit surprising that it does not generate nan instead, like above):

with tf.GradientTape() as tape:
  out = tf.where(a > 15., a, tf.math.pow(10.0, tf.math.exp(a)))
  grads = tape.gradient(out, a)

print(grads)
# tf.Tensor(inf, shape=(), dtype=float32)

So this behavior happens for both nan and infinite values in inactive branch. I wish it wasn't like this, because it's a bit unreasonable and makes it impossible to use user-defined ops/functions which generate extremely large values for some input values; hence, that inner tf.where workaround may not be practical always (unfortunately, even gradient clipping does not help with this, because clipping a nan value produces nan in TF).

CC: @anorak-k for potential consideration in your PR after @mdanatg confirms this.

@mdanatg
Copy link
Contributor

@mdanatg mdanatg commented Apr 19, 2020

@mkaze that's true - nan, inf and any other special FP value will disrupt the gradient calculation.

What happens internally is that the gradients are aggregated in this fashion: 1 * <grad of branch taken> + 0 * <grad of branch not taken>. In the former case, you have 0 * inf = nan. In the latter case, you have 1 * inf = inf. I agree it's very confusing, unfortunately a naive fix would add significant overhead to gradient calculations.

Moreover, the forward calculation doesn't need to result in a nan or inf. You can also get weird results if the gradient alone is nan or inf. For example, the cube root function is defined and well-behaved everywhere, but its derivative at zero is infinite. So this will give you a nan gradient too:

a = tf.Variable(0.0)
with tf.GradientTape() as tape:
  out = tf.where(a < 1, a, tf.pow(a, 1.0/3.0))
  grads = tape.gradient(out, a)
print(grads)

I think the tf.where workaround is useful with infinite values as well, so long as the branch not taken is forced to take a gradient that can be safely multiplied by 0. For your example, it would be something like this:

dummy_safe_value = 0
safe_a = tf.where(a > 15., dummy_safe_value, a)
out = tf.where(a > 15., a, tf.math.pow(10.0, tf.math.exp(safe_a)))

I agree that it sometimes can be impractical to do, but in principle it should always be possible as long as you control the inputs to the sensitive functions - all they have to do is force finite values in all the elements that are dropped.

@kari-554
Copy link

@kari-554 kari-554 commented May 7, 2020

I want to fix the issue #38349

@tushar-dalal
Copy link

@tushar-dalal tushar-dalal commented May 14, 2020

This is due to a limitation limitation in how gradients are calculated. Unfortunately, it is unlikely to be fixed in the foreseable future.

You can find more detail here, along with a recipe for how to avoid it: https://stackoverflow.com/questions/33712178/tensorflow-nan-bug/42497444#42497444

In short, if the input to a tf.where contains NaNs, the gradient will always be NaN, regardless whether the input is actually used or not, and the workaround is to prevent the inputs from ever containing NaNs.

You can simply have it raise a value error if its getting Nan inputs. Or does it not work like that?

@unicorn-io
Copy link

@unicorn-io unicorn-io commented May 29, 2020

Can I work on this issue if someone isn't now?

@mdanatg
Copy link
Contributor

@mdanatg mdanatg commented May 29, 2020

@tushar-dalal The challenge is that verifying for such NaN inputs can be taking on performance. When debugging, tf.debugging.check_numerics can indeed help with that.

@unicorn-io Feel free to tackle it, but note that it's extremely challenging to solve. That said, there was a PR (#38467) to add a warning message to the docs of tf.where, it would be useful to revive it.

@unicorn-io
Copy link

@unicorn-io unicorn-io commented May 29, 2020

I am motivated to do this can you give me some tips to start with I will try my best to understand and resolve this issue.

@unicorn-io
Copy link

@unicorn-io unicorn-io commented Jun 2, 2020

I am motivated to do this can you give me some tips to start with I will try my best to understand and resolve this issue. @mdanatg

@mdanatg
Copy link
Contributor

@mdanatg mdanatg commented Jun 2, 2020

@unicorn-io You can start by looking at the gradient code and understanding how it works. Then you can reproduce when happens in the case of a where with bad gradients.

@unicorn-io
Copy link

@unicorn-io unicorn-io commented Jun 2, 2020

Cool I'll get to it

@madamalarevanth
Copy link

@madamalarevanth madamalarevanth commented Jun 8, 2020

Hey i would like to work on it. can also help please

@AbhinavTalari
Copy link

@AbhinavTalari AbhinavTalari commented Jun 17, 2020

Cool I'll get to it

This bug cannot be fixed as of now it seems.

@mdanatg
Copy link
Contributor

@mdanatg mdanatg commented Jun 17, 2020

It's indeed very challenging to fix. However, the documentation of affected ops, like tf.where can still be updated to alert the users about it.

@unicorn-io
Copy link

@unicorn-io unicorn-io commented Jun 18, 2020

@mdanatg isn't #38497 addressing this and is closed?

@mdanatg
Copy link
Contributor

@mdanatg mdanatg commented Jun 18, 2020

You mean #38467? It's closed due to staleness, and it would be useful to revive. By the looks of it it's safe to assume noone else is working on it.

@EbiereVO
Copy link

@EbiereVO EbiereVO commented Jul 1, 2020

Seems like its a long time since the last activity. Is this issue still open to be worked on?

@mdanatg
Copy link
Contributor

@mdanatg mdanatg commented Jul 1, 2020

I think so. There are two parts to it: (1) updating the docs of tf.where, which is fairly straightforward, and (2) actually trying to address the issue, which is a significant undertaking because it involves a rather fundamental issue.

@iamharshit13
Copy link

@iamharshit13 iamharshit13 commented Jul 8, 2020

Is this issue still addressable ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet
You can’t perform that action at this time.