Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert to call_tir with class prefix #11

Merged
merged 2 commits into from
May 11, 2023
Merged

Revert to call_tir with class prefix #11

merged 2 commits into from
May 11, 2023

Conversation

sudeepag
Copy link
Contributor

@sudeepag sudeepag commented May 11, 2023

Per feedback on #10, reverting to call_tir with a cls prefix in the invocation.

Preview

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@sudeepag
Copy link
Contributor Author

@tqchen This works correctly when using call_tir on primitive functions defined within the class, but I get an error when using it with registered runtime functions, for example:

@tvm.script.ir_module
class MyModuleWithExternCall:
    @R.function
    def main(x: R.Tensor((1, 784), "float32"), 
             w0: R.Tensor((128, 784), "float32"), 
             b0: R.Tensor((128,), "float32"), 
             w1: R.Tensor((10, 128), "float32"), 
             b1: R.Tensor((10,), "float32")):
        # block 0
        with R.dataflow():
            lv0 = R.call_tir("env.linear", (x, w0, b0), out_sinfo=R.Tensor((1, 128), dtype="float32"))
            lv1 = R.call_tir("env.relu", (lv0,), out_sinfo=R.Tensor((1, 128), dtype="float32"))
            out = R.call_tir("env.linear", (lv1, w1, b1), out_sinfo=R.Tensor((1, 10), dtype="float32"))
            R.output(out)
        return out
error:   Check failed: type_code_ == kTVMObjectHandle (11 vs. 8) : expected Object but got str
 --> [/tmp/ipykernel_261925/1336559420.py:11:19](https://file+.vscode-resource.vscode-cdn.net/tmp/ipykernel_261925/1336559420.py:11:19)
    |  
 11 |              lv0 = R.call_tir("env.linear", (x, w0, b0), out_sinfo=R.Tensor((1, 128), dtype="float32"))

Is there a different syntax for using call_tir with registered functions?

@tqchen
Copy link
Contributor

tqchen commented May 11, 2023

for runtime functions, we should indeed use call_dps_packed

@sudeepag
Copy link
Contributor Author

@tqchen Thanks, ready for review.

Preview

@tqchen tqchen merged commit f96f3cc into mlc-ai:main May 11, 2023
@tqchen
Copy link
Contributor

tqchen commented May 11, 2023

Thanks @sudeepag

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants