diff --git a/warp/codegen.py b/warp/codegen.py index 7db6ecc5..a9490513 100644 --- a/warp/codegen.py +++ b/warp/codegen.py @@ -1659,6 +1659,9 @@ def emit_Name(adj, node): # lookup symbol, if it has already been assigned to a variable then return the existing mapping if node.id in adj.symbols: return adj.symbols[node.id] + # Check if the node has a warp_func attribute + if hasattr(node, 'warp_func'): + return node.warp_func obj = adj.resolve_external_reference(node.id) @@ -2323,6 +2326,18 @@ def emit_Assign(adj, node): raise WarpCodegenError( "Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead." ) + elif isinstance(lhs, ast.Name): + # symbol name + name = lhs.id + + # evaluate rhs + rhs = adj.eval(node.value) + + # Check if rhs is a function object + if isinstance(rhs, warp.context.Function): + # Assign the function directly to the symbol table + adj.symbols[name] = rhs + return # handle the case where we are assigning multiple output variables if isinstance(lhs, ast.Tuple):