From 8e7f7b9fdc431a4449defab83340a8c8f85b790d Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Tue, 14 May 2024 10:12:09 -0400 Subject: [PATCH] Do not duplicate so much code --- lib/axon.ex | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/lib/axon.ex b/lib/axon.ex index 7d1d97ee..995650c8 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -745,25 +745,28 @@ defmodule Axon do generated. """ @doc type: :special - def block(fun, opts \\ []) + def block(fun, opts \\ []) when is_function(fun) do + {:arity, arity} = Function.info(fun, :arity) + opts = Keyword.validate!(opts, [:name, :meta]) + block_id = System.unique_integer([:positive, :monotonic]) - for i <- 1..128 do - args = Macro.generate_arguments(i, __MODULE__) + block_fun(arity, fn inputs -> + layer(:block, List.wrap(inputs), + op_name: :block, + name: opts[:name], + meta: opts[:meta], + block_fun: fun, + block_id: block_id + ) + end) + end - @doc false - def block(fun, opts) when is_function(fun, unquote(i)) do - opts = Keyword.validate!(opts, [:name, :meta]) - block_id = System.unique_integer([:positive, :monotonic]) + @doc false + for i <- 0..128 do + args = Macro.generate_arguments(i, __MODULE__) - fn unquote_splicing(args) -> - layer(:block, List.wrap(unquote(args)), - op_name: :block, - name: opts[:name], - meta: opts[:meta], - block_fun: fun, - block_id: block_id - ) - end + def block_fun(unquote(i), callback) do + fn unquote_splicing(args) -> callback.(unquote(args)) end end end