Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes a bug when assigning a function to an intermediate variable #327

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

mehdiataei
Copy link
Contributor

Category

  • New feature
  • Bugfix
  • Breaking change
  • Refactoring
  • Documentation
  • Other (please explain)

Description

Fixes a bug when assigning a function to an intermediate variable. Warp's code generator currently doesn't handle this pattern correctly. When a function is returned, the code generator replaces the expression with a special AST node (ast.Name) with an identifier __warp_func__ and attaches the actual function object to it using an attribute warp_func.

In the example repro below from the doc, when the code generator later encounters the name func in output[tid] = func(a, b), it tries to resolve func but fails because it doesn't find __warp_func__ in its symbol table, leading to the error.

NOTE: All tests pass, including the newly added test. However, please review carefully. Fingers crossed my tweaks don't unleash chaos.

import warp as wp

@wp.func
def do_add(a: float, b: float):
    return a + b

@wp.func
def do_sub(a: float, b: float):
    return a - b

@wp.func
def do_mul(a: float, b: float):
    return a * b

op_handlers = {
    "add": do_add,
    "sub": do_sub,
    "mul": do_mul,
}

inputs = wp.array([[1, 2], [3, 0]], dtype=wp.float32)
outputs = wp.empty(2, dtype=wp.float32)

for op in op_handlers.keys():

    @wp.kernel
    def operate(input: wp.array(dtype=inputs.dtype, ndim=2), output: wp.array(dtype=wp.float32)):
        tid = wp.tid()
        a, b = input[tid, 0], input[tid, 1]
        # retrieve the right function to use for the captured dtype variable
        output[tid] = wp.static(op_handlers[op])(a, b) # this works (as per the docs)
        # ERROR: But below does not work unexpectedly (even though it should be equivalent)
        # func = wp.static(op_handlers[op])
        # output[tid] = func(a, b) # this does not work

    wp.launch(operate, dim=2, inputs=[inputs], outputs=[outputs])

print(outputs.numpy())

Changelog

  • Allow functions to be correctly assigned to variables in Warp kernels.

Before your PR is "Ready for review"

  • Do you agree to the terms under which contributions are accepted as described in Section 9 the Warp License?
  • Have you read the Contributor Guidelines?
  • Have you written any new necessary tests?
  • Have you added or updated any necessary documentation?
  • Have you added any files modified by compiling Warp and building the documentation to this PR (.e.g. stubs.py, functions.rst)?
  • Does your code pass ruff check and ruff format --check?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants