Skip to content

Commit

Permalink
various fix
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Nov 5, 2024
1 parent d92e550 commit 1646db0
Show file tree
Hide file tree
Showing 6 changed files with 805 additions and 584 deletions.
242 changes: 235 additions & 7 deletions luisa_lang/_builtin_decor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
from types import UnionType
from typing import Any, Callable, List, Optional, Set, TypeVar
import typing
from luisa_lang import hir
import inspect
from luisa_lang.utils import get_full_name, get_union_args
from luisa_lang.classinfo import register_class, class_typeinfo, MethodType, _get_cls_globalns
import functools
from luisa_lang.utils import get_full_name, get_union_args, unique_hash
from luisa_lang.classinfo import MethodType, VarType, GenericInstance, UnionType, _get_cls_globalns, register_class, class_typeinfo
from enum import auto, Enum
from luisa_lang import classinfo, parse
import inspect
from typing import (
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
TypeAlias,
TypeVar,
Union,
Generic,
Literal,
overload,
Any,
)

_T = TypeVar("_T", bound=type)
_F = TypeVar("_F", bound=Callable[..., Any])


def _builtin_type(ty: hir.Type, *args, **kwargs) -> Callable[[_T], _T]:
def builtin_type(ty: hir.Type, *args, **kwargs) -> Callable[[_T], _T]:
def decorator(cls: _T) -> _T:
cls_name = get_full_name(cls)
ctx = hir.GlobalContext.get()
Expand Down Expand Up @@ -117,12 +134,223 @@ def make_builtin():
return decorator


def _builtin(func: _F) -> _F:
return func
def builtin(s: str) -> Callable[[_F], _F]:
def wrapper(func: _F) -> _F:
setattr(func, "__luisa_builtin__", s)
return func
return wrapper


def _intrinsic_impl(*args, **kwargs) -> Any:
raise NotImplementedError(
"intrinsic functions should not be called in host-side Python code. "
"Did you mistakenly called a DSL function?"
)



class _ObjKind(Enum):
BUILTIN_TYPE = auto()
STRUCT = auto()
FUNC = auto()
KERNEL = auto()


def _make_func_template(f: Callable[..., Any], func_name: str, func_globals: Dict[str, Any], self_type: Optional[hir.Type] = None):
# parsing_ctx = _parse.ParsingContext(func_name, func_globals)
# func_sig_parser = _parse.FuncParser(func_name, f, parsing_ctx, self_type)
# func_sig = func_sig_parser.parsed_func
# params = [v.name for v in func_sig_parser.params]
# is_generic = func_sig_parser.p_ctx.type_vars != {}

func_sig = classinfo.parse_func_signature(f, func_globals, [])
func_sig_converted, sig_parser = parse.convert_func_signature(
func_sig, func_name, func_globals, {}, {}, self_type)
implicit_type_params = sig_parser.implicit_type_params
implicit_generic_params: Set[hir.GenericParameter] = set()
for p in implicit_type_params.values():
assert isinstance(p, hir.SymbolicType)
implicit_generic_params.add(p.param)

def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.FunctionLike:
type_var_ns: Dict[TypeVar, hir.Type | hir.ComptimeValue] = {}
mapped_implicit_type_params: Dict[str,
hir.Type] = dict()
if is_generic:
mapping = hir.match_func_template_args(func_sig_converted, args)
if isinstance(mapping, hir.TypeInferenceError):
raise mapping
if len(mapping) != len(func_sig_converted.generic_params):
raise hir.TypeInferenceError(
None, "not all type parameters are resolved")
for p in func_sig_converted.generic_params:
if p not in mapping:
raise hir.TypeInferenceError(
None, f"type parameter {p} is not resolved")
if p not in implicit_generic_params:
type_var_ns[sig_parser.generic_param_to_type_var[p]
] = mapping[p]

for name, itp, in implicit_type_params.items():
assert isinstance(itp, hir.SymbolicType)
gp = itp.param
mapped_type = mapping[gp]
assert isinstance(mapped_type, hir.Type)
mapped_implicit_type_params[name] = mapped_type
func_sig_instantiated, _p = parse.convert_func_signature(
func_sig, func_name, func_globals, type_var_ns, mapped_implicit_type_params, self_type, mode='instantiate')
assert len(
func_sig_instantiated.generic_params) == 0, f"generic params should be resolved but found {func_sig_instantiated.generic_params}"
func_parser = parse.FuncParser(
func_name, f, func_sig_instantiated, func_globals, type_var_ns, self_type)
return func_parser.parse_body()
params = [v[0] for v in func_sig.args]
is_generic = len(func_sig_converted.generic_params) > 0
# print(f"func {func_name} is_generic: {is_generic}")
return hir.FunctionTemplate(func_name, params, parsing_func, is_generic)
_TT = TypeVar('_TT')

def _dsl_func_impl(f: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT:
import sourceinspect
assert inspect.isfunction(f), f"{f} is not a function"
# print(hir.GlobalContext.get)

ctx = hir.GlobalContext.get()
func_name = get_full_name(f)
func_globals: Any = getattr(f, "__globals__", {})

if kind == _ObjKind.FUNC:
template = _make_func_template(f, func_name, func_globals)
ctx.functions[f] = template
setattr(f, "__luisa_func__", template)
return typing.cast(_TT, f)
else:
raise NotImplementedError()
# return cast(_T, f)


def _dsl_struct_impl(cls: type[_TT], attrs: Dict[str, Any]) -> type[_TT]:
ctx = hir.GlobalContext.get()

register_class(cls)
cls_info = class_typeinfo(cls)
globalns = _get_cls_globalns(cls)
globalns[cls.__name__] = cls

def get_ir_type(var_ty: VarType) -> hir.Type:
if isinstance(var_ty, (UnionType, classinfo.AnyType, classinfo.SelfType)):
raise RuntimeError("Struct fields cannot be UnionType")
if isinstance(var_ty, TypeVar):
raise NotImplementedError()
if isinstance(var_ty, GenericInstance):
raise NotImplementedError()
return ctx.types[var_ty]

fields: List[Tuple[str, hir.Type]] = []
for name, field in cls_info.fields.items():
fields.append((name, get_ir_type(field)))
ir_ty = hir.StructType(
f'{cls.__name__}_{unique_hash(cls.__qualname__)}', cls.__qualname__, fields)
ctx.types[cls] = ir_ty

for name, method in cls_info.methods.items():
method_object = getattr(cls, name)
template = _make_func_template(
method_object, get_full_name(method_object), globalns, self_type=ir_ty)
ir_ty.methods[name] = template
return cls



def _dsl_decorator_impl(obj: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT:
if kind == _ObjKind.STRUCT:
assert isinstance(obj, type), f"{obj} is not a type"
return typing.cast(_TT, _dsl_struct_impl(obj, attrs))
elif kind == _ObjKind.FUNC or kind == _ObjKind.KERNEL:
return _dsl_func_impl(obj, kind, attrs)
raise NotImplementedError()


def struct(cls: type[_TT]) -> type[_TT]:
"""
Mark a class as a DSL struct.
Example:
```python
@luisa.struct
class Sphere:
center: luisa.float3
radius: luisa.float
def volume(self) -> float:
return 4.0 / 3.0 * math.pi * self.radius ** 3
```
"""
return _dsl_decorator_impl(cls, _ObjKind.STRUCT, {})

_KernelType = TypeVar("_KernelType", bound=Callable[..., None])

@overload
def kernel(f: _KernelType) -> _KernelType: ...


@overload
def kernel(export: bool = False, **kwargs) -> Callable[[
_KernelType], _KernelType]: ...


def kernel(*args, **kwargs) -> _KernelType | Callable[[_KernelType], _KernelType]:
if len(args) == 1 and len(kwargs) == 0:
f = args[0]
return f

def decorator(f):
return f

return decorator


class InoutMarker:
value: str

def __init__(self, value: str):
self.value = value


inout = InoutMarker("inout")
out = InoutMarker("out")


@overload
def func(f: _F) -> _F: ...


@overload
def func(inline: bool | Literal["always"]
= False, **kwargs) -> Callable[[_F], _F]: ...


def func(*args, **kwargs) -> _F | Callable[[_F], _F]:
"""
Mark a function as a DSL function.
To mark an argument as inout/out, use the `var=inout` syntax in decorator arguments.
Example:
```python
@luisa.func(a=inout, b=inout)
def swap(a: int, b: int):
a, b = b, a
```
"""

def impl(f: _F) -> _F:
return _dsl_decorator_impl(f, _ObjKind.FUNC, kwargs)

if len(args) == 1 and len(kwargs) == 0:
f = args[0]
return impl(f)

def decorator(f):
return impl(f)

return decorator
Loading

0 comments on commit 1646db0

Please sign in to comment.