Skip to content

Commit

Permalink
add api call
Browse files Browse the repository at this point in the history
  • Loading branch information
drcege committed Oct 24, 2024
1 parent 71f0fec commit 63d430a
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 1 deletion.
2 changes: 1 addition & 1 deletion data_juicer/ops/filter/image_pair_similarity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self,
*args,
**kwargs):
"""
Initialization method.
Initialization method.
:param hf_clip: clip model name on huggingface to compute
the similarity between image and text.
Expand Down
91 changes: 91 additions & 0 deletions data_juicer/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional, Union

import multiprocess as mp
import requests
import wget
from loguru import logger

Expand Down Expand Up @@ -589,6 +590,95 @@ def prepare_opencv_classifier(model_path):
return model


class APIModel:

def __init__(self,
*,
api_url=None,
api_key=None,
response_path='choices.0.message.content'):
if api_url is None:
api_url = os.getenv('DJ_API_URL')
if api_url is None:
base_url = os.getenv('OPENAI_BASE_URL',
'https://api.openai.com/v1')
api_url = base_url.rstrip('/') + '/chat/completions'
self.api_url = api_url

if api_key is None:
api_key = os.getenv('DJ_API_KEY') or os.getenv('OPENAI_API_KEY')
self.api_key = api_key

self.headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}'
}
self.response_path = response_path

def __call__(self, *, messages, model, **kwargs):
"""Sends messages to the configured API model and returns the parsed response.
:param messages: The messages to send to the API.
:param model: The model to be used for generating responses.
:param kwargs: Additional parameters for the API request.
:return: The parsed response from the API, or None if an error occurs.
"""
payload = {
'model': model,
'messages': messages,
**kwargs,
}
try:
response = requests.post(self.api_url,
json=payload,
headers=self.headers)
response.raise_for_status()
result = response.json()
return self.nested_access(result, self.response_path)
except Exception as e:
logger.exception(e)
return None

@staticmethod
def nested_access(data, path):
"""Access nested data using a dot-separated path.
:param data: The data structure to access.
:param path: A dot-separated string representing the path to access.
:return: The value at the specified path, if it exists.
"""
keys = path.split('.')
for key in keys:
# Convert string keys to integers if they are numeric
key = int(key) if key.isdigit() else key
data = data[key]
return data


def prepare_api_model(*,
api_url=None,
api_key=None,
response_path='choices.0.message.content'):
"""Creates a callable API model for interacting with the OpenAI-compatible API.
This callable object supports custom result parsing and is suitable for use
with incompatible proxy servers.
:param api_url: The URL of the API. If not provided, it will fallback
to the environment variable or a default OpenAI URL.
:param api_key: The API key for authorization. If not provided, it will
fallback to the environment variable.
:param response_path: The path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:return: A callable API model object that can be used to send messages
and receive responses.
"""
return APIModel(api_url=api_url,
api_key=api_key,
response_path=response_path)


MODEL_FUNCTION_MAPPING = {
'fasttext': prepare_fasttext_model,
'sentencepiece': prepare_sentencepiece_for_lang,
Expand All @@ -602,6 +692,7 @@ def prepare_opencv_classifier(model_path):
'recognizeAnything': prepare_recognizeAnything_model,
'vllm': prepare_vllm_model,
'opencv_classifier': prepare_opencv_classifier,
'api': prepare_api_model,
}


Expand Down

0 comments on commit 63d430a

Please sign in to comment.