Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated state loading to copy by reference #705

Closed
wants to merge 1 commit into from

Conversation

bryce13950
Copy link
Collaborator

Description

This PR updates the state loading for TransformerLens to always copy the configured state dictionary by reference as opposed to copying the actual data across memory. Right now, when a model is loaded in TransformerLens, you need 2X more memory than what the model will take once the loading is complete.

e.g. if you are loading google/gemma-2b when the loading is complete it will take about 13.3gb of memory, but as it loads it will bubble up to around 27gb of memory when it hits the function call self.load_state_dict(state_dict, strict=False).

This difference is the same across the board.

i.e. if you are loading a model that take 95gb of memory when sitting idle, you will need about 190gb of memory to get past the loading.

This PR simply sets the assign param to true in the function call to load_state_dict from PyTorch. This parameter is already set to true when loading a model in 4 bit mode.

I am not sure if this is going to be a beneficial change across the board, but my initial thought is that this is probably fine to change. The biggest concern is that someone would want the state dictionary constructed to be available separately from the state of the actual model. At the moment, that is not easily doable without heavily modifying HookedTransformer, and at a glance it seems that in all current uses, the memory allocated for the constructed state_dict is deallocated immediately following the call to load_state_dict.

All of this leads me to believe that copying the state dictionary values by references is probably fine. However, I am not sure if this topic has been discussed in previous conversations that I am not privy to, and if assign has been purposely left to it's default state for a reason that I am not aware of. It's very possible that no one ever thought to set that parameter, and up until recently it didn't really matter, since most people were probably using models less than a few gigs in size. I am just theorizing though, and would really appreciate a second look.

Along with the reduction in initial memory requirements, this change also significantly reduces the time needed to load a model. I am not entirely sure how much the reduction is, but I would guess that it is reduction somewhere between 33% and 50%.

If there is a reason to not change this across the board, then we should at the very least provide an option to set that parameter when loading a model. I would prefer to have less options to keep the API simpler, and make copying by reference the default. The ability to load by reference seems too powerful to keep turned off as it is now, especially as we continue to support larger models.

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@bryce13950 bryce13950 closed this Sep 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant