-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add optional gate activation histogram logging during eval #641
base: main
Are you sure you want to change the base?
Changes from all commits
b68dec8
3002442
8e8aef3
f5c5230
2ecaa05
5e9b602
f2ed0b8
56a3aea
1ebec3a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
from levanter.data import Dataset, ReplicatedBatchLoader | ||
from levanter.logging import LoadingTimeTrackerIterator | ||
from levanter.models.lm_model import LmExample, LmHeadModel | ||
from levanter.tracker.histograms import NBINS | ||
from levanter.trainer import StepInfo | ||
from levanter.utils.stat_utils import RunningMean | ||
from levanter.utils.tree_utils import inference_mode | ||
|
@@ -34,6 +35,7 @@ class EvalResult: | |
tag_macro_losses: dict[str, float] # per tag average-per-token loss | ||
tag_micro_losses: dict[str, float] # per tag total loss, for "parent" tags | ||
total_eval_loading_time: float | ||
extras: dict[str, float] | ||
|
||
|
||
class DomainTaggedDataset(Dataset[tuple[T, hax.NamedArray]]): | ||
|
@@ -123,6 +125,17 @@ def eval_callback(step: StepInfo): | |
_join_prefix(prefix, "loading_time"): result.total_eval_loading_time, | ||
_join_prefix(prefix, "total_time"): time_fn(), | ||
} | ||
if (gate_hist := result.extras.get("gate_hist", None)) is not None: | ||
pos_idx = NBINS // 2 + 1 | ||
log_dict[_join_prefix(prefix, "gate_hist/all")] = np.array(gate_hist.sum(axis=0)) | ||
num_gt0 = gate_hist[:, pos_idx:].sum().item() | ||
total = gate_hist.sum().item() | ||
log_dict[_join_prefix(prefix, "gate_gt0/all")] = num_gt0 / total | ||
for i in range(gate_hist.shape[0]): | ||
log_dict[_join_prefix(prefix, f"gate_hist/layer{i+1}")] = np.array(gate_hist[i]) | ||
num_gt0 = gate_hist[i, pos_idx:].sum().item() | ||
total = gate_hist[i].sum().item() | ||
log_dict[_join_prefix(prefix, f"gate_gt0/layer{i+1}")] = num_gt0 / total | ||
|
||
logger.info(f"{prefix} loss: {result.micro_avg_loss:.3f}") | ||
for tag, loss in result.tag_macro_losses.items(): | ||
|
@@ -185,12 +198,12 @@ def __init__( | |
|
||
@hax.named_jit(out_axis_resources=axis_mapping) | ||
def accum_for_batch( | ||
m: LmHeadModel, state: tuple[RunningMean, RunningMean], batch: LmExample, tags: hax.NamedArray | ||
m: LmHeadModel, state: tuple[RunningMean, RunningMean, dict], batch: LmExample, tags: hax.NamedArray | ||
): | ||
m = inference_mode(m, True) | ||
with hax.axis_mapping(axis_mapping): | ||
total_mean, mean_per_tag = state | ||
losses = m.compute_loss(batch, reduction=None, reduction_axis=()) | ||
total_mean, mean_per_tag, total_extras = state | ||
losses, extras = m.compute_loss(batch, reduction=None, reduction_axis=()) | ||
mask = batch.loss_mask # [Batch, Token] | ||
this_tokens = hax.einsum("->", mask) | ||
this_loss = hax.einsum("->", losses, mask) # to scalar | ||
|
@@ -203,23 +216,32 @@ def accum_for_batch( | |
safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0) | ||
mean_per_tag = mean_per_tag.add(safe_mean, this_tokens_per_tag) | ||
|
||
return mean, mean_per_tag | ||
if extras: | ||
for key in extras: | ||
curr = total_extras.get(key, jnp.zeros_like(extras[key])) | ||
total_extras[key] = extras[key] + curr | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is summing always going to be the right reduction here? |
||
|
||
return mean, mean_per_tag, total_extras | ||
|
||
self.accum_for_batch = accum_for_batch | ||
|
||
def evaluate(self, m: LmHeadModel): | ||
total_loss = jnp.zeros(()) | ||
mean_losses_per_tag = hax.zeros(self.dataset.Tag, dtype=np.float32) | ||
|
||
state = (RunningMean.zeros_like(total_loss), RunningMean.zeros_like(mean_losses_per_tag)) | ||
state: tuple[RunningMean, RunningMean, dict] = ( | ||
RunningMean.zeros_like(total_loss), | ||
RunningMean.zeros_like(mean_losses_per_tag), | ||
{}, | ||
) | ||
state = hax.shard(state) | ||
|
||
iterator = LoadingTimeTrackerIterator(self.loader) | ||
|
||
for batch, tags in tqdm.tqdm(iterator, "eval"): | ||
state = self.accum_for_batch(m, state, batch, tags) | ||
|
||
total_loss, losses_per_tag = state | ||
total_loss, losses_per_tag, extras = state | ||
|
||
micro_avg_loss = total_loss.mean.item() | ||
tag_avg_loss = losses_per_tag.mean | ||
|
@@ -252,4 +274,4 @@ def evaluate(self, m: LmHeadModel): | |
tag_micro_loss[tag] = mean_loss_per_tag_cpu[index] | ||
# no macro loss for the leaf tags | ||
|
||
return EvalResult(micro_avg_loss, macro_avg_loss, tag_macro_loss, tag_micro_loss, iterator.total_time) | ||
return EvalResult(micro_avg_loss, macro_avg_loss, tag_macro_loss, tag_micro_loss, iterator.total_time, extras) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
from levanter.models.lm_model import LmConfig, LmHeadModel | ||
from levanter.types import BlockFoldable | ||
from levanter.utils.flop_utils import lm_flops_per_token | ||
from levanter.utils.py_utils import cached_classproperty | ||
|
||
|
||
silence_transformer_nag() | ||
|
@@ -75,7 +76,6 @@ class GemmaConfig(HFCompatConfig): | |
vocab_size: int = 256_000 | ||
num_layers: int = 18 | ||
num_heads: int = 8 | ||
head_dim: int = 256 | ||
num_kv_heads: int = 1 | ||
attn_dropout = 0.0 | ||
norm_eps = 1e-6 | ||
|
@@ -106,10 +106,14 @@ class GemmaConfig(HFCompatConfig): | |
Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_dim)) | ||
HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads)) | ||
|
||
@property | ||
def head_dim(self) -> int: return self.hidden_dim // self.num_heads | ||
|
||
def __post_init__(self): | ||
assert ( | ||
self.num_heads % self.num_kv_heads == 0 | ||
), f"num_heads={self.num_heads} not divisible by num_kv_heads={self.num_kv_heads}." | ||
assert (self.head_dim * self.num_heads) == self.hidden_dim, "head_dim * num_heads must equal hidden_dim." | ||
|
||
def hf_checkpoint_converter(self) -> HFCheckpointConverter["GemmaConfig"]: # type: ignore | ||
return HFCheckpointConverter( | ||
|
@@ -129,7 +133,9 @@ def from_hf_config(cls, hf_config: HfConfig): | |
if hf_config.hidden_activation: | ||
activation_function = hf_config.hidden_activation | ||
else: | ||
activation_function = hf_config.hidden_act | ||
# This is the implementation in huggingface | ||
# https://github.com/huggingface/transformers/blob/12b1620e615592fbf099d4ec44af7b9f2d1b48aa/src/transformers/models/gemma/modeling_gemma.py#L200 | ||
activation_function = "gelu_pytorch_tanh" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i swore we already did this |
||
|
||
if activation_function == "gelu_pytorch_tanh": | ||
activation_function = "gelu_new" | ||
|
@@ -168,7 +174,7 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) | |
num_hidden_layers=self.num_layers, | ||
num_attention_heads=self.num_heads, | ||
num_key_value_heads=self.num_kv_heads, | ||
head_dim=self.hidden_dim // self.num_heads, | ||
head_dim=self.head_dim, | ||
hidden_activation=( | ||
"gelu_pytorch_tanh" if self.activation_function == "gelu_new" else self.activation_function | ||
), | ||
|
@@ -263,9 +269,9 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *, | |
# MLP and skip connection | ||
residual = x | ||
x = self.post_attention_layernorm(x) | ||
mlp_output = self.mlp(x, key=k_mlp) | ||
mlp_output, extras = self.mlp(x, key=k_mlp) | ||
output = residual + mlp_output | ||
return output | ||
return output, extras | ||
|
||
|
||
class GemmaTransformer(StateDictSerializationMixin, eqx.Module): | ||
|
@@ -292,10 +298,10 @@ def init(config: GemmaConfig, *, key) -> "GemmaTransformer": | |
@named_call | ||
def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray | AttentionMask], *, key) -> NamedArray: | ||
keys = maybe_rng_split(key, self.config.num_layers) if key is not None else None | ||
x = self.layers.fold(x, mask=attn_mask, key=keys) | ||
x, extras = self.layers.scan(x, mask=attn_mask, key=keys) | ||
x = self.norm(x) | ||
|
||
return x | ||
return x, extras | ||
|
||
def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): | ||
if isinstance(self.layers, Stacked): | ||
|
@@ -358,9 +364,9 @@ def __call__( | |
The attn_mask from training pipeline may be an AttentionMask object instead of NamedArray | ||
""" | ||
x = self.embeddings.embed(input_ids) | ||
x = self.transformer(x, attn_mask=attn_mask, key=key) | ||
x, extras = self.transformer(x, attn_mask=attn_mask, key=key) | ||
lm_logits = self.embeddings.unembed(x) | ||
return lm_logits | ||
return lm_logits, extras | ||
|
||
def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[GemmaConfig]": | ||
new_embeddings = self.embeddings.resize_embeddings(new_size, key=key) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,8 @@ | |
from levanter.models.lm_model import LmConfig, LmHeadModel | ||
from levanter.types import BlockFoldable | ||
from levanter.utils.flop_utils import lm_flops_per_token | ||
from levanter.utils.py_utils import cached_classproperty | ||
from levanter.tracker.histograms import get_bins, sharded_histogram | ||
|
||
|
||
silence_transformer_nag() | ||
|
@@ -78,6 +80,7 @@ class LlamaConfig(HFCompatConfig): | |
use_bias: bool = False | ||
use_layer_norm_weight: bool = True | ||
rope_scaling: Optional[dict] = None | ||
measure_act_stats: bool = True | ||
|
||
reference_checkpoint: str = "meta-llama/Llama-2-7b-hf" | ||
tokenizer: Optional[str] = None | ||
|
@@ -181,10 +184,17 @@ class LlamaMlp(eqx.Module, StateDictSerializationMixin): | |
up_proj: hnn.Linear # projection from Embed to Mlp | ||
down_proj: hnn.Linear # projection from Mlp to Embed | ||
act: Callable = eqx.static_field() | ||
measure_act_stats: bool = False | ||
|
||
@staticmethod | ||
def init( | ||
Embed: Axis, Mlp: Axis, activation_fn: Union[str, Callable], *, key, use_bias: bool = False | ||
Embed: Axis, | ||
Mlp: Axis, | ||
activation_fn: Union[str, Callable], | ||
*, | ||
key, | ||
use_bias: bool = False, | ||
measure_act_stats=False, | ||
) -> "LlamaMlp": | ||
k_fc, k_up_proj, k_down_proj = jrandom.split(key, 3) | ||
gate_proj = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias, out_first=True) | ||
|
@@ -193,16 +203,20 @@ def init( | |
if isinstance(activation_fn, str): | ||
activation_fn = ACT2FN[activation_fn] | ||
act = activation_fn # type: ignore | ||
return LlamaMlp(gate_proj, up_proj, down_proj, act) | ||
get_bins() # initialize bins | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rm? |
||
return LlamaMlp(gate_proj, up_proj, down_proj, act, measure_act_stats) | ||
|
||
@named_call | ||
def __call__(self, x: NamedArray, *, key=None) -> NamedArray: | ||
k_gate, k_up, k_down = maybe_rng_split(key, 3) | ||
hidden_states = self.gate_proj(x, key=k_gate) | ||
extras = {} | ||
if self.measure_act_stats: | ||
extras["gate_hist"] = sharded_histogram(hidden_states.array, bins=get_bins()) | ||
hidden_states = self.act(hidden_states) | ||
hidden_states = hidden_states * self.up_proj(x, key=k_up) | ||
outputs = self.down_proj(hidden_states, key=k_down) | ||
return outputs | ||
return outputs, extras | ||
|
||
def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): | ||
# unflatten the linear layers of HF state_dict to match the shape of LlamaMlp | ||
|
@@ -402,6 +416,7 @@ def init(config: LlamaConfig, *, key) -> "LlamaDecoderLayer": | |
config.activation_function, | ||
key=k_mlp, | ||
use_bias=config.use_bias, | ||
measure_act_stats=config.measure_act_stats, | ||
) | ||
ln_1 = config.mk_LayerNorm(config.Embed) | ||
ln_2 = config.mk_LayerNorm(config.Embed) | ||
|
@@ -420,9 +435,9 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *, | |
# MLP and skip connection | ||
residual = x | ||
x = self.post_attention_layernorm(x) | ||
mlp_output = self.mlp(x, key=k_mlp) | ||
mlp_output, extras = self.mlp(x, key=k_mlp) | ||
output = residual + mlp_output | ||
return output | ||
return output, extras | ||
|
||
|
||
class LlamaTransformer(StateDictSerializationMixin, eqx.Module): | ||
|
@@ -449,10 +464,10 @@ def init(config: LlamaConfig, *, key) -> "LlamaTransformer": | |
@named_call | ||
def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray | AttentionMask], *, key) -> NamedArray: | ||
keys = maybe_rng_split(key, self.config.num_layers) if key is not None else None | ||
x = self.layers.fold(x, mask=attn_mask, key=keys) | ||
x, extras = self.layers.scan(x, mask=attn_mask, key=keys) | ||
x = self.norm(x) | ||
|
||
return x | ||
return x, extras | ||
|
||
def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): | ||
if isinstance(self.layers, Stacked): | ||
|
@@ -544,9 +559,9 @@ def __call__( | |
""" | ||
k_t, k_head = maybe_rng_split(key, 2) | ||
x = self.embeddings.embed(input_ids) | ||
x = self.transformer(x, attn_mask=attn_mask, key=k_t) | ||
x, extras = self.transformer(x, attn_mask=attn_mask, key=k_t) | ||
lm_logits = self.lm_head(x, key=k_head) | ||
return lm_logits | ||
return lm_logits, extras | ||
|
||
def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]": | ||
new_Vocab = self.Vocab.resize(new_size) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so i think i'm gonna have a strong preference for