Swapping selected layers with different layers. #3586
-
Hello! I am pretty new to Jax/Flax. I was trying to implement to LoRA, although I could modify the state dict, would it be possible to modify the layers to accommodate the additional matrices on run time? Or, is there any other alternate approach to this, so that I could freeze the original weights and just keep the I am able to modify the state dict by doing something like this: import jax
import jax.numpy as jnp
PRNGKey = jnp.ndarray
def get_param(path, param, rank: int, rng: PRNGKey):
if len(param.shape) == 1:
return param
a_dim, b_dim = param.shape
return {
"a": jax.random.normal(rng, shape=(a_dim, rank)),
"b": jnp.zeros(shape=(rank, b_dim)),
"w": param
}
new_params = jax.tree_util.tree_map_with_path(get_param, params, is_leaf=None) Is there any maneuver I could do? Or, would I have to make drastic moves like re-defining the architecture? Any suggestions would help! Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
I have a JAX implementation of LoRA here: https://github.com/davisyoshida/lorax You don't need to touch model code to support it. |
Beta Was this translation helpful? Give feedback.
-
Thanks a lot for your reply! So, my main query is if I were to do it myself, how should I go about implementing it? Thanks again! |
Beta Was this translation helpful? Give feedback.
Ah yeah so if you want to really customize JAX you need to write custom interpreters. The JAX tutorials are pretty good (11/10 if we're comparing to most projects, but not 100% of the information you need to be successful).
This is the main one which is relevant: https://jax.readthedocs.io/en/latest/autodidax.html
That being said, if you're looking for something simpler, you can probably write a more restricted version without getting that far down into the weeds. I think targeting only flax Dense layers and making the modification at initialization time should be much simpler. I don't know the best way to do that using FLAX's machinery since I'm a bit more familiar with haiku.