diff --git a/onnx_tf/handlers/backend/depth_to_space.py b/onnx_tf/handlers/backend/depth_to_space.py index 5e3043a8..f0747653 100644 --- a/onnx_tf/handlers/backend/depth_to_space.py +++ b/onnx_tf/handlers/backend/depth_to_space.py @@ -40,7 +40,8 @@ def _common(cls, node, **kwargs): if mode == "CRD": # need native computation bsize = attrs.get("blocksize") - batch, channel, height, width = x.shape + x_shape = tf.shape(x) + batch, channel, height, width = x_shape[0], x_shape[1], x_shape[2], x_shape[3] csize = channel // (bsize**2) reshape_node = tf.reshape(x, [batch, csize, bsize, bsize, height, width])