Updated state loading to copy by reference #705
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR updates the state loading for TransformerLens to always copy the configured state dictionary by reference as opposed to copying the actual data across memory. Right now, when a model is loaded in TransformerLens, you need 2X more memory than what the model will take once the loading is complete.
This difference is the same across the board.
This PR simply sets the
assign
param to true in the function call toload_state_dict
from PyTorch. This parameter is already set to true when loading a model in 4 bit mode.I am not sure if this is going to be a beneficial change across the board, but my initial thought is that this is probably fine to change. The biggest concern is that someone would want the state dictionary constructed to be available separately from the state of the actual model. At the moment, that is not easily doable without heavily modifying
HookedTransformer
, and at a glance it seems that in all current uses, the memory allocated for the constructedstate_dict
is deallocated immediately following the call toload_state_dict
.All of this leads me to believe that copying the state dictionary values by references is probably fine. However, I am not sure if this topic has been discussed in previous conversations that I am not privy to, and if
assign
has been purposely left to it's default state for a reason that I am not aware of. It's very possible that no one ever thought to set that parameter, and up until recently it didn't really matter, since most people were probably using models less than a few gigs in size. I am just theorizing though, and would really appreciate a second look.Along with the reduction in initial memory requirements, this change also significantly reduces the time needed to load a model. I am not entirely sure how much the reduction is, but I would guess that it is reduction somewhere between 33% and 50%.
If there is a reason to not change this across the board, then we should at the very least provide an option to set that parameter when loading a model. I would prefer to have less options to keep the API simpler, and make copying by reference the default. The ability to load by reference seems too powerful to keep turned off as it is now, especially as we continue to support larger models.
Type of change
Please delete options that are not relevant.
Checklist: