Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
Signed-off-by: Zhang, Weiwei1 <[email protected]>
  • Loading branch information
WeiweiZhang1 committed Jun 19, 2024
1 parent 77320b0 commit 0dfe81f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
3 changes: 2 additions & 1 deletion auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
output_q = block_forward(
block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device
)
if self.amp and not check_is_cpu(device):
if self.amp:
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
loss = mse_loss(output_q, current_output) # pylint: disable=not-callable
else:
Expand Down Expand Up @@ -1416,3 +1416,4 @@ def __init__(
optimizer,
**kwargs,
)

9 changes: 8 additions & 1 deletion examples/language-modeling/eval_042/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,10 @@ def evaluate(
parser.add_argument(
"--eval_bs", default=1,
)
parser.add_argument(
"--trust_remote_code", action='store_true',
help="Whether to enable trust_remote_code"
)
parser.add_argument("--tasks",
default="lambada_openai,hellaswag,winogrande,piqa,mmlu,truthfulqa_mc1," \
"openbookqa,boolq,rte,arc_easy,arc_challenge",
Expand All @@ -582,7 +586,7 @@ def evaluate(
s = time.time()
from transformers import AutoConfig

config = AutoConfig.from_pretrained(args.model_name)
config = AutoConfig.from_pretrained(args.model_name, trust_remote_code=args.trust_remote_code)

if hasattr(config, "quantization_config"):
quantization_config = config.quantization_config
Expand All @@ -599,10 +603,13 @@ def evaluate(
if config.torch_dtype == torch.float32:
model_args += ",dtype=float16"
# model_args += ",dtype=float16"
if args.trust_remote_code:
model_args += f",trust_remote_code=True"
result = simple_evaluate(model="hf",
model_args=model_args,
tasks=test_tasks,
batch_size=args.eval_bs)
print(make_table(result))

print("cost time: ", time.time() - s)

9 changes: 3 additions & 6 deletions examples/language-modeling/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,16 +231,12 @@ def get_library_version(library_name):
args.seqlen = seqlen

excel_name = f"{model_name}_{args.bits}_{args.group_size}"
pt_dtype = torch.float16
if (hasattr(model, 'config') and (model.dtype is torch.bfloat16 or model.config.torch_dtype is torch.bfloat16)):
dtype = 'bfloat16'
pt_dtype = torch.bfloat16
else:
if str(args.device) != "cpu":
pt_dtype = torch.float16
if "cpu" not in device_str:
dtype = 'float16'
else:
pt_dtype = torch.float32
dtype = 'float32'

excel_name = f"{model_name}_{args.bits}_{args.group_size}"
Expand Down Expand Up @@ -313,7 +309,7 @@ def get_library_version(library_name):
model_name = args.model_name.rstrip("/")

model.eval()
if args.device != "cpu":
if "cpu" not in device_str:
torch.cuda.empty_cache()

export_dir = args.output_dir + "/" + model_name.split('/')[-1] + f"-autoround-w{args.bits}g{args.group_size}"
Expand Down Expand Up @@ -357,3 +353,4 @@ def get_library_version(library_name):
from lm_eval.utils import make_table

print(make_table(res))

0 comments on commit 0dfe81f

Please sign in to comment.