-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Quantization Support #46
Changes from 3 commits
1700b8a
5a5bd2b
85b6eda
a36e022
bcef8ed
2c7d454
094c071
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -188,6 +188,10 @@ def create_isvc( | |
client.V1EnvVar( | ||
name="NAI_MAX_TOKENS", value=str(model_params["max_new_tokens"]) | ||
), | ||
client.V1EnvVar( | ||
name="NAI_QUANTIZATION", | ||
value=str(model_params["quantize_bits"]), | ||
), | ||
], | ||
resources=client.V1ResourceRequirements( | ||
limits={ | ||
|
@@ -364,6 +368,7 @@ def execute(params: argparse.Namespace) -> None: | |
input_path = params.data | ||
mount_path = params.mount_path | ||
model_timeout = params.model_timeout | ||
quantize_bits = params.quantize_bits | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. white space on top There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
|
||
check_if_path_exists(mount_path, "local nfs mount", is_dir=True) | ||
if not nfs_path or not nfs_server: | ||
|
@@ -382,6 +387,15 @@ def execute(params: argparse.Namespace) -> None: | |
model_info["repo_id"] = model_params["repo_id"] | ||
model_info["repo_version"] = check_if_valid_version(model_info, mount_path) | ||
|
||
if quantize_bits and int(quantize_bits) not in [4, 8]: | ||
print("## Quantization precision bits should be either 4 or 8") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There can be question, why it's not taking 16 as well. Add a text, print("## Quantization precision bits should be either 4 or 8. Default precision used is 16") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed the mentioned message |
||
sys.exit(1) | ||
elif quantize_bits and deployment_resources["gpus"]: | ||
print("## BitsAndBytes Quantization requires GPUs") | ||
sys.exit(1) | ||
else: | ||
model_params["quantize_bits"] = quantize_bits | ||
|
||
config.load_kube_config() | ||
core_api = client.CoreV1Api() | ||
|
||
|
@@ -434,6 +448,12 @@ def execute(params: argparse.Namespace) -> None: | |
default=None, | ||
help="HuggingFace Hub token to download LLAMA(2) models", | ||
) | ||
parser.add_argument( | ||
"--quantize_bits", | ||
type=str, | ||
default="", | ||
help="BitsAndBytes Quantization Precision (4 or 8)", | ||
) | ||
# Parse the command-line arguments | ||
args = parser.parse_args() | ||
execute(args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
torch-model-archiver==0.8.1 | ||
kubernetes==28.1.0 | ||
kserve==0.11.1 | ||
huggingface-hub==0.17.1 | ||
huggingface-hub==0.20.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep it simple here. you are doing the check, just to type cast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed as suggested