Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add extra options for export.py #91

Merged
merged 4 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 35 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,10 @@ This library provides the generative AI loop for ONNX models, including inferenc

Users can call a high level `generate()` method, or run each iteration of the model in a loop.

* Search techniques like greedy/beam search to generate token sequences
* Built in scoring tools like repetition penalties
* Support greedy/beam search and TopP, TopK sampling to generate token sequences
* Built in logits processing like repetition penalties
* Easy custom scoring

## Sample code for phi-2 in Python

Install onnxruntime-genai.

(Temporary) Build and install from source according to the instructions below.


```python
import onnxruntime_genai as og

model=og.Model(f'models/microsoft/phi-2', device_type)

tokenizer = model.create_tokenizer()

prompt = '''def print_prime(n):
"""
Print all primes between 1 and n
"""'''

tokens = tokenizer.encode(prompt)

params=og.SearchParams(model)
params.max_length = 200
params.input_ids = tokens

output_tokens=model.generate(params)

text = tokenizer.decode(output_tokens)

print("Output:")
print(text)
```


## Features

* Supported model architectures:
Expand Down Expand Up @@ -126,6 +92,39 @@ huggingface-cli login --token <your HuggingFace token>
python export.py -m microsoft/phi-2 -p int4 -e cpu -o phi2-int4-cpu.onnx
```

## Sample code for phi-2 in Python

Install onnxruntime-genai.

(Temporary) Build and install from source according to the instructions below.


```python
import onnxruntime_genai as og

model=og.Model(f'models/microsoft/phi-2', device_type)

tokenizer = model.create_tokenizer()

prompt = '''def print_prime(n):
"""
Print all primes between 1 and n
"""'''

tokens = tokenizer.encode(prompt)

params=og.SearchParams(model)
params.max_length = 200
params.input_ids = tokens

output_tokens=model.generate(params)

text = tokenizer.decode(output_tokens)

print("Output:")
print(text)
```


## Contributing

Expand Down
56 changes: 44 additions & 12 deletions src/python/models/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os

class Model:
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
self.intermediate_size = config.intermediate_size
self.hidden_size = config.hidden_size
self.num_kv_heads = config.num_key_value_heads
Expand All @@ -31,6 +31,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir):
self.onnx_dtype = onnx_dtype # {"int4", "fp16", "fp32"}
self.ep = ep
self.cache_dir = cache_dir
self.extra_options = extra_options

self.inputs = []
self.outputs = []
Expand Down Expand Up @@ -102,8 +103,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir):
# Quantization-specific variables (INT4, INT8, etc.)
self.quant_attrs = {
"int4": {
"block_size": 32,
"accuracy_level": None,
"block_size": int(extra_options["int4_block_size"]) if "int4_block_size" in extra_options else 32,
"accuracy_level": int(extra_options["int4_accuracy_level"]) if "int4_accuracy_level" in extra_options else None,
}
}

Expand Down Expand Up @@ -1028,8 +1029,8 @@ def make_common_mask_reformat_subgraph(self, basename, root_input, unsqueeze_for
return expand_name

class LlamaModel(Model):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir)
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
self.model_inputs = ["input_ids", "attention_mask", "position_ids"]

def make_attention_mask_reformatting(self):
Expand Down Expand Up @@ -1078,8 +1079,8 @@ def make_attention_mask_reformatting(self):


class MistralModel(LlamaModel):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir)
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
self.position_ids_name = self.make_position_ids_reformatting()

def make_position_ids_reformatting(self):
Expand Down Expand Up @@ -1122,8 +1123,8 @@ def make_attention(self, layer_id, attention, root_input, **kwargs):


class PhiModel(LlamaModel):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir)
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
# self.input_shapes["position_ids"] = [1] # Note: This is optional and only needed if you want position_ids to be an int instead of a 2D tensor
self.layernorm_attrs["simple"] = False
self.rotemb_attrs["num_heads"] = self.num_attn_heads
Expand Down Expand Up @@ -1232,16 +1233,31 @@ def make_layer(self, layer_id, layer):
self.layernorm_attrs["skip_input"] = f"{residual_add_name}/output_0"


def parse_extra_options(kv_items):
"""
Parse key value pairs that are separated by '='
"""
kv_pairs = {}

if kv_items:
for kv_str in kv_items:
kv = kv_str.split('=')
kv_pairs[kv[0].strip()] = kv[1].strip()
print(kv_pairs)
return kv_pairs


def create_model(args):
extra_kwargs = {} if os.path.exists(args.model_name_or_path) else {"cache_dir": args.cache_dir, "use_auth_token": True}
config = AutoConfig.from_pretrained(args.model_name_or_path, **extra_kwargs)

extra_options = parse_extra_options(args.extra_options)
if config.architectures[0] == "LlamaForCausalLM":
onnx_model = LlamaModel(config, args.io_dtype, args.precision, args.execution_provider, args.cache_dir)
onnx_model = LlamaModel(config, args.io_dtype, args.precision, args.execution_provider, args.cache_dir, extra_options)
elif config.architectures[0] == "MistralForCausalLM":
onnx_model = MistralModel(config, args.io_dtype, args.precision, args.execution_provider, args.cache_dir)
onnx_model = MistralModel(config, args.io_dtype, args.precision, args.execution_provider, args.cache_dir, extra_options)
elif config.architectures[0] == "PhiForCausalLM":
onnx_model = PhiModel(config, args.io_dtype, args.precision, args.execution_provider, args.cache_dir)
onnx_model = PhiModel(config, args.io_dtype, args.precision, args.execution_provider, args.cache_dir, extra_options)
else:
raise NotImplementedError(f"The {args.model_name_or_path} model is not currently supported.")

Expand All @@ -1251,6 +1267,7 @@ def create_model(args):
# Save ONNX model
onnx_model.save(args.output)


def get_args():
parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -1294,6 +1311,21 @@ def get_args():
help="Model cache directory (if providing model name and not folder path)",
)

parser.add_argument(
"--extra_options",
required=False,
metavar="KEY=VALUE",
nargs='+',
default=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'cache_dir'),
help="""
Key value pairs for various options. Currently support:
int4_block_size = 16/32/64/128/256: Specify the block_size for int4 quantization.
int4_accuracy_level = 1/2/3/4: Specify the minimum accuracy level for activation of matmul in int4 quantization.
4 is int8, which means input A of int4 quantized matmul is quantized to int8 and input B is upcasted to int8 for computation.
1 is fp32, 2 is fp16, and 3 is bf16.
""",
)

args = parser.parse_args()

print("Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, INT4 CPU, INT4 CUDA")
Expand Down
Loading