Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why does this lr_finder use training loss instead of validation loss? #29

Open
alleno-yu opened this issue Jul 12, 2020 · 7 comments
Open

Comments

@alleno-yu
Copy link

I have looked into the post "Estimating an Optimal Learning Rate For a Deep Neural Network", it suggested to use training loss to determine the best learning rate to use or a range of learning rate to use. However, in the paper "Cyclical Learning Rates for Training Neural Networks", the author used validation accuracy to find the learning rate range. So, in my humble opinion, lr_finder should evaluate val_loss after each batch and record it, then plot a graph using "validation loss" against "learning rate".

@surmenok
Copy link
Owner

I think your point is valid in general. However, we run only one epoch on training set. For the first epoch, train loss should be close to validation loss, if train set and validation set are drawn from the same distribution. So, the simplified method (that comes from Jeremy Howard's fast.ai course) could still be valid in many cases.
Would you mind creating a pull request for adding an option of using the validation set?

@alleno-yu
Copy link
Author

alleno-yu commented Jul 12, 2020

Thank you for your respond, I'm in the middle of MSc Final Project. I'm new to github, if that's not too late, I can create a pull request after the project. But right now, my approach is very naive. Add validation_data to the init parameter calls:

def __init__(self, model, validation_data):
    self.model = model
    self.losses = []
    self.lrs = []
    self.best_loss = 1e9
    self.validation_data = validation_data

Then add following code under on_batch_end function:

    x, y = self.validation_data
    val_loss = self.model.evaluate(x, y, verbose=0)[0]
    loss = val_loss
    self.losses.append(loss)

Hope this will help, again thank you for your contribution!
The last question, should I close this issue?

@surmenok
Copy link
Owner

Let's keep it open until it's fixed.

@tarasivashchuk
Copy link
Contributor

tarasivashchuk commented Jul 13, 2020

I might take a look at this today if I have some free time and submit the pull request. That is unless you have already started and wanted to finish it yourself @alleno-yu , let me know.

Otherwise, I think this is fairly trivial and it seems a potential solution would be to do something like basically instead of running one epoch, we decrease the number of steps per epoch to something like ~2-10 batches per epoch, and increase the number of epochs to number of batches // batches per epoch, and then essentially do the same logic, except using the on_epoch_end method to append the validation loss to the losses list. Thoughts?

And also, to @surmenok , what do you think should be the default functionality? I could do some quick testing if you guys want me to tackle this to gauge performance and accuracy differences, although it would be far from extensive and far from conclusive, but it would be something to go off of? Let me know, thanks!

Thanks guys, I don't have any professional work right now so I figure'd I'd contribute to some open source projects and work on some of my own.

@alleno-yu
Copy link
Author

@tarasivashchuk I haven't started it, so feel free to help fix this issue

@tarasivashchuk
Copy link
Contributor

@alleno-yu Ok, I'm going to wait to hear back from @surmenok to make sure he's on board with that solution

@surmenok
Copy link
Owner

Sorry for late response.
I think it's totally fine to add support of using validation set. It should be optional: the user can pass in validation set. If it's not passed in then training set is used.
As for number of epochs, we could make number of epochs configurable instead of hardcoding 1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants