Skip to content

Commit

Permalink
support qqq(w4a8) for lmdeploy
Browse files Browse the repository at this point in the history
  • Loading branch information
HandH1998 committed Aug 29, 2024
1 parent d04b37f commit 754ef80
Show file tree
Hide file tree
Showing 43 changed files with 2,718 additions and 276 deletions.
9 changes: 5 additions & 4 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,11 @@ def model_format(parser, default: str = None):
'--model-format',
type=str,
default=default,
choices=['hf', 'llama', 'awq', 'gptq'],
help='The format of input model. `hf` means `hf_llama`, `llama` '
'means `meta_llama`, `awq` represents the quantized model by AWQ,'
' and `gptq` refers to the quantized model by GPTQ')
choices=['hf', 'llama', 'awq', 'gptq', 'qqq'],
help='The format of input model. `hf` meaning `hf_llama`, `llama` '
'meaning `meta_llama`, `awq` meaning the quantized model by awq, '
'`gptq` refers to the quantized model by GPTQ, and '
'`qqq` meaning the quantized model by qqq')

@staticmethod
def revision(parser, default: str = None):
Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ class TurbomindEngineConfig:
"""TurboMind Engine config.
Args:
model_format (str): the layout of the deployed model. It can be one of the following values [hf, meta_llama, awq],
`hf` meaning huggingface model(.bin, .safetensors), `meta_llama` being meta llama's format(.pth), awq` meaning the quantized model by AWQ.
model_format (str): the layout of the deployed model. It can be one of the following values [hf, meta_llama, awq, qqq],
`hf` meaning huggingface model(.bin, .safetensors), `meta_llama` being meta llama's format(.pth), `awq` meaning the quantized model by AWQ,
`qqq` meaning the quantized model by QQQ.
tp (int): the number of GPU cards used in tensor parallelism, default to 1
session_len (int): the max session length of a sequence, default to None
max_batch_size (int): the max batch size during inference, default to 128
Expand Down
48 changes: 33 additions & 15 deletions lmdeploy/turbomind/deploy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .source_model.base import INPUT_MODELS
from .target_model.base import OUTPUT_MODELS, TurbomindModelConfig

SUPPORTED_FORMATS = ['meta_llama', 'hf', 'awq', 'gptq', None]
SUPPORTED_FORMATS = ['meta_llama', 'hf', 'awq', 'gptq', 'qqq', None]
logger = get_logger('lmdeploy')


Expand All @@ -29,7 +29,7 @@ def get_input_model_registered_name(model_path: str, model_format: str):
Args:
model_path (str): the path of the input model
model_format (str): the format of the model, which can be one of
['meta_llama', 'hf', 'awq']
['meta_llama', 'hf', 'awq', 'gptq', 'qqq']
"""
arch = get_model_arch(model_path)[0]
register_name = SUPPORTED_ARCHS[arch]
Expand Down Expand Up @@ -93,12 +93,14 @@ def get_output_model_registered_name_and_config(model_path: str,
Args:
model_path (str): the path of the input model
model_format (str): the format of the model, which can be one of
['meta_llama', 'hf', 'awq']
group_size (int): the size of group used by awq model
['meta_llama', 'hf', 'awq', 'gptq', 'qqq']
group_size (int): the size of group used by quantization methods,
including `awq`, `gptq` and `qqq`
"""
register_name = 'tm'
turbomind_model_arch = 'llama'
weight_type = 'fp16'
quantization = ''

config = TurbomindModelConfig.from_dict({}, allow_none=True)

Expand All @@ -111,6 +113,9 @@ def get_output_model_registered_name_and_config(model_path: str,
if model_format in ['awq', 'gptq']:
weight_type = 'int4'
group_size = 128 if group_size == 0 else group_size
elif model_format == 'qqq':
weight_type = 'int4'
group_size = -1 if group_size == 0 else group_size
else:
torch_dtype = getattr(model_config, 'torch_dtype', 'float16')
TORCH_DTYPE_MAP = {torch.bfloat16: 'bf16', torch.float16: 'fp16'}
Expand All @@ -128,10 +133,12 @@ def get_output_model_registered_name_and_config(model_path: str,
config.session_len = session_len + 8
config.weight_type = weight_type
config.group_size = group_size
quantization = '' if weight_type in ['bf16', 'fp16'] else model_format
config.quantization = quantization

lora_type = 'plora' if turbomind_model_arch == 'xcomposer2' else ''

exporter_factory = get_exporter_factory(weight_type, lora_type)
exporter_factory = get_exporter_factory(quantization, lora_type)

return register_name, config, exporter_factory

Expand Down Expand Up @@ -225,6 +232,8 @@ def get_tm_model(model_path,
assert not quant_config.get('desc_act', False) and \
quant_config.get('sym', True), \
f'unsupported quant config: {quant_config}'
elif quant_method == 'qqq':
pass
else:
assert 0, f'unsupported quant_config: {quant_config}'

Expand All @@ -237,13 +246,20 @@ def get_tm_model(model_path,
f'model format is "{engine_config.model_format}" ' \
f'but group_size is {group_size}. Currently, only 128 ' \
'is supported'
if engine_config.model_format == 'qqq':
assert group_size in [-1, 128], \
f'model format is "{engine_config.model_format}" ' \
f'but group_size is {group_size}. Currently, only -1 and 128 ' \
'is supported'

input_model_name = get_input_model_registered_name(
model_path, engine_config.model_format)
input_policy = get_input_policy(engine_config.model_format)
input_model = INPUT_MODELS.get(input_model_name)(model_path=model_path,
tokenizer_path=model_path,
input_policy=input_policy)
input_model = INPUT_MODELS.get(input_model_name)(
model_path=model_path,
tokenizer_path=model_path,
input_policy=input_policy,
model_format=engine_config.model_format)

output_model_name, cfg, exporter_factory = \
get_output_model_registered_name_and_config(
Expand Down Expand Up @@ -291,17 +307,19 @@ def main(model_name: str,
model_name (str): unused any longer
model_path (str): the directory path of the model
model_format (str): the format of the model, should choose from
['meta_llama', 'hf', 'awq', None]. 'meta_llama' stands for META's
llama format, 'hf' means huggingface llama format, and 'awq' means
llama(hf) model quantized by lmdeploy/lite/quantization/awq.py.
The default value is None
chat_template (str): the name of the built-in chat template.
['meta_llama', 'hf', 'awq', 'qqq', None]. 'meta_llama' stands for
META's llama format, 'hf' means huggingface llama format,
'awq' means llama(hf) model quantized by
lmdeploy/lite/quantization/awq.py,
and 'qqq' means llama(hf) model quantized by the repo
https://github.com/HandH1998/QQQ,
the default value is None
tokenizer_path (str): the path of tokenizer model
dst_path (str): the destination path that saves outputs
tp (int): the number of GPUs used for tensor parallelism, should be 2^n
quant_path (str): Path of the quantized model, which can be None.
group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits
group_size (int): a parameter used in AWQ or QQQ to quantize fp16
weights to 4 bits
revision (str): The specific model version to use. It can be a branch
name, a tag name, or a commit id. If unspecified, will use
the default version.
Expand Down
Loading

0 comments on commit 754ef80

Please sign in to comment.