Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Count number of times `tf.function` is traced #37323

Open
ammirato opened this issue Mar 5, 2020 · 18 comments
Open

Count number of times `tf.function` is traced #37323

ammirato opened this issue Mar 5, 2020 · 18 comments

Comments

@ammirato
Copy link

@ammirato ammirato commented Mar 5, 2020

Please make sure that this is a feature request. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:feature_template

System information

  • TensorFlow version (you are using): 2.1
  • Are you willing to contribute it (Yes/No): No

Describe the feature and the current behavior/state.
I would like to be able to check, inside a function wrapped by tf.function, when this function is being retraced. This is useful for testing my functions are compatible with tf.function, and will make use of the speed up (and not just make a new graph every time they are called).

An example, with the requested feature would be inserted in REQUEST:

import my_function
@pytest.mark.unit
@pytest.mark.parametrize('use_tf_function', [True, False])
def test_my_function(use_tf_function):

    def test(arg1, arg2):
        # REQUEST: Check to make sure this is not being retraced many times.
        return my_function(arg1, arg2)

    # Sometimes test in Eager mode for debugging, sometimes test in graph mode.
    test_func = test
    if use_tf_function:
        test_func = tf.function(test_func)

    #####Test 1
    arg1, arg2 = #some setup stuff
    test_func(arg1, arg2)  # create the graph (tracing happens).
    results = test_func(arg1, arg2)  # hopefully tracing does not happen a second time.
    assert results

    #####Test 2
    arg4, arg3 = #some setup stuff
    results = test_func(arg3, arg4)  # hopefully tracing does not happen again (if the inputs are tensors and the shapes do not change)
    assert results

Will this change the current api? How?
I'm not sure, maybe just add a function, maybe it already exists.

Who will benefit with this feature?
Anyone debugging.

Any Other info.

@ammirato
Copy link
Author

@ammirato ammirato commented Mar 5, 2020

One thing I could do is check test_func._call_counter.get_tracing_count(), but I don't like the use of a variable with _ starting the name. Is there a more visible function for doing this?

@mdanatg
Copy link
Contributor

@mdanatg mdanatg commented Mar 9, 2020

Indeed, you should avoid using private fields - they may change at any moment.

The easiest and most robust way is to use a global Python counter, like so:

>>> trace_count = 0
>>> @tf.function
>>> def f(x):
...   global trace_count
...   trace_count += 1
...   # the rest of your code
>>> trace_count
0
>>> f(tf.constant(1))
>>> trace_count
1
>>> f(tf.constant(2))
>>> trace_count
1
>>> f(tf.constant([1, 2]))
>>> trace_count
2
@mdanatg mdanatg closed this Mar 9, 2020
@ammirato
Copy link
Author

@ammirato ammirato commented Mar 11, 2020

This only works for certain cases, and doesn't seem scalable to every function everywhere. I'd have to have a different global variable for each function.

What if my function is written in one file, and I wrap it in tf.function in another file? I can't use a global variable as descirbed above to test the trace count in the caller code, which is where I would like to test.

@mdanatg
Copy link
Contributor

@mdanatg mdanatg commented Mar 11, 2020

I missed that this was a feature request. Sounds reasonable to me.

@SSaishruthi
Copy link
Contributor

@SSaishruthi SSaishruthi commented Mar 12, 2020

Any suggestions on how this can be implemented? I would like to give a try if we have any pointers

@ammirato
Copy link
Author

@ammirato ammirato commented Mar 12, 2020

A simple solution would be to just add a public method to the Function class, that returns self._call_counter.get_tracing_count().

I'm not sure of the implications of this though, as _call_counter is a private field.

@mdanatg
Copy link
Contributor

@mdanatg mdanatg commented Mar 13, 2020

Fields are kept private by default, and we only expose public methods if there is a need for one. This helps keep the API clean - once a method is made public, it's extremely difficult to remove it. So for now, I recommend naming the method/property with the experimental_ prefix to control its adoption.

@qaishk
Copy link

@qaishk qaishk commented Mar 13, 2020

I'd like to give this a shot. Can I submit a PR for this?

@SSaishruthi
Copy link
Contributor

@SSaishruthi SSaishruthi commented Mar 13, 2020

@ammirato @mdanatg Thanks. Will work on this.

@kulsoomzahra
Copy link

@kulsoomzahra kulsoomzahra commented Mar 15, 2020

can i work on this? @ammirato

@Ankuraxz
Copy link

@Ankuraxz Ankuraxz commented Mar 18, 2020

@ammirato, can I work on this and submit a PR

@104H
Copy link

@104H 104H commented Mar 26, 2020

How about a static variable in the function class?

@LGTcoder
Copy link

@LGTcoder LGTcoder commented Apr 4, 2020

import my_function
@pytest.mark.unit
@pytest.mark.parametrize('use_tf_function', [True, False])
def test_my_function(use_tf_function):

def test(arg1, arg2):
    # REQUEST: Check to make sure this is not being retraced many times.
    return my_function(arg1, arg2)

# Sometimes test in Eager mode for debugging, sometimes test in graph mode.
test_func = test
if use_tf_function:
    test_func = tf.function(test_func)

#####Test 1
arg1, arg2 = #some setup stuff
test_func(arg1, arg2)  # create the graph (tracing happens).
results = test_func(arg1, arg2)  # hopefully tracing does not happen a second time.
assert results

#####Test 2
arg4, arg3 = #some setup stuff
results = test_func(arg3, arg4)  # hopefully tracing does not happen again (if the inputs are tensors and the shapes do not change)
assert results

#is it right....?

@PratsBhatt
Copy link

@PratsBhatt PratsBhatt commented May 17, 2020

Is this issue closed? or is it open for contributions? I see that the PR is closed. Thanks

@mihaimaruseac
Copy link
Collaborator

@mihaimaruseac mihaimaruseac commented May 18, 2020

PR has been closed in unmerged state

@aicam
Copy link

@aicam aicam commented May 29, 2020

If I got the issue correctly you want to save the state of your graph so you can use lru cache, I tested both inside tensorflow library and only combination of @tf.function and @lru_cache in my main.py and both worked fine but actually i was testing a simple graph with one plus, I don't know the side effects

@WolframAlph
Copy link

@WolframAlph WolframAlph commented Aug 22, 2020

Is this feature still needed? I can see previous PR was closed in an unmerged state. Thanks.

@LGTcoder
Copy link

@LGTcoder LGTcoder commented Aug 24, 2020

if hv mentioned the right one......

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet
You can’t perform that action at this time.