-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add optional gate activation histogram logging during eval #641
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK so I kinda want to not make a whole bunch of changes to the model API just yet, and would rather have a guide on how to hack this in, since these things tend to be special snowflakes.
I also wonder if we just should consider using a debug callback (see e.g. jit_log_metrics) which is a bit gross from a functional purity perspective, but for logging I think it's fine?
activation_function = hf_config.hidden_act | ||
# This is the implementation in huggingface | ||
# https://github.com/huggingface/transformers/blob/12b1620e615592fbf099d4ec44af7b9f2d1b48aa/src/transformers/models/gemma/modeling_gemma.py#L200 | ||
activation_function = "gelu_pytorch_tanh" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i swore we already did this
@@ -123,6 +125,17 @@ def eval_callback(step: StepInfo): | |||
_join_prefix(prefix, "loading_time"): result.total_eval_loading_time, | |||
_join_prefix(prefix, "total_time"): time_fn(), | |||
} | |||
if (gate_hist := result.extras.get("gate_hist", None)) is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so i think i'm gonna have a strong preference for
- extracting this block (and the part in the loop) into a class (sort of like runningmean)
- not actually checking the usage of it in taggedevaluator (or in the models) into main, but instead
- making a little guide on how to add it in, since it's something that people want to play with sometimes but kinda adds a bunch of noise
@@ -193,16 +203,20 @@ def init( | |||
if isinstance(activation_fn, str): | |||
activation_fn = ACT2FN[activation_fn] | |||
act = activation_fn # type: ignore | |||
return LlamaMlp(gate_proj, up_proj, down_proj, act) | |||
get_bins() # initialize bins |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rm?
if extras: | ||
for key in extras: | ||
curr = total_extras.get(key, jnp.zeros_like(extras[key])) | ||
total_extras[key] = extras[key] + curr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is summing always going to be the right reduction here?
NBINS = 2 * NSIDE + 3 | ||
|
||
|
||
@jax.jit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generally speaking it's not worth putting jit
around helpers, though sometimes it is
return _BINS | ||
|
||
|
||
BIN_AX = Axis("bins", NBINS - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rm?
|
||
|
||
@jax.jit | ||
def histogram(a: Array, bins: Array) -> Array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a reference to that git issue about why we need this?
|
||
|
||
@jax.jit | ||
def sharded_histogram(a: Array, bins: Array) -> Array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's maybe just call this histogram and the other thing _histogram?
|
||
|
||
@jax.jit | ||
def get_bins() -> Array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make this take a number of bins (with the current default?)
The main goal of this PR is to add the ability to log activation statistics of the MLPs of models.
In it's current state, this involves one big, slightly inconvenient change: every model's
compute_loss
function now returns a tuple of(loss: Array, extras: dict)
whereextras
can contain any auxiliary data to log. Thus all the upstream code had to be modified to accomidate this change.Currently there's code to measure the activation statistics of llama models during eval only, as computing the histograms is incredibly inefficient on TPUs. For LLaMa 7b, computing the histograms takes roughly 4x as long as the rest of the forward pass. AFAIK there's no faster way to do this, but it's just during eval so 🤷.
The code's a little messy, so some review would be appreciated.