Skip to content

Latest commit



287 lines (202 loc) · 7.83 KB

File metadata and controls

287 lines (202 loc) · 7.83 KB

The JAXnet API

JAXnet modules

JAXnet comes with some predefined modules. The tests show how modules can be used. For example, Sequential is defined as

def Sequential(*layers):
    def sequential(inputs):
        for layer in layers:
            inputs = layer(inputs)
        return inputs

    return sequential

Parameter-free modules like relu, flatten and softmax are plain Python functions:

def relu(x):
    return np.maximum(x, 0)

and usage is seamless:

layer = Sequential(Dense(10), relu)

Parameter sharing

Parameters are shared by using the same module object multiple times:

shared_net = Sequential(layer, layer)

How do modules work?

Parameter is the primitive module from which all modules are built. It is created from an initialization function:

scalar = Parameter(lambda key: np.zeros(()))

The module has a single parameter that is initialized via the given function:

param = scalar.init_parameters(key=PRNGKey(0))
assert np.zeros(()) == param

Independent of any inputs, it returns these parameter values:

assert param == scalar.apply(param)

The Parameter module is roughly equivalent to:

class Parameter:
    def __init__(self, init_parameter): self.init_parameter = init_parameter

    def apply(self, parameters, *inputs): return parameters

    def init_parameters(self, *example_inputs, key): return self.init_parameter(key)

All other modules are composed from this primitive via @parametrized functions:

def Dense(out_dim, kernel_init=glorot(), bias_init=randn()):
    def dense(inputs):
        kernel = Parameter(lambda key: kernel_init(key, (inputs.shape[-1], out_dim)))()
        bias = Parameter(lambda key: bias_init(key, (out_dim,)))()
        return, kernel) + bias

    return dense

The parameter function allows to express the same more concisely:

def Dense(out_dim, kernel_init=glorot(), bias_init=randn()):
    def dense(inputs):
        kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
        bias = parameter((out_dim,), bias_init)
        return, kernel) + bias

    return dense

@parameterized transforms this function into an equivalent of:

class Dense:
    Params = namedtuple('dense', ['kernel', 'bias'])

    def __init__(self, out_dim, kernel_init=glorot(), bias_init=randn()):
        self.bias_init = bias_init
        self.kernel_init = kernel_init
        self.out_dim = out_dim

    def apply(self, parameters, inputs):
        kernel, bias = parameters
        return, kernel) + bias

    def init_parameters(self, example_inputs, key):
        kernel_key, bias_key = random.split(key, 2)
        kernel = self.kernel_init(kernel_key, (example_inputs.shape[-1], self.out_dim))
        bias = self.bias_init(bias_key, (self.out_dim,))
        return Dense.Params(kernel=kernel, bias=bias)

This allows creation and usage of models as described in the readme.

Parameters can optionally be named (see next section for effect):

        kernel = parameter((inputs.shape[-1], out_dim), kernel_init, 'kernel')
        bias = parameter((out_dim,), bias_init, 'name')

How are parameters named?

JAXnet does not rely on module or weight names. Parameters are initialized to (nested) namedtuples for readability only. They are named after their defining module (@parametrized function). Parameters are named parameter unless a name is specified as above. If names clash within the same module, indices are added in a fixed order:

net = Sequential(Conv(4, (2, 2)), flatten, relu, Dense(3), relu, Dense(2),
                   Sequential(Dense(2), relu))
inputs = np.zeros((1, 5, 5, 2))

params = net.init_parameters(inputs, key=PRNGKey(0))
assert (4, ) == params.conv.bias.shape
assert (3, ) == params.dense0.bias.shape
assert (3, 2) == params.dense1.kernel.shape
assert (2, ) == params.dense1.bias.shape
assert (2, ) == params.sequential.dense.bias.shape

When init_parameters is called on different modules, parameters corresponding to the same shared module can be different (have different indices) between the two calls. When init_parameters is called on the same module twice, resulting parameter names are identical.

Regularization and reparametrization

JAXnet allows concise regularization for a given loss network:

reg_loss_net = L2Regularized(loss_net, scale=.1)

reg_loss_net now is a module usable like any other.

Reparametrization is similarly simple:

    def Scaled():
        def learnable_scale(params):
            return 2 * parameter((), ones) * params

        return learnable_scale

    scaled_net = Reparametrized(net, Scaled)

In this example, every weight vector/matrix is multiplied by a learnable scalar. Variational inference can be implemented as a combination of Reparametrized and Regularized. (Example will be added soon.)

Since Reparametrized just returns another module, it can be applied to parts of your network:

net = Sequential(Conv(2, (3, 3)), relu, Conv(2, (3, 3)), relu, flatten,
                 Reparametrized(Sequential(Dense(2), relu, Dense(2)), Scaled))

Implementing Reparametrized is straight-forward:

def Reparametrized(model, reparametrization_factory):
    def reparametrized(*inputs):
        params = Parameter(lambda key: model.init_parameters(*inputs, key=key))()
        transformed_params = tree_map(lambda param: reparametrization_factory()(param), params)
        return model.apply(transformed_params, *inputs)

    return reparametrized

Parameter reuse

If you want to evaluate parts or extended versions of a trained network (to get accuracy, generate samples, do introspection, ...), you can use apply_from:

net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), logsoftmax)

def loss(inputs, targets):
    return -np.mean(net(inputs) * targets)

def accuracy(inputs, targets):
    return np.mean(np.argmax(targets, axis=1) == np.argmax(predict(inputs), axis=1))

params = loss.init_parameters(np.zeros((3, 784)), np.zeros((3, 4)), key=PRNGKey(0))

# train params...

test_acc = accuracy.apply_from({loss: params}, *test_batch, jit=True)

It is a shorthand for:

accuracy_params = accuracy.parameters_from({loss: params}, *test_batch)
test_acc = jit(accuracy.apply)(accuracy_params, *test_batch)

This assumes that the inputs for loss are the same as for accuracy. Use shaped to specify deviating input shapes, for example to get predictions from net (which does not require a target) (demo):

predictions = net.apply_from({loss.shaped(*next_batch()): params}, test_inputs, jit=True)

If you want to reuse parts of your network while initializing the rest, use init_parameters with reuse:

inputs = np.zeros((1, 2))
net = Dense(5)
net_params = net.init_parameters(inputs, key=PRNGKey(0))

# train net params...

transfer_net = Sequential(net, relu, Dense(2))
transfer_net_params = transfer_net.init_parameters(inputs, key=PRNGKey(1), reuse={net: net_params})

assert net_params == transfer_net_params.dense0

# train transfer_net_params...

Storing parameters

Store parameters with save and load:

opt = optimizers.Adam()
state = opt.init(params)

# train ...

trained_params = opt.get_parameters(state)

save(trained_params, Path.home() / 'net')
trained_params = load(Path.home() / 'net')

# evaluate etc. ...

You can store the complete optimizer state with the same methods:

save(state, Path.home() / 'net')
state = load(Path.home() / 'net')

# continue training...