Skip to content

Commit

Permalink
Update chat app examples
Browse files Browse the repository at this point in the history
  • Loading branch information
kunal-vaishnavi committed Dec 19, 2024
1 parent 6518b82 commit 744a1e1
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 32 deletions.
19 changes: 15 additions & 4 deletions examples/chat_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def change_model_listener(new_model_name):
if "vision" in new_model_name:
print("Configuring for multi-modal model")
interface = MultiModal_ONNXModel(
model_path=d["model_dir"]
model_path=d["model_dir"], execution_provider=d["provider"],
)
else:
print("Configuring for language-only model")
interface = ONNXModel(
model_path=d["model_dir"]
model_path=d["model_dir"], execution_provider=d["provider"],
)

# interface.initialize()
Expand Down Expand Up @@ -74,15 +74,26 @@ def interface_retry(*args):
yield from res


def get_ep_name(name):
new_name = name.lower().replace("directml", "dml")
if "cpu" in new_name:
return "cpu"
elif "cuda" in new_name:
return "cuda"
elif "dml" in new_name:
return "dml"
raise ValueError(f"{new_name} is not recognized.")


def launch_chat_app(expose_locally: bool = False, model_name: str = "", model_path: str = ""):
if os.path.exists(optimized_directory):
for ep_name in os.listdir(optimized_directory):
sub_optimized_directory = os.path.join(optimized_directory, ep_name)
for model_name in os.listdir(sub_optimized_directory):
available_models[model_name] = {"model_dir": os.path.join(sub_optimized_directory, model_name)}
available_models[model_name] = {"model_dir": os.path.join(sub_optimized_directory, model_name), "provider": get_ep_name(ep_name)}

if model_path:
available_models[model_name] = {"model_dir": model_path}
available_models[model_name] = {"model_dir": model_path, "provider": get_ep_name(ep_name)}

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'ep_name' may be used before it is initialized.

with gr.Blocks(css=custom_css, theme=small_and_beautiful_theme) as demo:
history = gr.State([])
Expand Down
20 changes: 15 additions & 5 deletions examples/chat_app/interface/hddr_llm_onnx_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,24 @@
class ONNXModel():
"""A wrapper for OnnxRuntime-GenAI to run ONNX LLM model."""

def __init__(self, model_path):
def __init__(self, model_path, execution_provider):
self.og = og
self.model = og.Model(f'{model_path}')

logging.info("Loading model...")
if hasattr(og, "Config"):
self.config = og.Config(model_path)
self.config.clear_providers()
if execution_provider != "cpu":
self.config.append_provider(execution_provider)
self.model = og.Model(self.config)
else:
self.model = og.Model(model_path)
logging.info("Loaded model...")

self.tokenizer = og.Tokenizer(self.model)
self.tokenizer_stream = self.tokenizer.create_stream()
self.model_path = model_path

if "phi" in self.model_path:
self.template_header = ""
self.enable_history_max = 10 if "mini" in self.model_path else 2
Expand Down Expand Up @@ -69,17 +81,15 @@ def search(
output_tokens = []

params = og.GeneratorParams(self.model)
params.input_ids = input_ids

search_options = {"max_length" : max_length}
params.set_search_options(**search_options)

generator = og.Generator(self.model, params)
generator.append_tokens(input_ids)

idx = 0
while not generator.is_done():
idx += 1
generator.compute_logits()
generator.generate_next_token()
next_token = generator.get_next_tokens()[0]
output_tokens.append(next_token)
Expand Down
48 changes: 27 additions & 21 deletions examples/chat_app/interface/multimodal_onnx_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,20 @@ class MultiModal_ONNXModel():

"""A wrapper for ONNXRuntime GenAI to run ONNX Multimodal model"""

def __init__(self, model_path):
def __init__(self, model_path, execution_provider):
self.og = og
logging.info("Loading model ...")
self.model = og.Model(f'{model_path}')

logging.info("Loading model...")
if hasattr(og, "Config"):
self.config = og.Config(model_path)
self.config.clear_providers()
if execution_provider != "cpu":
self.config.append_provider(execution_provider)
self.model = og.Model(self.config)
else:
self.model = og.Model(model_path)
logging.info("Loaded model ...")

self.processor = self.model.create_multimodal_processor()
self.tokenizer = self.processor.create_stream()

Expand All @@ -23,7 +33,6 @@ def __init__(self, model_path):
self.chat_template = "<|user|>\n{tags}\n{input}<|end|>\n<|assistant|>\n"

def generate_prompt_with_history(self, images, history, text=default_prompt, max_length=3072):

prompt = ""

for dialog in history[-self.enable_history_max:]:
Expand All @@ -43,16 +52,14 @@ def generate_prompt_with_history(self, images, history, text=default_prompt, max
self.images = og.Images.open(*images)

logging.info("Preprocessing images and prompt ...")
input_ids = self.processor(prompt, images=self.images)

return input_ids
inputs = self.processor(prompt, images=self.images)
return inputs


def search(self, input_ids, max_length: int = 3072, token_printing_step: int = 1):

def search(self, inputs, max_length: int = 3072, token_printing_step: int = 1):
output = ""
params = og.GeneratorParams(self.model)
params.set_inputs(input_ids)
params.set_inputs(inputs)

search_options = {"max_length": max_length}
params.set_search_options(**search_options)
Expand All @@ -61,7 +68,6 @@ def search(self, input_ids, max_length: int = 3072, token_printing_step: int = 1
idx = 0
while not generator.is_done():
idx += 1
generator.compute_logits()
generator.generate_next_token()
next_token = generator.get_next_tokens()[0]
output += self.tokenizer.decode(next_token)
Expand All @@ -74,18 +80,18 @@ def predict(self, text, chatbot, history, max_length_tokens, max_context_length_
yield chatbot, history, "Empty context"
return

input_ids = self.generate_prompt_with_history(
text=text,
history=history,
images=args[0],
max_length=max_context_length_tokens
inputs = self.generate_prompt_with_history(
text=text,
history=history,
images=args[0],
max_length=max_context_length_tokens
)

sentence = self.search(
input_ids,
max_length=max_length_tokens,
token_printing_step=token_printing_step,
)
inputs,
max_length=max_length_tokens,
token_printing_step=token_printing_step,
)

sentence = sentence.strip()
a, b = [[y[0], convert_to_markdown(y[1])] for y in history] + [[text, convert_to_markdown(sentence)]], [
Expand All @@ -103,7 +109,7 @@ def predict(self, text, chatbot, history, max_length_tokens, max_context_length_
except Exception as e:
print(type(e).__name__, e)

del input_ids
del inputs
gc.collect()

try:
Expand Down
4 changes: 2 additions & 2 deletions examples/python/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ONNX Runtime GenAI API Python Examples
# ONNX Runtime GenAI Python Examples

## Install the onnxruntime-genai library
## Install ONNX Runtime GenAI

Install the python package according to the [installation instructions](https://onnxruntime.ai/docs/genai/howto/install).

Expand Down

0 comments on commit 744a1e1

Please sign in to comment.