Skip to content

Commit

Permalink
[Ready] Basic Service Implementation (#468)
Browse files Browse the repository at this point in the history
* inital implementation

* support params checking and dataset converting

* support filtered and tags

* allow partial execution of op.run

* Service/match api (#428)

* temp

* Regress model preloading (#426)

* fix param definition (#424)

* fix param def

* add param check

* use pydantic types (#422)

* use pydantic types

* change config unittest

* fix GenerateInstructionMapper

* update

* Fix (#427)

Fix some words

* Add new OP: image_tagging_mapper (#423)

* * init image tagging mapper

* + Add unittest for image_tagging_mapper
* support specified tag field names for all tagging OPs

* * fix problems of unittest

* + add docs

* * update docs

* * skip two unittests which require ram

* * minor fix for gece's comments

* * merge main into this branch

* + add type hint

* match api call

* match api call

* pre-commit

* decouple API args & add type hints

---------

Co-authored-by: Ce Ge (戈策) <[email protected]>
Co-authored-by: Cathy0908 <[email protected]>
Co-authored-by: co63oc <[email protected]>
Co-authored-by: Yilun Huang <[email protected]>
Co-authored-by: null <[email protected]>

* Service/match api (#431)

* temp

* match api call

* match api call

* pre-commit

* decouple API args & add type hints

* agentscope demo

* update demos

* update pre-commt

* add yaml

* update notebooks

* add py

* refine

---------

Co-authored-by: null <[email protected]>
Co-authored-by: gece.gc <[email protected]>

* update reqs

---------

Co-authored-by: BeachWang <[email protected]>
Co-authored-by: Cathy0908 <[email protected]>
Co-authored-by: co63oc <[email protected]>
Co-authored-by: Yilun Huang <[email protected]>
  • Loading branch information
5 people authored Nov 11, 2024
1 parent d2e92f9 commit 528b8a9
Show file tree
Hide file tree
Showing 18 changed files with 1,778 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ exclude: |
(?x)^(
docs/.*|
tests/.*|
demos/.*|
demos/(?!api_service/).*|
tools/mm_eval/inception_metrics/.*|
thirdparty/easy_animate/.*|
.*\.md
Expand Down
5 changes: 2 additions & 3 deletions data_juicer/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
merge_config, prepare_side_configs)

__all__ = [
'init_configs',
'export_config',
'merge_config',
'init_configs', 'get_init_configs', 'export_config', 'merge_config',
'prepare_side_configs'
]
36 changes: 18 additions & 18 deletions data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import shutil
import tempfile
import time
from argparse import ArgumentError, Namespace
from typing import Dict, List, Union
from argparse import ArgumentError
from typing import Dict, List, Optional, Union

import yaml
from jsonargparse import (ActionConfigFile, ArgumentParser, dict_to_namespace,
namespace_to_dict)
from jsonargparse import (ActionConfigFile, ArgumentParser, Namespace,
dict_to_namespace, namespace_to_dict)
from jsonargparse.typehints import ActionTypeHint
from jsonargparse.typing import ClosedUnitInterval, NonNegativeInt, PositiveInt
from loguru import logger
Expand All @@ -22,7 +22,7 @@
global_parser = None


def init_configs(args=None):
def init_configs(args: Optional[List[str]] = None):
"""
initialize the jsonargparse parser and parse configs from one of:
1. POSIX-style commands line args;
Expand Down Expand Up @@ -357,7 +357,7 @@ def update_ds_cache_dir_and_related_vars(new_ds_cache_path):
config.DEFAULT_EXTRACTED_DATASETS_PATH)


def init_setup_from_cfg(cfg):
def init_setup_from_cfg(cfg: Namespace):
"""
Do some extra setup tasks after parsing config file or command line.
Expand Down Expand Up @@ -628,7 +628,7 @@ def namespace_to_arg_list(namespace, prefix='', includes=None, excludes=None):
return arg_list


def config_backup(cfg):
def config_backup(cfg: Namespace):
cfg_path = cfg.config[0].absolute
work_dir = cfg.work_dir
target_path = os.path.join(work_dir, os.path.basename(cfg_path))
Expand All @@ -638,7 +638,7 @@ def config_backup(cfg):
shutil.copyfile(cfg_path, target_path)


def display_config(cfg):
def display_config(cfg: Namespace):
import pprint

from tabulate import tabulate
Expand All @@ -658,13 +658,13 @@ def display_config(cfg):
print(table)


def export_config(cfg,
path,
format='yaml',
skip_none=True,
skip_check=True,
overwrite=False,
multifile=True):
def export_config(cfg: Namespace,
path: str,
format: str = 'yaml',
skip_none: bool = True,
skip_check: bool = True,
overwrite: bool = False,
multifile: bool = True):
"""
Save the config object, some params are from jsonargparse
Expand Down Expand Up @@ -700,7 +700,7 @@ def export_config(cfg,
logger.info(f'Saved the configuration in {path}')


def merge_config(ori_cfg, new_cfg: Dict):
def merge_config(ori_cfg: Namespace, new_cfg: Namespace):
"""
Merge configuration from new_cfg into ori_cfg
Expand Down Expand Up @@ -758,7 +758,7 @@ def merge_config(ori_cfg, new_cfg: Dict):
logger.error('Config merge failed')


def prepare_side_configs(ori_config):
def prepare_side_configs(ori_config: Union[str, Namespace, Dict]):
"""
parse the config if ori_config is a string of a config file path with
yaml, yml or json format
Expand Down Expand Up @@ -790,7 +790,7 @@ def prepare_side_configs(ori_config):
return config


def get_init_configs(cfg):
def get_init_configs(cfg: Union[Namespace, Dict]):
"""
set init configs of datajucer for cfg
"""
Expand Down
16 changes: 12 additions & 4 deletions data_juicer/core/analyzer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
from typing import Optional

from jsonargparse import Namespace
from loguru import logger
from pydantic import PositiveInt

from data_juicer.analysis import ColumnWiseAnalysis, OverallAnalysis
from data_juicer.config import init_configs
Expand All @@ -22,11 +25,11 @@ class Analyzer:
dataset better.
"""

def __init__(self, cfg=None):
def __init__(self, cfg: Optional[Namespace] = None):
"""
Initialization method.
:param cfg: optional config dict.
:param cfg: optional jsonargparse Namespace dict.
"""
self.cfg = init_configs() if cfg is None else cfg

Expand Down Expand Up @@ -65,12 +68,16 @@ def __init__(self, cfg=None):
self.overall_single_plot_path = None
self.analysis_path = os.path.join(self.cfg.work_dir, 'analysis')

def run(self, load_data_np=None, skip_export=False):
def run(self,
load_data_np: Optional[PositiveInt] = None,
skip_export: bool = False,
skip_return: bool = False):
"""
Running the dataset analysis pipeline.
:param load_data_np: number of workers when loading the dataset.
:param skip_export: whether export the results into disk
:param skip_return: skip return for API called.
:return: analyzed dataset.
"""
# 1. format data
Expand Down Expand Up @@ -129,4 +136,5 @@ def run(self, load_data_np=None, skip_export=False):
)
column_wise_analysis.analyze(skip_export=skip_export)

return dataset
if not skip_return:
return dataset
16 changes: 12 additions & 4 deletions data_juicer/core/executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
from time import time
from typing import Optional

from jsonargparse import Namespace
from loguru import logger
from pydantic import PositiveInt

from data_juicer.config import init_configs
from data_juicer.core.data import Dataset
Expand All @@ -27,11 +30,11 @@ class Executor:
ops in the config file in order and generate a processed dataset.
"""

def __init__(self, cfg=None):
def __init__(self, cfg: Optional[Namespace] = None):
"""
Initialization method.
:param cfg: optional config dict.
:param cfg: optional jsonargparse Namespace.
"""
self.cfg = init_configs() if cfg is None else cfg

Expand Down Expand Up @@ -135,11 +138,14 @@ def sample_data(self,
else:
raise ValueError(f'Unsupported sample_algo: {sample_algo}')

def run(self, load_data_np=None):
def run(self,
load_data_np: Optional[PositiveInt] = None,
skip_return=False):
"""
Running the dataset process pipeline.
:param load_data_np: number of workers when loading the dataset.
:param skip_return: skip return for API called.
:return: processed dataset.
"""
# 1. format data
Expand Down Expand Up @@ -176,4 +182,6 @@ def run(self, load_data_np=None):
if self.cfg.use_cache and self.cfg.cache_compress:
from data_juicer.utils.compress import compress
compress(dataset)
return dataset

if not skip_return:
return dataset
34 changes: 20 additions & 14 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def process_single(self, sample):
"""
raise NotImplementedError

def run(self, dataset, *, exporter=None, tracer=None):
def run(self, dataset, *, exporter=None, tracer=None, reduce=True):
dataset = super(Filter, self).run(dataset)
if Fields.stats not in dataset.features:
from data_juicer.core.data import add_same_content_to_new_column
Expand All @@ -384,13 +384,16 @@ def run(self, dataset, *, exporter=None, tracer=None):
desc=self._name + '_compute_stats')
if exporter and self.stats_export_path is not None:
exporter.export_compute_stats(dataset, self.stats_export_path)
new_dataset = dataset.filter(self.process,
num_proc=self.runtime_np(),
batch_size=self.batch_size,
desc=self._name + '_process')
if tracer:
tracer.trace_filter(self._name, dataset, new_dataset)
return new_dataset
if reduce:
new_dataset = dataset.filter(self.process,
num_proc=self.runtime_np(),
batch_size=self.batch_size,
desc=self._name + '_process')
if tracer:
tracer.trace_filter(self._name, dataset, new_dataset)
return new_dataset
else:
return dataset


class Deduplicator(OP):
Expand Down Expand Up @@ -436,17 +439,20 @@ def process(self, dataset, show_num=0):
"""
raise NotImplementedError

def run(self, dataset, *, exporter=None, tracer=None):
def run(self, dataset, *, exporter=None, tracer=None, reduce=True):
dataset = super(Deduplicator, self).run(dataset)
dataset = dataset.map(self.compute_hash,
num_proc=self.runtime_np(),
with_rank=self.use_cuda(),
desc=self._name + '_compute_hash')
show_num = tracer.show_num if tracer else 0
new_dataset, dup_pairs = self.process(dataset, show_num)
if tracer:
tracer.trace_deduplicator(self._name, dup_pairs)
return new_dataset
if reduce:
show_num = tracer.show_num if tracer else 0
new_dataset, dup_pairs = self.process(dataset, show_num)
if tracer:
tracer.trace_deduplicator(self._name, dup_pairs)
return new_dataset
else:
return dataset


class Selector(OP):
Expand Down
13 changes: 13 additions & 0 deletions demos/api_service/configs/dj_config_template.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# data-juicer config template

# global parameters
project_name: 'dj_agent'
dataset_path: '' # path to your dataset directory or file, specified in the agent
np: 4 # number of subprocess to process your dataset

export_path: '' # path to the output path, specified in the agent
export_original_dataset: true

# process schedule
# a list of several process operators with their arguments, specified in the agent
process: []
21 changes: 21 additions & 0 deletions demos/api_service/configs/model_configs.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[
{
"config_name": "gpt-4",
"model_type": "openai-chat",
"model_name": "gpt-4",
"api_key": "your API key",
"organization": "your organization name",
"generate_args": {
"temperature": 0.5
}
},
{
"config_name": "dashscope_chat-qwen-max",
"model_type": "dashscope_chat",
"model_name": "qwen-max",
"api_key": "your API key",
"generate_args": {
"temperature": 0.0
}
}
]
Loading

0 comments on commit 528b8a9

Please sign in to comment.