Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Construct train step from an objective function and optimizer #595

Open
seanmor5 opened this issue Sep 11, 2024 · 0 comments
Open

Construct train step from an objective function and optimizer #595

seanmor5 opened this issue Sep 11, 2024 · 0 comments

Comments

@seanmor5
Copy link
Contributor

Right now the only way to construct a train step is using a loss function and an optimizer:

def train_step(model, loss, optimizer, opts \\ []) do

This is suitable for most cases, but some instances it may be easier to allow a user to pass an objective function to differentiate through rather than just the loss function. In a default train step the constructed objective function is:

  objective_fn = fn trainable_parameters, model_state, loss_scale_state, inp, tar ->
    # hack to use trainable parameters as grad
    model_state =
      update_in(model_state, [Access.key!(:data)], fn data ->
        tree_merge(data, trainable_parameters, fn _, _, v -> v end)
      end)

    model_out = forward_model_fn.(model_state, inp)
    unscaled_loss = loss_fn.(tar, model_out.prediction)
    scaled_loss = scale_loss.(unscaled_loss, loss_scale_state)

    {model_out, scaled_loss, unscaled_loss}
  end

If we can clean this form up a bit, and get rid of the hack, this could be a useful API for constructing more complex training objectives without needing to re-implement the entire train step

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant