Skip to content

Commit

Permalink
Dev/llm info extract (#481)
Browse files Browse the repository at this point in the history
* add api call

* add call_api ops

* clean

* minor update

* more tests

* update tests

* update prompts

* fix unittest

* update tests

* add docs

* minor fix

* add API processor

* refine API  processor

* refine

* chunk and extract events

* fix bugs

* fix tests

* refine tests

* extract nickname

* nickname test done

* lightRAG to OP

* doc done

* remove extra test

* relavant -> relevant

* fix minor error

* ValueError -> Exception

* fix config_all error

* fix prepare_api_model

* fix rank sample None

* constant fix key

* refine args

---------

Co-authored-by: null <[email protected]>
Co-authored-by: gece.gc <[email protected]>
  • Loading branch information
3 people authored Nov 15, 2024
1 parent d761af5 commit c279a3d
Show file tree
Hide file tree
Showing 24 changed files with 2,173 additions and 71 deletions.
99 changes: 90 additions & 9 deletions configs/config_all.yaml

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions data_juicer/ops/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .helper_func import (get_sentences_from_document, get_words_from_document,
merge_on_whitespace_tab_newline,
split_on_newline_tab_whitespace, split_on_whitespace,
strip, words_augmentation, words_refinement)
split_text_by_punctuation, strip, words_augmentation,
words_refinement)
from .special_characters import SPECIAL_CHARACTERS

__all__ = [
'get_sentences_from_document', 'get_words_from_document',
'merge_on_whitespace_tab_newline', 'split_on_newline_tab_whitespace',
'split_on_whitespace', 'strip', 'words_augmentation', 'words_refinement'
'split_on_whitespace', 'strip', 'words_augmentation', 'words_refinement',
'split_text_by_punctuation'
]
16 changes: 16 additions & 0 deletions data_juicer/ops/common/helper_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,19 @@ def get_sentences_from_document(document, model_func=None):
else:
sentences = document.splitlines()
return '\n'.join(sentences)


def split_text_by_punctuation(text):
"""
Split text by any zh and en punctuation
:param text: text to be splitted.
:return: sub texts splitted by any zh and en punctuation
"""
# any zh and en punctuation
punctuation_pattern = r'[\u3000-\u303f\uff00-\uffef]|[!"#$%&\'()*+,-./:;<=>?@[\\\]^_`{|}~]' # noqa: E501

result = re.split(punctuation_pattern, text)
result = [s.strip() for s in result if s.strip()]

return result
22 changes: 15 additions & 7 deletions data_juicer/ops/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from .clean_ip_mapper import CleanIpMapper
from .clean_links_mapper import CleanLinksMapper
from .expand_macro_mapper import ExpandMacroMapper
from .extract_entity_attribute_mapper import ExtractEntityAttributeMapper
from .extract_entity_relation_mapper import ExtractEntityRelationMapper
from .extract_event_mapper import ExtractEventMapper
from .extract_keyword_mapper import ExtractKeywordMapper
from .extract_nickname_mapper import ExtractNicknameMapper
from .fix_unicode_mapper import FixUnicodeMapper
from .generate_qa_from_examples_mapper import GenerateQAFromExamplesMapper
from .generate_qa_from_text_mapper import GenerateQAFromTextMapper
Expand Down Expand Up @@ -37,6 +42,7 @@
RemoveWordsWithIncorrectSubstringsMapper
from .replace_content_mapper import ReplaceContentMapper
from .sentence_split_mapper import SentenceSplitMapper
from .text_chunk_mapper import TextChunkMapper
from .video_captioning_from_audio_mapper import VideoCaptioningFromAudioMapper
from .video_captioning_from_frames_mapper import \
VideoCaptioningFromFramesMapper
Expand All @@ -59,18 +65,20 @@
'AudioFFmpegWrappedMapper', 'CalibrateQAMapper', 'CalibrateQueryMapper',
'CalibrateResponseMapper', 'ChineseConvertMapper', 'CleanCopyrightMapper',
'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper',
'ExpandMacroMapper', 'FixUnicodeMapper', 'GenerateQAFromExamplesMapper',
'GenerateQAFromTextMapper', 'ImageBlurMapper',
'ImageCaptioningFromGPT4VMapper', 'ImageCaptioningMapper',
'ImageDiffusionMapper', 'ImageFaceBlurMapper', 'ImageTaggingMapper',
'NlpaugEnMapper', 'NlpcdaZhMapper', 'OptimizeQAMapper',
'OptimizeQueryMapper', 'OptimizeResponseMapper',
'ExpandMacroMapper', 'ExtractEntityAttributeMapper',
'ExtractEntityRelationMapper', 'ExtractEventMapper',
'ExtractKeywordMapper', 'ExtractNicknameMapper', 'FixUnicodeMapper',
'GenerateQAFromExamplesMapper', 'GenerateQAFromTextMapper',
'ImageBlurMapper', 'ImageCaptioningFromGPT4VMapper',
'ImageCaptioningMapper', 'ImageDiffusionMapper', 'ImageFaceBlurMapper',
'ImageTaggingMapper', 'NlpaugEnMapper', 'NlpcdaZhMapper',
'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper',
'PunctuationNormalizationMapper', 'RemoveBibliographyMapper',
'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper',
'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper',
'RemoveSpecificCharsMapper', 'RemoveTableTextMapper',
'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper',
'SentenceSplitMapper', 'VideoCaptioningFromAudioMapper',
'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper',
'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper',
'VideoCaptioningFromVideoMapper', 'VideoFFmpegWrappedMapper',
'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper',
Expand Down
35 changes: 24 additions & 11 deletions data_juicer/ops/mapper/calibrate_qa_mapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import re
from typing import Dict, Optional

from loguru import logger
from pydantic import PositiveInt

from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
from data_juicer.utils.model_utils import get_model, prepare_model

Expand Down Expand Up @@ -30,21 +33,22 @@ class CalibrateQAMapper(Mapper):
def __init__(self,
api_model: str = 'gpt-4o',
*,
api_url: Optional[str] = None,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt: Optional[str] = None,
input_template: Optional[str] = None,
reference_template: Optional[str] = None,
qa_pair_template: Optional[str] = None,
output_pattern: Optional[str] = None,
model_params: Optional[Dict] = None,
sampling_params: Optional[Dict] = None,
try_num: PositiveInt = 3,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs):
"""
Initialization method.
:param api_model: API model name.
:param api_url: URL endpoint for the API.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt: System prompt for the calibration task.
Expand All @@ -54,6 +58,7 @@ def __init__(self,
:param output_pattern: Regular expression for parsing model output.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
Expand All @@ -65,15 +70,17 @@ def __init__(self,
self.qa_pair_template = qa_pair_template or \
self.DEFAULT_QA_PAIR_TEMPLATE
self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN
self.sampling_params = sampling_params or {}

model_params = model_params or {}
self.sampling_params = sampling_params

self.model_key = prepare_model(model_type='api',
model=api_model,
url=api_url,
endpoint=api_endpoint,
response_path=response_path,
**model_params)

self.try_num = try_num

def build_input(self, sample):
reference = self.reference_template.format(sample[self.text_key])
qa_pair = self.qa_pair_template.format(sample[self.query_key],
Expand All @@ -89,7 +96,7 @@ def parse_output(self, raw_output):
else:
return None, None

def process_single(self, sample=None, rank=None):
def process_single(self, sample, rank=None):
client = get_model(self.model_key, rank=rank)

messages = [{
Expand All @@ -99,9 +106,15 @@ def process_single(self, sample=None, rank=None):
'role': 'user',
'content': self.build_input(sample)
}]
output = client(messages, **self.sampling_params)

parsed_q, parsed_a = self.parse_output(output)
parsed_q, parsed_a = None, None
for i in range(self.try_num):
try:
output = client(messages, **self.sampling_params)
parsed_q, parsed_a = self.parse_output(output)
if parsed_q or parsed_a:
break
except Exception as e:
logger.warning(f'Exception: {e}')
if parsed_q:
sample[self.query_key] = parsed_q
if parsed_a:
Expand Down
199 changes: 199 additions & 0 deletions data_juicer/ops/mapper/extract_entity_attribute_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import re
from itertools import chain
from typing import Dict, List, Optional

from loguru import logger
from pydantic import PositiveInt

from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper
from data_juicer.utils.constant import Fields
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = 'extract_entity_attribute_mapper'


# TODO: LLM-based inference.
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class ExtractEntityAttributeMapper(Mapper):
"""
Extract attributes for given entities from the text
"""

_batched_op = True

DEFAULT_SYSTEM_PROMPT_TEMPLATE = (
'给定一段文本,从文本中总结{entity}的{attribute},并且从原文摘录最能说明该{attribute}的代表性示例。\n'
'要求:\n'
'- 摘录的示例应该简短。\n'
'- 遵循如下的回复格式:\n'
'## {attribute}:\n'
'{entity}的{attribute}描述...\n'
'### 代表性示例1:\n'
'说明{entity}该{attribute}的原文摘录1...\n'
'### 代表性示例2:\n'
'说明{entity}该{attribute}的原文摘录2...\n'
'...\n')

DEFAULT_INPUT_TEMPLATE = '# 文本\n```\n{text}\n```\n'
DEFAULT_ATTR_PATTERN_TEMPLATE = r'\#\#\s*{attribute}:\s*(.*?)(?=\#\#\#|\Z)'
DEFAULT_DEMON_PATTERN = r'\#\#\#\s*代表性示例(\d+):\s*(.*?)(?=\#\#\#|\Z)'

def __init__(self,
query_entities: List[str] = [],
query_attributes: List[str] = [],
api_model: str = 'gpt-4o',
*,
entity_key: str = Fields.main_entity,
attribute_key: str = Fields.attribute,
attribute_desc_key: str = Fields.attribute_description,
support_text_key: str = Fields.attribute_support_text,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt_template: Optional[str] = None,
input_template: Optional[str] = None,
attr_pattern_template: Optional[str] = None,
demo_pattern: Optional[str] = None,
try_num: PositiveInt = 3,
drop_text: bool = False,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs):
"""
Initialization method.
:param query_entities: Entity list to be queried.
:param query_attributes: Attribute list to be queried.
:param api_model: API model name.
:param entity_key: The field name to store the given main entity for
attribute extraction. It's "__dj__entity__" in default.
:param entity_attribute_key: The field name to store the given
attribute to be extracted. It's "__dj__attribute__" in default.
:param attribute_desc_key: The field name to store the extracted
attribute description. It's "__dj__attribute_description__" in
default.
:param support_text_key: The field name to store the attribute
support text extracted from the raw text. It's
"__dj__support_text__" in default.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt_template: System prompt template for the
task. Need to be specified by given entity and attribute.
:param input_template: Template for building the model input.
:param attr_pattern_template: Pattern for parsing the attribute from
output. Need to be specified by given attribute.
:param: demo_pattern: Pattern for parsing the demonstraction from
output to support the attribute.
:param try_num: The number of retry attempts when there is an API
call error or output parsing error.
:param drop_text: If drop the text in the output.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)

self.query_entities = query_entities
self.query_attributes = query_attributes

self.entity_key = entity_key
self.attribute_key = attribute_key
self.attribute_desc_key = attribute_desc_key
self.support_text_key = support_text_key

self.system_prompt_template = system_prompt_template \
or self.DEFAULT_SYSTEM_PROMPT_TEMPLATE
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
self.attr_pattern_template = attr_pattern_template \
or self.DEFAULT_ATTR_PATTERN_TEMPLATE
self.demo_pattern = demo_pattern or self.DEFAULT_DEMON_PATTERN

self.sampling_params = sampling_params
self.model_key = prepare_model(model_type='api',
model=api_model,
endpoint=api_endpoint,
response_path=response_path,
**model_params)

self.try_num = try_num
self.drop_text = drop_text

def parse_output(self, raw_output, attribute_name):

attribute_pattern = self.attr_pattern_template.format(
attribute=attribute_name)
pattern = re.compile(attribute_pattern, re.VERBOSE | re.DOTALL)
matches = pattern.findall(raw_output)
if matches:
attribute = matches[0].strip()
else:
attribute = ''

pattern = re.compile(self.demo_pattern, re.VERBOSE | re.DOTALL)
matches = pattern.findall(raw_output)
demos = [demo.strip() for _, demo in matches if demo.strip()]

return attribute, demos

def _process_single_sample(self, text='', rank=None):
client = get_model(self.model_key, rank=rank)

entities, attributes, descs, demo_lists = [], [], [], []
for entity in self.query_entities:
for attribute in self.query_attributes:
system_prompt = self.system_prompt_template.format(
entity=entity, attribute=attribute)
input_prompt = self.input_template.format(text=text)
messages = [{
'role': 'system',
'content': system_prompt
}, {
'role': 'user',
'content': input_prompt
}]

desc, demos = '', []
for i in range(self.try_num):
try:
output = client(messages, **self.sampling_params)
desc, demos = self.parse_output(output, attribute)
if desc and len(demos) > 0:
break
except Exception as e:
logger.warning(f'Exception: {e}')
entities.append(entity)
attributes.append(attribute)
descs.append(desc)
demo_lists.append(demos)

return entities, attributes, descs, demo_lists

def process_batched(self, samples, rank=None):

sample_num = len(samples[self.text_key])

entities, attributes, descs, demo_lists = [], [], [], []
for text in samples[self.text_key]:
res = self._process_single_sample(text, rank=rank)
cur_ents, cur_attrs, cur_descs, cur_demos = res
entities.append(cur_ents)
attributes.append(cur_attrs)
descs.append(cur_descs)
demo_lists.append(cur_demos)

if self.drop_text:
samples.pop(self.text_key)

for key in samples:
samples[key] = [[samples[key][i]] * len(descs[i])
for i in range(sample_num)]
samples[self.entity_key] = entities
samples[self.attribute_key] = attributes
samples[self.attribute_desc_key] = descs
samples[self.support_text_key] = demo_lists

for key in samples:
samples[key] = list(chain(*samples[key]))

return samples
Loading

0 comments on commit c279a3d

Please sign in to comment.