Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FC alignment #413

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

import numpy as np
from custom_exception import BadAPIStatusError
from eval_checker_constant import FILENAME_INDEX_MAPPING
from model_handler.handler_map import handler_map
from tqdm import tqdm
from eval_checker_constant import FILENAME_INDEX_MAPPING

REST_API_GROUND_TRUTH_FILE_PATH = "api_status_check_ground_truth_REST.json"
EXECTUABLE_API_GROUND_TRUTH_FILE_PATH = "api_status_check_ground_truth_executable.json"
Expand Down Expand Up @@ -230,6 +230,12 @@
"Fireworks",
"Apache 2.0",
],
"llama-v3-70b-instruct-hf": [
"FireFunction-v1 (FC)",
"https://huggingface.co/fireworks-ai/firefunction-v1",
"Fireworks",
"Apache 2.0",
],
"gemini-1.0-pro": [
"Gemini-1.0-Pro (FC)",
"https://deepmind.google/technologies/gemini/#introduction",
Expand Down Expand Up @@ -325,7 +331,7 @@
"https://huggingface.co/Snowflake/snowflake-arctic-instruct",
"Snowflake",
"apache-2.0",
]
],
}

INPUT_PRICE_PER_MILLION_TOKEN = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import json
import os

from model_handler.gpt_handler import OpenAIHandler
from model_handler.model_style import ModelStyle
import os, json
from openai import OpenAI


class FireworkAIHandler(OpenAIHandler):
def __init__(self, model_name, temperature=0.7, top_p=1, max_tokens=1000) -> None:
super().__init__(model_name, temperature, top_p, max_tokens)
self.model_name = "accounts/fireworks/models/firefunction-v1-FC"
# self.model_name = "accounts/fireworks/models/firefunction-v1-FC"
# self.model_name = "accounts/fireworks/models/fc-pawel-v2-14-FC"
self.model_name = "accounts/fireworks/models/dt-fc-rc-v5-FC"
self.temperature = 0
self.model_style = ModelStyle.FIREWORK_AI

self.client = OpenAI(
Expand Down
54 changes: 35 additions & 19 deletions berkeley-function-call-leaderboard/model_handler/gpt_handler.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import json
import os
import time

from model_handler.constant import (
GORILLA_TO_OPENAPI,
GORILLA_TO_PYTHON,
SYSTEM_PROMPT_FOR_CHAT_MODEL,
USER_PROMPT_FOR_CHAT_MODEL,
)
from model_handler.handler import BaseHandler
from model_handler.model_style import ModelStyle
from model_handler.utils import (
convert_to_tool,
convert_to_function_call,
ast_parse,
augment_prompt_by_languge,
convert_to_function_call,
convert_to_tool,
language_specific_pre_processing,
ast_parse,
)
from model_handler.constant import (
GORILLA_TO_OPENAPI,
GORILLA_TO_PYTHON,
USER_PROMPT_FOR_CHAT_MODEL,
SYSTEM_PROMPT_FOR_CHAT_MODEL,
)
from openai import OpenAI
import os, time, json


class OpenAIHandler(BaseHandler):
Expand All @@ -23,10 +26,12 @@ def __init__(self, model_name, temperature=0.7, top_p=1, max_tokens=1000) -> Non
self.model_style = ModelStyle.OpenAI
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

def inference(self, prompt,functions,test_category):
def inference(self, prompt, functions, test_category):
if "FC" not in self.model_name:
prompt = augment_prompt_by_languge(prompt,test_category)
functions = language_specific_pre_processing(functions,test_category,False)
prompt = augment_prompt_by_languge(prompt, test_category)
functions = language_specific_pre_processing(
functions, test_category, False
)
message = [
{
"role": "system",
Expand Down Expand Up @@ -55,10 +60,19 @@ def inference(self, prompt,functions,test_category):
functions = language_specific_pre_processing(functions, test_category, True)
if type(functions) is not list:
functions = [functions]
# message = [
# {
# "role": "system",
# "content": "You will be presented with a question. Respond with a single or multiple function calls",
# },
# {"role": "user", "content": "Questions:" + prompt},
# ]
message = [{"role": "user", "content": "Questions:" + prompt}]
oai_tool = convert_to_tool(
functions, GORILLA_TO_OPENAPI, self.model_style, test_category, True
)
print(f"DEBUG: message: {message}")
print(f"DEBUG: tools: {oai_tool}")
start_time = time.time()
if len(oai_tool) > 0:
response = self.client.chat.completions.create(
Expand All @@ -68,6 +82,7 @@ def inference(self, prompt,functions,test_category):
max_tokens=self.max_tokens,
top_p=self.top_p,
tools=oai_tool,
frequency_penalty=0.8,
)
else:
response = self.client.chat.completions.create(
Expand All @@ -76,6 +91,7 @@ def inference(self, prompt,functions,test_category):
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
frequency_penalty=0.8,
)
latency = time.time() - start_time
try:
Expand All @@ -89,11 +105,11 @@ def inference(self, prompt,functions,test_category):
metadata["input_tokens"] = response.usage.prompt_tokens
metadata["output_tokens"] = response.usage.completion_tokens
metadata["latency"] = latency
return result,metadata
def decode_ast(self,result,language="Python"):
return result, metadata

def decode_ast(self, result, language="Python"):
if "FC" not in self.model_name:
decoded_output = ast_parse(result,language)
decoded_output = ast_parse(result, language)
else:
decoded_output = []
for invoked_function in result:
Expand All @@ -107,8 +123,8 @@ def decode_ast(self,result,language="Python"):
params[key] = str(params[key])
decoded_output.append({name: params})
return decoded_output
def decode_execute(self,result):

def decode_execute(self, result):
if "FC" not in self.model_name:
decoded_output = ast_parse(result)
execution_list = []
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from model_handler.arctic_handler import ArcticHandler
from model_handler.claude_fc_handler import ClaudeFCHandler
from model_handler.claude_prompt_handler import ClaudePromptingHandler
from model_handler.cohere_handler import CohereHandler
from model_handler.databricks_handler import DatabricksHandler
from model_handler.deepseek_handler import DeepseekHandler
from model_handler.firework_ai_handler import FireworkAIHandler
from model_handler.fireworks_llama import FireworksLlamaHandler
from model_handler.functionary_handler import FunctionaryHandler
from model_handler.gemini_handler import GeminiHandler
from model_handler.gemma_handler import GemmaHandler
Expand All @@ -14,8 +17,6 @@
from model_handler.mistral_handler import MistralHandler
from model_handler.nexus_handler import NexusHandler
from model_handler.oss_handler import OSSHandler
from model_handler.cohere_handler import CohereHandler
from model_handler.arctic_handler import ArcticHandler

handler_map = {
"gorilla-openfunctions-v0": GorillaHandler,
Expand Down Expand Up @@ -67,4 +68,5 @@
"command-r-plus-FC-optimized": CohereHandler,
"command-r-plus-optimized": CohereHandler,
"snowflake/arctic": ArcticHandler,
"llama-v3-70b-instruct-hf": FireworksLlamaHandler,
}
10 changes: 7 additions & 3 deletions berkeley-function-call-leaderboard/openfunctions_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import argparse, json, os
from tqdm import tqdm
import argparse
import json
import os

from model_handler.constant import USE_COHERE_OPTIMIZATION
from model_handler.handler_map import handler_map
from model_handler.model_style import ModelStyle
from model_handler.constant import USE_COHERE_OPTIMIZATION
from tqdm import tqdm


def get_args():
Expand Down Expand Up @@ -31,6 +34,7 @@ def get_args():
"simple": "gorilla_openfunctions_v1_test_simple.json",
"relevance": "gorilla_openfunctions_v1_test_relevance.json",
"parallel_function": "gorilla_openfunctions_v1_test_parallel_function.json",
"parallel_function_local": "gorilla_openfunctions_v1_test_parallel_function_local.json",
"multiple_function": "gorilla_openfunctions_v1_test_multiple_function.json",
"parallel_multiple_function": "gorilla_openfunctions_v1_test_parallel_multiple_function.json",
"java": "gorilla_openfunctions_v1_test_java.json",
Expand Down