Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Are there any demos? #3

Open
Urheen opened this issue Feb 17, 2024 · 2 comments
Open

Are there any demos? #3

Urheen opened this issue Feb 17, 2024 · 2 comments

Comments

@Urheen
Copy link

Urheen commented Feb 17, 2024

Could you please provide any demos for using this package?

@janEbert
Copy link
Owner

Oh sorry, the README is kind of incorrect here, I have to fix that! In fact, you have to pass a function to the VeLO.step method that

  1. takes no arguments,
  2. calculates gradients for the model, and
  3. returns the loss.

This can be a bit awkward to achieve in a multi-processing-safe way.

The usage section in the README should be something like:

from pytorch_velo import VeLO

# [...]

train_steps = epochs * len(dataset)  # Assuming `dataset` is already batched.
opt = VeLO(params, num_training_steps=train_steps, weight_decay=0.0)

def loss_with_backward_builder(inputs, targets):

    def loss_with_backward():
        opt.zero_grad()
        preds = model(inputs)
        loss = criterion(preds, targets)
        loss.backward()
        return loss

    return loss_with_backward

# [...]

# For each batch of input-target pairs:
closure = loss_with_backward_builder(inputs, targets)
opt.step(closure)

@Urheen
Copy link
Author

Urheen commented Feb 21, 2024

Thank you very much for your reply!

Could you please add one attribute that it can load the pretrained model of velo optimizer?

@Urheen Urheen closed this as completed Feb 21, 2024
@Urheen Urheen reopened this Feb 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants