Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SmoothQuant doesn't respect ignored modules for VLMs #687

Open
mgoin opened this issue Sep 26, 2024 · 2 comments
Open

SmoothQuant doesn't respect ignored modules for VLMs #687

mgoin opened this issue Sep 26, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@mgoin
Copy link
Collaborator

mgoin commented Sep 26, 2024

I am trying to apply SmoothQuant during W8A8 quantization of meta-llama/Llama-3.2-11B-Vision-Instruct where I ignore all of the modules except for language_model. However I find that it crashes when going through the vision model that I have chosen to ignore

ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"]
recipe = [
    SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore),
    GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore),
]

Error:

  File "/home/mgoin/code/llm-compressor/src/llmcompressor/modifiers/smoothquant/base.py", line 276, in _apply_smoothing
    self.scales_[mapping.smooth_name].max_channel_vals
KeyError: 'vision_model.transformer.layers.0.input_layernorm'

Code to trigger:

from datasets import load_dataset
from transformers import AutoTokenizer, MllamaForConditionalGeneration

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.transformers import oneshot, wrap_hf_model_class

# Select model and load it.
MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
model_class = wrap_hf_model_class(MllamaForConditionalGeneration)
model = model_class.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype="auto",
)
processor = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 4
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))


def preprocess(example):
    return {
        "text": processor.apply_chat_template(
            example["messages"],
            tokenize=False,
        )
    }


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
    return processor(
        sample["text"],
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
        add_special_tokens=False,
    )


ds = ds.map(tokenize, remove_columns=ds.column_names)
print(ds)

# Configure algorithms. In this case, we:
#   * apply SmoothQuant to make the activations easier to quantize
#   * quantize the weights to int8 with GPTQ (static per channel)
#   * quantize the activations to int8 (dynamic per token)
# Note: set sequential_update: true in the recipe to reduce memory
ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"]
recipe = [
    SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore),
    GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore),
]

# Apply algorithms.
oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = processor("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(processor.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + "-W8A8-Dynamic-Per-Token"
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)

Full log and error:

python llama3.2_vision_example.py
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:05<00:00,  1.19s/it]
Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 4
})
2024-09-26T21:08:35.974072+0000 | main | WARNING - Process rank: 0, device: cuda:0, n_gpu: 1, distributed training: True, 16-bits training: False
2024-09-26T21:08:35.975208+0000 | main | INFO - Training/evaluation parameters TrainingArguments(
_n_gpu=1,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
batch_eval_metrics=False,
bf16=False,
bf16_full_eval=False,
clear_sparse_session=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
dataloader_prefetch_factor=None,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=False,
do_oneshot=True,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_do_concat_batches=True,
eval_on_start=False,
eval_steps=None,
eval_strategy=no,
eval_use_gather_object=False,
evaluation_strategy=None,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
gradient_checkpointing_kwargs=None,
greater_is_better=None,
group_by_length=False,
half_precision_backend=auto,
hub_always_push=False,
hub_model_id=None,
hub_private_repo=False,
hub_strategy=every_save,
hub_token=<HUB_TOKEN>,
ignore_data_skip=False,
include_inputs_for_metrics=False,
include_num_input_tokens_seen=False,
include_tokens_per_second=False,
jit_mode_eval=False,
label_names=None,
label_smoothing_factor=0.0,
learning_rate=5e-05,
length_column_name=length,
load_best_model_at_end=False,
local_rank=0,
log_level=passive,
log_level_replica=warning,
log_on_each_node=True,
logging_dir=./output/runs/Sep26_21-08-35_beaker,
logging_first_step=False,
logging_nan_inf_filter=True,
logging_steps=500,
logging_strategy=steps,
lr_scheduler_kwargs={},
lr_scheduler_type=linear,
max_grad_norm=1.0,
max_steps=-1,
metric_for_best_model=None,
mp_parameters=,
neftune_noise_alpha=None,
no_cuda=False,
num_train_epochs=3.0,
oneshot_device=cuda:0,
optim=adamw_torch,
optim_args=None,
optim_target_modules=None,
output_dir=./output,
overwrite_output_dir=False,
past_index=-1,
per_device_eval_batch_size=8,
per_device_train_batch_size=8,
prediction_loss_only=False,
push_to_hub=False,
push_to_hub_model_id=None,
push_to_hub_organization=None,
push_to_hub_token=<PUSH_TO_HUB_TOKEN>,
ray_scope=last,
recipe=[SmoothQuantModifier(index=None, group=None, start=None, end=None, update=None, initialized_structure_=False, initialized_=False, finalized_=False, started_=False, ended_=False, smoothing_strength=0.8, mappings=[[['re:.*q_proj', 're:.*k_proj', 're:.*v_proj'], 're:.*input_layernorm'], [['re:.*gate_proj', 're:.*up_proj'], 're:.*post_attention_layernorm']], ignore=['re:.*lm_head', 're:multi_modal_projector.*', 're:vision_model.*'], num_calibration_steps=None, calibration_function=None, hooks_=None, resolved_mappings_=None, scales_=None), GPTQModifier(index=None, group=None, start=None, end=None, update=None, initialized_structure_=False, initialized_=False, finalized_=False, started_=False, ended_=False, sequential_update=True, targets='Linear', sequential_targets=None, block_size=128, quantize=True, dampening_frac=0.01, config_groups=None, ignore=['re:.*lm_head', 're:multi_modal_projector.*', 're:vision_model.*'], disable_quantization_observer_epoch=None, num_calibration_steps=None, scheme='W8A8', model=None, layer_compressors_=None, compressible_layers_=None, quantization_modifier_=None)],
recipe_args=None,
remove_unused_columns=True,
report_to=[],
restore_callback_states_from_checkpoint=False,
resume_from_checkpoint=None,
run_name=./output,
run_stages=False,
save_compressed=True,
save_on_each_node=False,
save_only_model=False,
save_safetensors=True,
save_steps=500,
save_strategy=steps,
save_total_limit=None,
seed=42,
skip_memory_metrics=True,
split_batches=None,
tf32=None,
torch_compile=False,
torch_compile_backend=None,
torch_compile_mode=None,
torch_empty_cache_steps=None,
torchdynamo=None,
tpu_metrics_debug=False,
tpu_num_cores=None,
use_cpu=False,
use_ipex=False,
use_legacy_prediction_loop=False,
use_liger_kernel=False,
use_mps_device=False,
warmup_ratio=0.0,
warmup_steps=0,
weight_decay=0.0,
)
2024-09-26T21:08:36.353807+0000 | _check_create_state | INFO - State created for compression lifecycle
2024-09-26T21:08:36.354551+0000 | pre_initialize_structure | INFO - Compression lifecycle structure pre-initialized for 0 modifiers
2024-09-26T21:08:36.354703+0000 | pre_initialize_structure | INFO - Compression lifecycle structure pre-initialized for 0 modifiers
2024-09-26T21:08:36.391736+0000 | one_shot | INFO - *** One Shot ***
2024-09-26T21:08:36.395763+0000 | from_modifiers | INFO - Creating recipe from modifiers
2024-09-26T21:08:36.427963+0000 | _check_compile_recipe | INFO - Recipe compiled and 1 modifiers created
2024-09-26T21:08:41.103626+0000 | _calibrate | INFO - Running SmoothQuantModifier calibration with 4 samples...
  0%|                                                                                                                                                    | 0/4 [00:00<?, ?it/s]torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
 25%|███████████████████████████████████                                                                                                         | 1/4 [00:00<00:00,  3.39it/s]torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 2048, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
torch.Size([1, 977, 4096])
 75%|█████████████████████████████████████████████████████████████████████████████████████████████████████████                                   | 3/4 [00:00<00:00,  7.69it/s]torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
torch.Size([1, 1119, 4096])
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  8.13it/s]
2024-09-26T21:08:41.598186+0000 | _apply_smoothing | INFO - Smoothing activation scales...
Traceback (most recent call last):
  File "/home/mgoin/code/llm-compressor/examples/quantization_w8a8_int8/llama3.2_vision_example.py", line 70, in <module>
    oneshot(
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/transformers/finetune/text_generation.py", line 76, in oneshot
    main(model_args, data_args, training_args)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/transformers/finetune/text_generation.py", line 364, in main
    stage_runner.one_shot()
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/transformers/finetune/runner.py", line 171, in one_shot
    self.trainer.one_shot(calibration_data=calib_data, stage=stage)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/transformers/finetune/session_mixin.py", line 401, in one_shot
    apply(
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/core/session_functions.py", line 184, in apply
    return active_session().apply(
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/core/session.py", line 210, in apply
    self.initialize(**kwargs)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/core/session.py", line 156, in initialize
    mod_data = self._lifecycle.initialize(
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/core/lifecycle.py", line 126, in initialize
    data = mod.initialize(state=self.state, **extras)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/modifiers/stage.py", line 124, in initialize
    modifier.initialize(state, **kwargs)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/modifiers/modifier.py", line 118, in initialize
    initialized = self.on_initialize(state=state, **kwargs)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/modifiers/smoothquant/base.py", line 135, in on_initialize
    self._apply_smoothing(state.model)
  File "/home/mgoin/venvs/vllm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/mgoin/code/llm-compressor/src/llmcompressor/modifiers/smoothquant/base.py", line 276, in _apply_smoothing
    self.scales_[mapping.smooth_name].max_channel_vals
KeyError: 'vision_model.transformer.layers.0.input_layernorm'
@mgoin mgoin added the bug Something isn't working label Sep 26, 2024
@markurtz
Copy link
Collaborator

@dsikka @kylesayrs is this one included in the bug fixes for the vision pipelines we're working on?

@mohammhn
Copy link

mohammhn commented Nov 14, 2024

@markurtz @dsikka @kylesayrs Hey, just wanted to follow up on this. I'm facing the same issue on my end. Is this issue supposed to be resolved, or is it currently unsupported?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants
@mgoin @markurtz @mohammhn and others