Skip to content

Commit

Permalink
add extra options for export.py (#91)
Browse files Browse the repository at this point in the history
1. Add an extra_options option for export. It allows user to easily
specify some parameters, like for quantization.
2. Re-order the ReadMe a littlt bit.

---------

Co-authored-by: kunal-vaishnavi <[email protected]>
  • Loading branch information
yufenglee and kunal-vaishnavi authored Feb 15, 2024
1 parent 924ef8f commit 8c76ac8
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 48 deletions.
71 changes: 35 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,10 @@ This library provides the generative AI loop for ONNX models, including inferenc

Users can call a high level `generate()` method, or run each iteration of the model in a loop.

* Search techniques like greedy/beam search to generate token sequences
* Built in scoring tools like repetition penalties
* Support greedy/beam search and TopP, TopK sampling to generate token sequences
* Built in logits processing like repetition penalties
* Easy custom scoring

## Sample code for phi-2 in Python

Install onnxruntime-genai.

(Temporary) Build and install from source according to the instructions below.


```python
import onnxruntime_genai as og

model=og.Model(f'models/microsoft/phi-2', device_type)

tokenizer = model.create_tokenizer()

prompt = '''def print_prime(n):
"""
Print all primes between 1 and n
"""'''

tokens = tokenizer.encode(prompt)

params=og.SearchParams(model)
params.max_length = 200
params.input_ids = tokens

output_tokens=model.generate(params)

text = tokenizer.decode(output_tokens)

print("Output:")
print(text)
```


## Features

* Supported model architectures:
Expand Down Expand Up @@ -126,6 +92,39 @@ huggingface-cli login --token <your HuggingFace token>
python export.py -m microsoft/phi-2 -p int4 -e cpu -o phi2-int4-cpu.onnx
```

## Sample code for phi-2 in Python

Install onnxruntime-genai.

(Temporary) Build and install from source according to the instructions below.


```python
import onnxruntime_genai as og

model=og.Model(f'models/microsoft/phi-2', device_type)

tokenizer = model.create_tokenizer()

prompt = '''def print_prime(n):
"""
Print all primes between 1 and n
"""'''

tokens = tokenizer.encode(prompt)

params=og.SearchParams(model)
params.max_length = 200
params.input_ids = tokens

output_tokens=model.generate(params)

text = tokenizer.decode(output_tokens)

print("Output:")
print(text)
```


## Contributing

Expand Down
56 changes: 44 additions & 12 deletions src/python/models/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os

class Model:
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
self.intermediate_size = config.intermediate_size
self.hidden_size = config.hidden_size
self.num_kv_heads = config.num_key_value_heads
Expand All @@ -31,6 +31,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir):
self.onnx_dtype = onnx_dtype # {"int4", "fp16", "fp32"}
self.ep = ep
self.cache_dir = cache_dir
self.extra_options = extra_options

self.inputs = []
self.outputs = []
Expand Down Expand Up @@ -102,8 +103,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir):
# Quantization-specific variables (INT4, INT8, etc.)
self.quant_attrs = {
"int4": {
"block_size": 32,
"accuracy_level": None,
"block_size": int(extra_options["int4_block_size"]) if "int4_block_size" in extra_options else 32,
"accuracy_level": int(extra_options["int4_accuracy_level"]) if "int4_accuracy_level" in extra_options else None,
}
}

Expand Down Expand Up @@ -1028,8 +1029,8 @@ def make_common_mask_reformat_subgraph(self, basename, root_input, unsqueeze_for
return expand_name

class LlamaModel(Model):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir)
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
self.model_inputs = ["input_ids", "attention_mask", "position_ids"]

def make_attention_mask_reformatting(self):
Expand Down Expand Up @@ -1078,8 +1079,8 @@ def make_attention_mask_reformatting(self):


class MistralModel(LlamaModel):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir)
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
self.position_ids_name = self.make_position_ids_reformatting()

def make_position_ids_reformatting(self):
Expand Down Expand Up @@ -1122,8 +1123,8 @@ def make_attention(self, layer_id, attention, root_input, **kwargs):


class PhiModel(LlamaModel):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir)
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
# self.input_shapes["position_ids"] = [1] # Note: This is optional and only needed if you want position_ids to be an int instead of a 2D tensor
self.layernorm_attrs["simple"] = False
self.rotemb_attrs["num_heads"] = self.num_attn_heads
Expand Down Expand Up @@ -1232,16 +1233,31 @@ def make_layer(self, layer_id, layer):
self.layernorm_attrs["skip_input"] = f"{residual_add_name}/output_0"


def parse_extra_options(kv_items):
"""
Parse key value pairs that are separated by '='
"""
kv_pairs = {}

if kv_items:
for kv_str in kv_items:
kv = kv_str.split('=')
kv_pairs[kv[0].strip()] = kv[1].strip()
print(kv_pairs)
return kv_pairs


def create_model(args):
extra_kwargs = {} if os.path.exists(args.model_name_or_path) else {"cache_dir": args.cache_dir, "use_auth_token": True}
config = AutoConfig.from_pretrained(args.model_name_or_path, **extra_kwargs)

extra_options = parse_extra_options(args.extra_options)
if config.architectures[0] == "LlamaForCausalLM":
onnx_model = LlamaModel(config, args.io_dtype, args.precision, args.execution_provider, args.cache_dir)
onnx_model = LlamaModel(config, args.io_dtype, args.precision, args.execution_provider, args.cache_dir, extra_options)
elif config.architectures[0] == "MistralForCausalLM":
onnx_model = MistralModel(config, args.io_dtype, args.precision, args.execution_provider, args.cache_dir)
onnx_model = MistralModel(config, args.io_dtype, args.precision, args.execution_provider, args.cache_dir, extra_options)
elif config.architectures[0] == "PhiForCausalLM":
onnx_model = PhiModel(config, args.io_dtype, args.precision, args.execution_provider, args.cache_dir)
onnx_model = PhiModel(config, args.io_dtype, args.precision, args.execution_provider, args.cache_dir, extra_options)
else:
raise NotImplementedError(f"The {args.model_name_or_path} model is not currently supported.")

Expand All @@ -1251,6 +1267,7 @@ def create_model(args):
# Save ONNX model
onnx_model.save(args.output)


def get_args():
parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -1294,6 +1311,21 @@ def get_args():
help="Model cache directory (if providing model name and not folder path)",
)

parser.add_argument(
"--extra_options",
required=False,
metavar="KEY=VALUE",
nargs='+',
default=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'cache_dir'),
help="""
Key value pairs for various options. Currently support:
int4_block_size = 16/32/64/128/256: Specify the block_size for int4 quantization.
int4_accuracy_level = 1/2/3/4: Specify the minimum accuracy level for activation of matmul in int4 quantization.
4 is int8, which means input A of int4 quantized matmul is quantized to int8 and input B is upcasted to int8 for computation.
1 is fp32, 2 is fp16, and 3 is bf16.
""",
)

args = parser.parse_args()

print("Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, INT4 CPU, INT4 CUDA")
Expand Down

0 comments on commit 8c76ac8

Please sign in to comment.