-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TensorFlow cannot compile shape-checked method with explicit input signature. #1
Comments
Hi @jesnie , I am hitting a similar problem. Is it somehow related to the one reported in this issue? Having an optional argument together with a import tensorflow as tf
from gpflow.experimental.check_shapes import check_shapes
class A:
@check_shapes(
"x: [batch...]",
"return: [batch...]",
)
def foo(self, x, opt=None):
return x + 2
@check_shapes(
"x: [batch...]",
"return: [batch...]",
)
def bar(self, x, opt=None):
return tf.function(self.foo)(x, opt)
@check_shapes(
"x: [batch...]",
"return: [batch...]",
)
def foo_2(self, x):
return x + 2
@check_shapes(
"x: [batch...]",
"return: [batch...]",
)
def bar_2(self, x):
return tf.function(self.foo_2)(x)
a = A()
a.foo_2(2.0) # OK
a.bar_2(2.0) # OK
a.foo(2.0) # OK
a.bar(2.0) # Not OK I noticed that changing to this fixes it: @check_shapes(
"x: [batch...]",
"return: [batch...]",
)
def bar(self, x, opt=None):
return tf.function(self.foo)(x=x, opt=opt) # added key words |
Which version of TensorFlow are you using? I know some of the earlier versions are struggling with optional parameters: https://github.com/GPflow/GPflow/blob/fda83683483429de5eda996ba2f98c0400b987cf/tests/gpflow/experimental/check_shapes/test_integration.py#L180 |
2.4 it is. Thanks for pointing this out. Are there any known fixes for that? |
I haven't made the effort to look into it. 🤷 |
The text was updated successfully, but these errors were encountered: