Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG🐛: Fixed scale related bugs in LoKr | Added rank_dropout_scale parameter #2180

Merged
merged 7 commits into from
Nov 5, 2024

Conversation

yaswanth19
Copy link
Contributor

This PR adds a new parameter rank_dropout_scale and fixes scale related bugs in the LoKr. Please refer to the following function in Lycoris:
https://github.com/KohakuBlueleaf/LyCORIS/blob/258387f586beabfca71646a9671027f75ed34597/lycoris/modules/lokr.py#L347

@yaswanth19
Copy link
Contributor Author

@BenjaminBossan Please review the minor corrections to the LoKr. The default initialization of W1 matrix is zeros and I am not sure to change it or not. Also shall I rewrite this LoKr implementation to remove lycoris dependency as I can use most of the LycorisLokr implementation code.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on these bug fixes for LoKr, this is really appreciated.

I have some small comments, but overall this looks good.

The default initialization of W1 matrix is zeros and I am not sure to change it or not.

I'd say leave it as is. What I think could be useful is to add a new option to the config that allows users to use the lycoris way of layer initialization, then they have a choice of which one they want. That could be done in a separate PR.

Also shall I rewrite this LoKr implementation to remove lycoris dependency as I can use most of the LycorisLokr implementation code.

I'm not sure what you mean by this, could you please elaborate.

src/peft/tuners/lokr/config.py Outdated Show resolved Hide resolved
src/peft/tuners/lokr/layer.py Outdated Show resolved Hide resolved
src/peft/tuners/lokr/model.py Outdated Show resolved Hide resolved
@yaswanth19
Copy link
Contributor Author

yaswanth19 commented Oct 28, 2024

I'm not sure what you mean by this, could you please elaborate.

I meant to say, we still have lycoris_utils dependency. Also the codebase is kind of old for LoKr. Since I have updated the codebase of LoKr w.r.t latest PEFT standards in LycorisLoKr #2133, It would be simple copy paste and we don't need a seperate implementation

@BenjaminBossan
Copy link
Member

Ah I see what you mean, thanks for explaining. I wasn't sure if you were referring to lycoris_utils.py.

To give a bit of context, lycoris_utils.py was added with the idea of removing some of the boiler plate involved in creating a new adapter by providing a few more abstractions. In that sense, it's not really older compared to the way that, say, LoRA is implemented. However, these new abstractions never really took off, either they did not really fit the new methods being added or the contributors just preferred to use LoRA as a starting point.

In general, I like that this PR is small and thus easy to review. If you rewrote LoKr to remove the usage of lycoris_utils.py, it would be quite a big change and I see little benefit, except of this was necessary to fix underlying bugs. To give you an idea of the size of the change, this PR moved the OFT implementation away from lycoris_utils.py.

We could still make this step, I think it would make sense if we also plan to move LoHa away from lycoris_utils.py and thus could remove that module completely. But that exercise should be left for later, we should first focus on fixing the bugs.

@yaswanth19
Copy link
Contributor Author

yaswanth19 commented Oct 28, 2024

Hmm makes sense given LoKr/LoHa is used less frequently. Then I will make the suggested changes. The big question is what to do with LycorisLoKr #2133 😅 . Shall I close it since we are using very little of lycoris package to have the feature what we thought and rewriting the LoKr also reaps very little benefit.

@yaswanth19
Copy link
Contributor Author

@BenjaminBossan Done with suggested changes ✅

I'd say leave it as is. What I think could be useful is to add a new option to the config that allows users to use the lycoris way of layer initialization, then they have a choice of which one they want. That could be done in a separate PR.

A separate flag is not needed, they can pass init_weight=False to initialize with random weights instead of zero weights.

@BenjaminBossan
Copy link
Member

A separate flag is not needed, they can pass init_weight=False to initialize with random weights instead of zero weights.

What I meant is that we found that lycoris initializes layers a bit differently than PEFT, as I mentioned on the other PR:

In PEFT, at the start, we initialize w1 to zeros and w2_a and w2_b randomly. LyCORIS, however, initializes w2_b to zeros and w1 and w2_a randomly.

We could have an init option like init_weights = {True, False, "lycoris"} and if the latter is chosen, we initialize the same way as lycoris does.

@yaswanth19
Copy link
Contributor Author

@BenjaminBossan Please review; Added lycoris style initialization and removed the alpha parameter setting.

@yaswanth19 yaswanth19 force-pushed the fix-lokr-bugs branch 2 times, most recently from 4cf3832 to 7e3868e Compare October 29, 2024 17:41
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for the updates. This PR is almost good to go. I found some nits that I commented on. Apart from that, since a new initialization scheme was added, let's add a test for this.

I think the best place would be in test_initialization.py. Let's create a test class similar to what we have for LoRA. The tests can be very simple: One test for init_weight=True that checks that we get the same output as from the base model. Same for init_weight="lycoris". Finally, a test for init_weight=False that checks that the output is different from the base model's output. No need for any stats on the outputs. LMK if you have questions.

src/peft/tuners/lokr/config.py Outdated Show resolved Hide resolved
src/peft/tuners/lokr/layer.py Show resolved Hide resolved
@yaswanth19
Copy link
Contributor Author

@BenjaminBossan Added the testcases, please review it.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic, thanks for the updates. Tests look good to. There are only a handful of small issues left, please take a look.

src/peft/tuners/lokr/config.py Outdated Show resolved Hide resolved
src/peft/tuners/lokr/config.py Outdated Show resolved Hide resolved
src/peft/tuners/lokr/config.py Outdated Show resolved Hide resolved
src/peft/tuners/lokr/config.py Outdated Show resolved Hide resolved
src/peft/tuners/lokr/config.py Outdated Show resolved Hide resolved
@yaswanth19
Copy link
Contributor Author

@BenjaminBossan The doc string changes are Addressed.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing my last points. This PR LGTM, thanks a lot for your amazing work on LoKr.

I'd like to leave this PR open for a little bit to see if there is more feedback from other folks. If not, I'll probably merge it at the start of next week.

@BenjaminBossan BenjaminBossan merged commit b1fd97d into huggingface:main Nov 5, 2024
14 checks passed
@BenjaminBossan
Copy link
Member

Again, thanks so much @yaswanth19 for implementing these fixes and enhancements for LoKr. Since there hasn't been any further feedback, I decided to merge now.

LMK if you plan on working on further fixes for LoKr or LoHa. Ideally, we could have them all in the same PEFT release, as users may have to retrain their models, depending on the type of fix.

@yaswanth19
Copy link
Contributor Author

yaswanth19 commented Nov 6, 2024

@BenjaminBossan I don't have any plans to work on this immediately, probably it would be best if somebody else can pick this so to have the changes in the same release

@BenjaminBossan
Copy link
Member

All right, thanks for letting me know.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants