You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Right now the only way to construct a train step is using a loss function and an optimizer:
deftrain_step(model,loss,optimizer,opts\\[])do
This is suitable for most cases, but some instances it may be easier to allow a user to pass an objective function to differentiate through rather than just the loss function. In a default train step the constructed objective function is:
objective_fn=fntrainable_parameters,model_state,loss_scale_state,inp,tar-># hack to use trainable parameters as gradmodel_state=update_in(model_state,[Access.key!(:data)],fndata->tree_merge(data,trainable_parameters,fn_,_,v->vend)end)model_out=forward_model_fn.(model_state,inp)unscaled_loss=loss_fn.(tar,model_out.prediction)scaled_loss=scale_loss.(unscaled_loss,loss_scale_state){model_out,scaled_loss,unscaled_loss}end
If we can clean this form up a bit, and get rid of the hack, this could be a useful API for constructing more complex training objectives without needing to re-implement the entire train step
The text was updated successfully, but these errors were encountered:
Right now the only way to construct a train step is using a loss function and an optimizer:
This is suitable for most cases, but some instances it may be easier to allow a user to pass an objective function to differentiate through rather than just the loss function. In a default train step the constructed objective function is:
If we can clean this form up a bit, and get rid of the hack, this could be a useful API for constructing more complex training objectives without needing to re-implement the entire train step
The text was updated successfully, but these errors were encountered: