From 93799fc9bc6710da606471f974b9084677a858db Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Tue, 13 Feb 2024 16:01:34 -0800 Subject: [PATCH] fixes --- src/autora/doc/pipelines/train.py | 6 +++--- src/autora/doc/runtime/predict_hf.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/autora/doc/pipelines/train.py b/src/autora/doc/pipelines/train.py index 870a016..ad81518 100644 --- a/src/autora/doc/pipelines/train.py +++ b/src/autora/doc/pipelines/train.py @@ -33,7 +33,7 @@ def fine_tune(base_model: str, new_model_name: str, dataset: Dataset) -> None: model = AutoModelForCausalLM.from_pretrained( base_model, - *kwargs, + **kwargs, ) model.config.use_cache = False model.config.pretraining_tp = 1 @@ -61,8 +61,8 @@ def fine_tune(base_model: str, new_model_name: str, dataset: Dataset) -> None: logging_steps=1, # TODO: Increase once there's more data learning_rate=2e-4, weight_decay=0.001, - fp16=False, - bf16=cuda_available, + fp16=cuda_available, + bf16=False, max_grad_norm=0.3, max_steps=-1, warmup_ratio=0.03, diff --git a/src/autora/doc/runtime/predict_hf.py b/src/autora/doc/runtime/predict_hf.py index 97a7f4f..d8bf424 100644 --- a/src/autora/doc/runtime/predict_hf.py +++ b/src/autora/doc/runtime/predict_hf.py @@ -107,5 +107,5 @@ def get_quantization_config() -> Any: load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_compute_dtype=torch.float16, )