From 63d430a7c9b785aa18c659416ff50eea5f8a6a73 Mon Sep 17 00:00:00 2001 From: null <3213204+drcege@users.noreply.github.com> Date: Wed, 23 Oct 2024 15:32:32 +0800 Subject: [PATCH] add api call --- .../filter/image_pair_similarity_filter.py | 2 +- data_juicer/utils/model_utils.py | 91 +++++++++++++++++++ 2 files changed, 92 insertions(+), 1 deletion(-) diff --git a/data_juicer/ops/filter/image_pair_similarity_filter.py b/data_juicer/ops/filter/image_pair_similarity_filter.py index 3299f9ad2..de576f07e 100644 --- a/data_juicer/ops/filter/image_pair_similarity_filter.py +++ b/data_juicer/ops/filter/image_pair_similarity_filter.py @@ -30,7 +30,7 @@ def __init__(self, *args, **kwargs): """ - Initialization method. + Initialization method. :param hf_clip: clip model name on huggingface to compute the similarity between image and text. diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index cda046b81..08a0c37c2 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -5,6 +5,7 @@ from typing import Optional, Union import multiprocess as mp +import requests import wget from loguru import logger @@ -589,6 +590,95 @@ def prepare_opencv_classifier(model_path): return model +class APIModel: + + def __init__(self, + *, + api_url=None, + api_key=None, + response_path='choices.0.message.content'): + if api_url is None: + api_url = os.getenv('DJ_API_URL') + if api_url is None: + base_url = os.getenv('OPENAI_BASE_URL', + 'https://api.openai.com/v1') + api_url = base_url.rstrip('/') + '/chat/completions' + self.api_url = api_url + + if api_key is None: + api_key = os.getenv('DJ_API_KEY') or os.getenv('OPENAI_API_KEY') + self.api_key = api_key + + self.headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {self.api_key}' + } + self.response_path = response_path + + def __call__(self, *, messages, model, **kwargs): + """Sends messages to the configured API model and returns the parsed response. + + :param messages: The messages to send to the API. + :param model: The model to be used for generating responses. + :param kwargs: Additional parameters for the API request. + + :return: The parsed response from the API, or None if an error occurs. + """ + payload = { + 'model': model, + 'messages': messages, + **kwargs, + } + try: + response = requests.post(self.api_url, + json=payload, + headers=self.headers) + response.raise_for_status() + result = response.json() + return self.nested_access(result, self.response_path) + except Exception as e: + logger.exception(e) + return None + + @staticmethod + def nested_access(data, path): + """Access nested data using a dot-separated path. + + :param data: The data structure to access. + :param path: A dot-separated string representing the path to access. + :return: The value at the specified path, if it exists. + """ + keys = path.split('.') + for key in keys: + # Convert string keys to integers if they are numeric + key = int(key) if key.isdigit() else key + data = data[key] + return data + + +def prepare_api_model(*, + api_url=None, + api_key=None, + response_path='choices.0.message.content'): + """Creates a callable API model for interacting with the OpenAI-compatible API. + + This callable object supports custom result parsing and is suitable for use + with incompatible proxy servers. + + :param api_url: The URL of the API. If not provided, it will fallback + to the environment variable or a default OpenAI URL. + :param api_key: The API key for authorization. If not provided, it will + fallback to the environment variable. + :param response_path: The path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :return: A callable API model object that can be used to send messages + and receive responses. + """ + return APIModel(api_url=api_url, + api_key=api_key, + response_path=response_path) + + MODEL_FUNCTION_MAPPING = { 'fasttext': prepare_fasttext_model, 'sentencepiece': prepare_sentencepiece_for_lang, @@ -602,6 +692,7 @@ def prepare_opencv_classifier(model_path): 'recognizeAnything': prepare_recognizeAnything_model, 'vllm': prepare_vllm_model, 'opencv_classifier': prepare_opencv_classifier, + 'api': prepare_api_model, }