Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorFlow cannot compile shape-checked method with explicit input signature. #1

Open
jesnie opened this issue Jul 6, 2022 · 4 comments
Labels
bug Something isn't working

Comments

@jesnie
Copy link
Member

jesnie commented Jul 6, 2022

import tensorflow as tf
from gpflow.experimental.check_shapes import check_shapes


class A:

    def f(self, x):
        return x + 2

    @check_shapes(
        "x: [batch...]",
        "return: [batch...]",
    )
    def g(self, x):
        return x + 2


a = A()
specs = [tf.TensorSpec(shape=None, dtype=tf.int32)]
f = tf.function(a.f)
f2 = tf.function(a.f, input_signature=specs)
g = tf.function(a.g)
g2 = tf.function(a.g, input_signature=specs)
x = tf.constant(7)

f(x)  # Good
f2(x)  # Good
g(x)  # Good
g2(x)  # Breaks...
@jesnie jesnie added the bug Something isn't working label Jul 6, 2022
@Corwinpro
Copy link
Contributor

Corwinpro commented Jul 11, 2022

Hi @jesnie ,

I am hitting a similar problem. Is it somehow related to the one reported in this issue?

Having an optional argument together with a tf.function wrapper is failing.

import tensorflow as tf
from gpflow.experimental.check_shapes import check_shapes


class A:
    @check_shapes(
        "x: [batch...]",
        "return: [batch...]",
    )
    def foo(self, x, opt=None):
        return x + 2

    @check_shapes(
        "x: [batch...]",
        "return: [batch...]",
    )
    def bar(self, x, opt=None):
        return tf.function(self.foo)(x, opt)

    @check_shapes(
        "x: [batch...]",
        "return: [batch...]",
    )
    def foo_2(self, x):
        return x + 2

    @check_shapes(
        "x: [batch...]",
        "return: [batch...]",
    )
    def bar_2(self, x):
        return tf.function(self.foo_2)(x)


a = A()

a.foo_2(2.0)  # OK
a.bar_2(2.0)  # OK

a.foo(2.0)  # OK
a.bar(2.0)  # Not OK

I noticed that changing to this fixes it:

    @check_shapes(
        "x: [batch...]",
        "return: [batch...]",
    )
    def bar(self, x, opt=None):
        return tf.function(self.foo)(x=x, opt=opt)  # added key words

@jesnie
Copy link
Member Author

jesnie commented Jul 11, 2022

Which version of TensorFlow are you using? I know some of the earlier versions are struggling with optional parameters: https://github.com/GPflow/GPflow/blob/fda83683483429de5eda996ba2f98c0400b987cf/tests/gpflow/experimental/check_shapes/test_integration.py#L180

@Corwinpro
Copy link
Contributor

2.4 it is. Thanks for pointing this out. Are there any known fixes for that?

@jesnie
Copy link
Member Author

jesnie commented Jul 11, 2022

I haven't made the effort to look into it. 🤷

@jesnie jesnie transferred this issue from GPflow/GPflow Sep 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: Todo
Development

No branches or pull requests

2 participants