Skip to content

Commit

Permalink
[Ready] Add API Call & Example OPs (#463)
Browse files Browse the repository at this point in the history
* add api call

* add call_api ops

* clean

* minor update

* more tests

* update tests

* update prompts

* fix unittest

* update tests

* add docs

* minor fix

* add API processor

* refine API  processor

* refine

* fix bugs

* fix tests

* refine tests
  • Loading branch information
drcege authored Nov 7, 2024
1 parent 6badfa8 commit fe2b4cf
Show file tree
Hide file tree
Showing 27 changed files with 776 additions and 224 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/docker/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ services:
- TORCH_HOME=/data/torch
- NLTK_DATA=/data/nltk
- DATA_JUICER_CACHE_HOME=/data/dj
- EASYOCR_MODULE_PATH=/data/EasyOCR
- RAY_ADDRESS=auto
working_dir: /workspace
networks:
Expand Down Expand Up @@ -39,6 +40,7 @@ services:
- TORCH_HOME=/data/torch
- NLTK_DATA=/data/nltk
- DATA_JUICER_CACHE_HOME=/data/dj
- EASYOCR_MODULE_PATH=/data/EasyOCR
working_dir: /workspace
volumes:
- huggingface_cache:/data
Expand Down
14 changes: 14 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ hpo_config: null # path to a configur
process:
# Mapper ops. Most of these ops need no arguments.
- audio_ffmpeg_wrapped_mapper: # simple wrapper for FFmpeg audio filters
- calibrate_qa_mapper: # calibrate question-answer pairs based on reference text.
api_model: 'gpt-4o' # API model name.
api_url: null # API URL. Defaults to DJ_API_URL environment variable.
api_key: null # API key. Defaults to DJ_API_KEY environment variable.
response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
system_prompt: null # System prompt for the calibration task.
input_template: null # Template for building the model input.
reference_template: null # Template for formatting the reference text.
qa_pair_template: null # Template for formatting question-answer pairs.
output_pattern: null # Regular expression for parsing model output.
model_params: null # Parameters for initializing the model.
sampling_params: null # Extra parameters passed to the API call.
- calibrate_query_mapper: # calibrate query in question-answer pairs based on reference text.
- calibrate_response_mapper: # calibrate response in question-answer pairs based on reference text.
- chinese_convert_mapper: # convert Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji.
mode: 's2t' # choose the mode to convert Chinese: ['s2t', 't2s', 's2tw', 'tw2s', 's2hk', 'hk2s', 's2twp', 'tw2sp', 't2tw', 'tw2t', 'hk2t', 't2hk', 't2jp', 'jp2t']
- clean_email_mapper: # remove emails from text.
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def map(self, *args, **kargs):
if callable(getattr(
called_func.__self__,
'is_batched_op')) and called_func.__self__.is_batched_op(
) or not called_func.__self__.turbo:
) or not getattr(called_func.__self__, 'turbo', False):
kargs['batched'] = True
kargs['batch_size'] = kargs.pop('batch_size', 1) if hasattr(
called_func.__self__, 'is_batched_op'
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/ops/filter/image_pair_similarity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self,
*args,
**kwargs):
"""
Initialization method.
Initialization method.
:param hf_clip: clip model name on huggingface to compute
the similarity between image and text.
Expand Down
6 changes: 5 additions & 1 deletion data_juicer/ops/mapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .audio_ffmpeg_wrapped_mapper import AudioFFmpegWrappedMapper
from .calibrate_qa_mapper import CalibrateQAMapper
from .calibrate_query_mapper import CalibrateQueryMapper
from .calibrate_response_mapper import CalibrateResponseMapper
from .chinese_convert_mapper import ChineseConvertMapper
from .clean_copyright_mapper import CleanCopyrightMapper
from .clean_email_mapper import CleanEmailMapper
Expand Down Expand Up @@ -53,7 +56,8 @@
from .whitespace_normalization_mapper import WhitespaceNormalizationMapper

__all__ = [
'AudioFFmpegWrappedMapper', 'ChineseConvertMapper', 'CleanCopyrightMapper',
'AudioFFmpegWrappedMapper', 'CalibrateQAMapper', 'CalibrateQueryMapper',
'CalibrateResponseMapper', 'ChineseConvertMapper', 'CleanCopyrightMapper',
'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper',
'ExpandMacroMapper', 'FixUnicodeMapper', 'GenerateQAFromExamplesMapper',
'GenerateQAFromTextMapper', 'ImageBlurMapper',
Expand Down
113 changes: 113 additions & 0 deletions data_juicer/ops/mapper/calibrate_qa_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import re
from typing import Dict, Optional

from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = 'calibrate_qa_mapper'


# TODO: LLM-based inference.
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class CalibrateQAMapper(Mapper):
"""
Mapper to calibrate question-answer pairs based on reference text.
"""

# avoid leading whitespace
DEFAULT_SYSTEM_PROMPT = ('请根据提供的【参考信息】对【问题】和【回答】进行校准,使其更加详细、准确。\n'
'按照以下格式输出:\n'
'【问题】\n'
'校准后的问题\n'
'【回答】\n'
'校准后的回答')
DEFAULT_INPUT_TEMPLATE = '{reference}\n{qa_pair}'
DEFAULT_REFERENCE_TEMPLATE = '【参考信息】\n{}'
DEFAULT_QA_PAIR_TEMPLATE = '【问题】\n{}\n【回答】\n{}'
DEFAULT_OUTPUT_PATTERN = r'【问题】\s*(.*?)\s*【回答】\s*(.*)'

def __init__(self,
api_model: str = 'gpt-4o',
*,
api_url: Optional[str] = None,
api_key: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt: Optional[str] = None,
input_template: Optional[str] = None,
reference_template: Optional[str] = None,
qa_pair_template: Optional[str] = None,
output_pattern: Optional[str] = None,
model_params: Optional[Dict] = None,
sampling_params: Optional[Dict] = None,
**kwargs):
"""
Initialization method.
:param api_model: API model name.
:param api_url: API URL. Defaults to DJ_API_URL environment variable.
:param api_key: API key. Defaults to DJ_API_KEY environment variable.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt: System prompt for the calibration task.
:param input_template: Template for building the model input.
:param reference_template: Template for formatting the reference text.
:param qa_pair_template: Template for formatting question-answer pairs.
:param output_pattern: Regular expression for parsing model output.
:param model_params: Parameters for initializing the model.
:param sampling_params: Extra parameters passed to the API call.
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)

self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
self.reference_template = reference_template or \
self.DEFAULT_REFERENCE_TEMPLATE
self.qa_pair_template = qa_pair_template or \
self.DEFAULT_QA_PAIR_TEMPLATE
self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN

self.model_params = model_params or {}
self.sampling_params = sampling_params or {}
self.model_key = prepare_model(model_type='api',
api_model=api_model,
api_url=api_url,
api_key=api_key,
response_path=response_path,
**model_params)

def build_input(self, sample):
reference = self.reference_template.format(sample[self.text_key])
qa_pair = self.qa_pair_template.format(sample[self.query_key],
sample[self.response_key])
input_prompt = self.input_template.format(reference=reference,
qa_pair=qa_pair)
return input_prompt

def parse_output(self, raw_output):
match = re.match(self.output_pattern, raw_output)
if match:
return match.group(1).strip(), match.group(2).strip()
else:
return None, None

def process_single(self, sample=None, rank=None):
client = get_model(self.model_key, rank=rank)

messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role': 'user',
'content': self.build_input(sample)
}]
output = client(messages, **self.sampling_params)

parsed_q, parsed_a = self.parse_output(output)
if parsed_q:
sample[self.query_key] = parsed_q
if parsed_a:
sample[self.response_key] = parsed_a

return sample
19 changes: 19 additions & 0 deletions data_juicer/ops/mapper/calibrate_query_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from data_juicer.ops.base_op import OPERATORS, UNFORKABLE
from data_juicer.ops.mapper.calibrate_qa_mapper import CalibrateQAMapper

OP_NAME = 'calibrate_query_mapper'


# TODO: LLM-based inference.
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class CalibrateQueryMapper(CalibrateQAMapper):
"""
Mapper to calibrate query in question-answer pairs based on reference text.
"""

DEFAULT_SYSTEM_PROMPT = '请根据提供的【参考信息】对问答对中的【问题】进行校准,\
使其更加详细、准确,且仍可以由原答案回答。只输出校准后的问题,不要输出多余内容。'

def parse_output(self, raw_output):
return raw_output.strip(), None
19 changes: 19 additions & 0 deletions data_juicer/ops/mapper/calibrate_response_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from data_juicer.ops.base_op import OPERATORS, UNFORKABLE
from data_juicer.ops.mapper.calibrate_qa_mapper import CalibrateQAMapper

OP_NAME = 'calibrate_response_mapper'


# TODO: LLM-based inference.
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class CalibrateResponseMapper(CalibrateQAMapper):
"""
Mapper to calibrate response in question-answer pairs based on reference text.
""" # noqa: E501

DEFAULT_SYSTEM_PROMPT = '请根据提供的【参考信息】对问答对中的【回答】进行校准,\
使其更加详细、准确,且仍可以回答原问题。只输出校准后的回答,不要输出多余内容。'

def parse_output(self, raw_output):
return None, raw_output.strip()
5 changes: 2 additions & 3 deletions data_juicer/ops/mapper/optimize_qa_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,8 @@ def build_input(self, sample):

def parse_output(self, raw_output):
logger.debug(raw_output)
matches = re.findall(self.output_pattern, raw_output, re.DOTALL)
if matches:
match = matches[0]
match = re.match(self.output_pattern, raw_output, re.DOTALL)
if match:
return match.group(1).strip(), match.group(2).strip()
else:
return None, None
Expand Down
1 change: 0 additions & 1 deletion data_juicer/utils/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,6 @@ def parse_string_to_roi(roi_string, roi_type='pixel'):
'format of "x1, y1, x2, y2", "(x1, y1, x2, y2)", or '
'"[x1, y1, x2, y2]".')
return None
return None


def close_video(container: av.container.InputContainer):
Expand Down
Loading

0 comments on commit fe2b4cf

Please sign in to comment.