Skip to content

Commit

Permalink
style: paint it black
Browse files Browse the repository at this point in the history
  • Loading branch information
TheWii committed Feb 23, 2024
1 parent 006750b commit 3447f73
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 47 deletions.
3 changes: 2 additions & 1 deletion bolt_expressions/expose.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
binary_min = binary_operator(Min)
binary_max = binary_operator(Max)


def wrapped_min(f: Any, *args: T, **kwargs: Any) -> Union[T, Any]:
values = args

Expand Down Expand Up @@ -55,5 +56,5 @@ def wrapped_max(f: Any, *args: T, **kwargs: Any) -> Union[T, Any]:
def wrapped_len(f: Any, obj: Any, /) -> Any:
if not isinstance(obj, Source):
return f(obj)

return length(obj)
7 changes: 4 additions & 3 deletions bolt_expressions/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class UnrollHelper:
def provide(self, **kwargs: Any):
prev_data = self.data
self.data = {**self.data, **kwargs}

yield self.data

self.data = prev_data
Expand Down Expand Up @@ -161,6 +161,7 @@ def unroll(
) -> tuple[Iterable[IrOperation], IrSource | IrLiteral]:
...


@dataclass(order=False, eq=False, kw_only=True)
class Unrolled(ExpressionNode):
operations: Iterable[IrOperation] = ()
Expand Down Expand Up @@ -370,7 +371,7 @@ def resolve_branch(self, node: ExpressionNode):

if not isinstance(result, IrSource):
return

result_tuple = result.to_tuple()

if result_tuple in self.lazy_values:
Expand Down Expand Up @@ -406,7 +407,7 @@ def unroll_lazy(
) -> tuple[Iterable[IrOperation], IrSource | IrLiteral] | None:
if source in helper.ignored_sources:
return None

if helper.data.get("ignore_lazy"):
return None

Expand Down
18 changes: 9 additions & 9 deletions bolt_expressions/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def unroll(self, helper: UnrollHelper) -> tuple[Iterable[IrOperation], IrSource]
if self.stores_result:
operation = self.create_operation(target_value)

store = IrChildren((
*operation.store, IrStore(type=StoreType.result, value=temp_var)
))
operation = replace(operation,store=store)
store = IrChildren(
(*operation.store, IrStore(type=StoreType.result, value=temp_var))
)
operation = replace(operation, store=store)
else:
operations.append(IrSet(left=temp_var, right=target_value))
operation = self.create_operation(temp_var)
Expand Down Expand Up @@ -145,7 +145,7 @@ def create_operation(self, left: IrSource, right: IrSource | IrLiteral) -> IrBin
def unroll(self, helper: UnrollHelper) -> tuple[Iterable[IrOperation], IrSource]:
former = convert_node(self.former, self.ctx)
latter = convert_node(self.latter, self.ctx)

with helper.provide(ignore_lazy=not self.evaluates_target):
former_nodes, former_value = former.unroll(helper)
latter_nodes, latter_value = latter.unroll(helper)
Expand All @@ -169,10 +169,10 @@ def unroll(self, helper: UnrollHelper) -> tuple[Iterable[IrOperation], IrSource]
operation = self.create_operation(temp_var, latter_value)

if self.stores_result:
store = IrChildren((
*operation.store, IrStore(type=StoreType.result, value=temp_var)
))
operation = replace(operation,store=store)
store = IrChildren(
(*operation.store, IrStore(type=StoreType.result, value=temp_var))
)
operation = replace(operation, store=store)

operations.append(operation)

Expand Down
40 changes: 19 additions & 21 deletions bolt_expressions/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class IrData(IrSource):
def to_tuple(self) -> "DataTuple":
return DataTuple(self.type, self.target, self.path)


@dataclass(frozen=True, kw_only=True)
class IrDataString(IrData):
range: int | tuple[int | None, int | None]
Expand All @@ -143,12 +144,13 @@ def normalized_range(self) -> tuple[int, int | None]:
else:
start = self.range
end = None if start == -1 else self.range + 1

if start is None:
start = 0

return (start, end)


@dataclass(frozen=True, kw_only=True)
class IrLiteral(IrNode):
value: NbtValue
Expand Down Expand Up @@ -900,6 +902,7 @@ def multiply_divide_by_fraction(nodes: Iterable[IrOperation]):

Location = tuple[int, ...]


def get_source_usage(nodes: Iterable[IrNode]) -> dict[SourceTuple, list[Location]]:
map: dict[SourceTuple, list[Location]] = {}

Expand All @@ -922,12 +925,12 @@ def add(source: Any, i: Location):

for s in node.store:
add(s.value, (i,))

if is_binary(node) and node.store:
add(node.left, (i,))

for operand in node.operands:
add(operand, (i, ))
add(operand, (i,))

if isinstance(node, IrBranch):
children_usage = get_source_usage(node.children)
Expand Down Expand Up @@ -1179,7 +1182,7 @@ def replace_node(node: Any) -> Any:
return node

nodes = [replace_node(node) for node in nodes]

yield from nodes


Expand Down Expand Up @@ -1227,10 +1230,9 @@ def data_string_propagation(nodes: Iterable[IrOperation]) -> Iterable[IrOperatio
defs = get_source_definitions(all_nodes)

for i, node in enumerate(all_nodes):
if (
not is_binary(node, ("append", "prepend", "merge", "insert", "set"))
or not isinstance(node.right, IrSource)
):
if not is_binary(
node, ("append", "prepend", "merge", "insert", "set")
) or not isinstance(node.right, IrSource):
yield node
continue

Expand Down Expand Up @@ -1321,12 +1323,11 @@ def deadcode_elimination(
for store_el in node.store:
source = store_el.value

if (
not opt.is_temp(source)
or any(use_i > (node_i,) for use_i in usage.get(source.to_tuple(), []))
if not opt.is_temp(source) or any(
use_i > (node_i,) for use_i in usage.get(source.to_tuple(), [])
):
store.append(store_el)

if store != node.store:
node = replace(node, store=IrChildren(store))

Expand Down Expand Up @@ -1561,10 +1562,7 @@ def store_result_inlining(nodes: Iterable[IrOperation]) -> Iterable[IrOperation]
removed: set[int] = set()

for i, node in enumerate(nodes):
if (
not isinstance(node, IrCast)
or not isinstance(node.right, IrSource)
):
if not isinstance(node, IrCast) or not isinstance(node.right, IrSource):
continue

source_def_i = get_reaching_definition(defs, node.right, i)
Expand All @@ -1579,15 +1577,15 @@ def store_result_inlining(nodes: Iterable[IrOperation]) -> Iterable[IrOperation]
type=StoreType.result,
value=node.left,
cast_type=node.cast_type,
scale=node.scale
scale=node.scale,
)
source_stores = stores.setdefault(source_def_i, [])
source_stores.append(store)
removed.add(i)

for i, node in enumerate(nodes):
if store := stores.get(i):
node = replace(node, store=IrChildren((*node.store, *store)))

if i not in removed:
yield node
yield node
29 changes: 16 additions & 13 deletions bolt_expressions/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,13 +294,14 @@ def operator_method(
def operator_method(func: Callable[P, T]) -> OperatorMethod[P, T]:
...


@overload
def operator_method(
func: Callable[P, T],
*,
lazy: bool = False,
is_internal: bool = True,
returns: bool = True
returns: bool = True,
) -> OperatorMethod[P, T]:
...

Expand Down Expand Up @@ -338,10 +339,10 @@ def get(self, key: str, default: Any = None) -> OperatorMethod[..., Any] | None:
return cast(OperatorMethod[..., Any], attr)

return default

def get_item(self, key: Any) -> Any:
return None

def set_item(self, key: Any, value: Any):
child = self.target.__getitem__(key)
child.__rebind__(value)
Expand Down Expand Up @@ -562,7 +563,7 @@ class GenericOperatorHandler(OperatorHandler):
__ne__ = binary_operator(NotEqual) # type: ignore
__not__ = unary_operator(Not)
__len__ = length

@operator_method(returns=False)
def insert(self, index: int, value: Any):
return Insert(
Expand Down Expand Up @@ -612,29 +613,31 @@ class StringOperatorHandler(OperatorHandler):
def get_item(self, key: Any):
if isinstance(key, (int, slice)):
return self.slice(key)

return None

@internal
def set_item(self, key: Any, value: Any):
if isinstance(key, (int, slice)):
raise TypeError(f"String data source does not support index/slice assignment.")

raise TypeError(
f"String data source does not support index/slice assignment."
)

return super().set_item(key, value)

@operator_method
def slice(self, value: int | slice):
expr = self.target.expr

range = (value.start, value.stop) if isinstance(value, slice) else value

target = self.target
source = IrDataString(
type=target._type,
target=target._target,
path=target._path,
nbt_type=target.readtype,
range=range
range=range,
)
result = create_result(expr, ResultType.data)[str]
resolve(expr, Unrolled(value=source, ctx=expr), result=result, lazy=True)
Expand All @@ -649,7 +652,7 @@ class SequenceOperatorHandler(OperatorHandler):
__ne__ = binary_operator(NotEqual) # type: ignore
__not__ = unary_operator(Not)
__len__ = length

@operator_method(returns=False)
def insert(self, index: int, value: Any):
return Insert(
Expand Down Expand Up @@ -833,7 +836,7 @@ def __getitem__(
if self.is_lazy():
self.evaluate()
return self[key]

result = self.operator_handler.get_item(key)
if result is not None:
return result
Expand Down

0 comments on commit 3447f73

Please sign in to comment.