diff --git a/flax/builtins.py b/flax/builtins.py index 2e6a01a..7c42fb5 100644 --- a/flax/builtins.py +++ b/flax/builtins.py @@ -515,7 +515,9 @@ "˝": attrdict( condition=lambda links: links and links[0].arity, qlink=lambda links, outermost_links, i: [ - attrdict(arity=2, call=lambda w, x: fold(links, w, x, initial=True, right=True)) + attrdict( + arity=2, call=lambda w, x: fold(links, w, x, initial=True, right=True) + ) ], ), "ˢ": attrdict( @@ -632,9 +634,7 @@ "ᵟ˝": attrdict( condition=lambda links: links and links[0].arity, qlink=lambda links, outermost_links, i: [ - attrdict( - arity=2, call=lambda w, x: fold(links, w, x, initial=True) - ) + attrdict(arity=2, call=lambda w, x: fold(links, w, x, initial=True)) ], ), "ᵟᵂ": attrdict( @@ -671,9 +671,7 @@ "ᵟ‶": attrdict( condition=lambda links: links and links[0].arity, qlink=lambda links, outermost_links, i: [ - attrdict( - arity=2, call=lambda w, x: scan(links, w, x, initial=True) - ) + attrdict(arity=2, call=lambda w, x: scan(links, w, x, initial=True)) ], ), "ᵟⁿ": attrdict( @@ -724,7 +722,9 @@ "‶": attrdict( condition=lambda links: links and links[0].arity, qlink=lambda links, outermost_links, i: [ - attrdict(arity=2, call=lambda w, x: scan(links, w, x, initial=True, right=True)) + attrdict( + arity=2, call=lambda w, x: scan(links, w, x, initial=True, right=True) + ) ], ), "⁰": attrdict(