Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Hybrid Data Pipeline #495

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

[Feature] Hybrid Data Pipeline #495

wants to merge 7 commits into from

Conversation

pppppM
Copy link
Collaborator

@pppppM pppppM commented Mar 19, 2024

No description provided.

@pppppM
Copy link
Collaborator Author

pppppM commented Mar 19, 2024

@fanqiNO1

@xiaohangguo
Copy link
Contributor

我做了一个 one-shot ,学着写一个,python解释器的样例,佬看一下对不对
@pppppM @fanqiNO1

[
    {
        "messages": [
            {
                "role": "user",
                "content": "帮我用scipy计算一个矩阵的逆"
            },
            {
                "role": "assistant",
                "content": "Sure, I will perform the matrix inversion using scipy.",
                "function_call": {
                    "name": "python_interpreter",
                    "parameters": {
                        "code": "import scipy.linalg\nscipy.linalg.inv([[1, 2], [3, 4]])"
                    }
                }
            },
            {
                "role": "function",
                "name": "python_interpreter",
                "content": "array([[-2. ,  1. ],\n       [ 1.5, -0.5]])"
            },
            {
                "role": "assistant",
                "content": "使用 scipy 计算出的矩阵的逆是 [[-2. , 1. ], [1.5, -0.5]]"
            }
        ],
        "functions": [
            {
                "name": "python_interpreter",
                "description": "Execute Python code and return the result",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "code": {
                            "type": "string",
                            "description": "Python code to be executed"
                        },
                        "required": ["code"]
                    }
                }
            }
        ]
    }
]



model = dict(
type=HybridFinetune,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个名字有点奇怪,要不叫做 HybridFinetuneModel,不过还有一个疑问,如果直接写了 finetune,用户会不会以为只能 finetune model 而不能 pretrain model?

chat_template=chat_template,
max_length=max_length,
pack_to_max_length=True,
num_workers = dataloader_num_workers,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方也有 dataloader_num_workers?

type=HybridDataset,
data_dir=data_dir,
data_files=data_files,
data_cached='cached_llava',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

支持自动 cached 功能,即用户指定 data_cached 路径后,如果不存在则自动缓存,如果存在则直接读取并告诉用户

"role": "user",
"content": [
{
"type": "image_url",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方感觉无法做到通用,因为可能会插入一些图片区分的 token,大部分情况下可能都会要重写 tokenizer 逻辑

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同样的问题,是否有办法兼容以下这几种处理方式?

<image>
Picture X: <image>
<IMG><image></IMG>

self.dataset = dataset

self._ori_img_urls = dataset['image_urls']
self._ori_img_rngs = dataset['image_ranges']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要加点注释,否则不知道这个字段是啥意思

'pixel_values': pixel_values,
'cumulative_len': cumulative_len,
'image_ranges': image_ranges,
'image_belong': image_belong
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉要说明下,有些字段只有在特定模式下才需要吧,如果没有点注释,自定义会很难

from xtuner.types import HybridChatTemplate
from xtuner.utils import build_tokenizer

os.environ['TOKENIZERS_PARALLELISM'] = 'true'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种环境变量有没有 false 的可能,如果有,则最好可以通过让用户环境变量设置,默认值为 true

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

考虑加前缀?XTUNER_XXXXXXX

added_keys=dict(tokens=int),
)
def _register_tokens(data, tokenizer=None, chat_template=None):
data['tokens'] = len(data['input_ids'])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉 tokens 这个名字难以理解,最好应该是 token_len 清晰很多

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接叫length?对齐transformers的一些默认行为,同时方便 LengthGroupedSampler
https://huggingface.co/docs/transformers/v4.39.1/en/main_classes/trainer#transformers.TrainingArguments.length_column_name

added_keys=dict(position_ids=list),
)
def _register_position_ids(data, tokenizer=None, chat_template=None):
data['position_ids'] = [i for i in range(len(data['input_ids']))]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
data['position_ids'] = [i for i in range(len(data['input_ids']))]
data['position_ids'] = list(range(len(data['input_ids'])))

input_keys=dict(input_ids=list),
added_keys=dict(cumulative_len=list),
)
def _register_cumulative_len(data, tokenizer=None, chat_template=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

过于简单的函数,可以考虑不要这个封装,否则看起来有点复杂,过度设计

Comment on lines +21 to +22
if not isinstance(data[key], _type):
breakpoint()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not isinstance(data[key], _type):
breakpoint()
assert isinstance(data[key], _type)


for url, ind in zip(image_urls, img_token_inds):
# image = self.load_image(url)
h, w = 336 // 14, 336 // 14
Copy link
Collaborator

@LZHgrla LZHgrla Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如何兼容其他分辨率和patch size,通过 config 传入这个参数?


img_ranges = []
for i, _ in enumerate(zip(input_ids, labels)):
if isinstance(input_ids[i], list):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(input_ids[i], list):
if isinstance(input_ids[i], list): # image pad tokens

new_ids.extend(input_ids[i])
new_labels.extend(labels[i])

else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
else:
else: # text token

)
def llava_to_openai(data, tokenizer=None, chat_template=None):

image_token = '<image>'
Copy link
Collaborator

@LZHgrla LZHgrla Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

允许自定义?我理解这个是专门为了处理 llava 数据集而设定,这也约束就用户的数据必须搞成 '',那么这个与 chat_template 的 image_token 有关系吗?

Comment on lines +55 to +60
projector_config = ProjectorConfig(
visual_hidden_size=self.visual_encoder.config.hidden_size,
llm_hidden_size=self.llm.config.hidden_size,
depth=projector_depth)
self.projector = ProjectorModel(projector_config).to(
self.visual_encoder.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方的初始化是不是在纯LLM时会出问题?

elif self.role == 'user':
if len(self.files) > 0:
stop_word = chat_template.stop_words[0]
text += f'\n{stop_word}\n{chat_template.decorate_files(self.files)}'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为何会引入 stop_word?同时为啥只取第0个?

Comment on lines +38 to +39
elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer':
eos_token_ids = tokenizer.eos_token_id
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer':
eos_token_ids = tokenizer.eos_token_id

@xiaohangguo
Copy link
Contributor

xiaohangguo commented Mar 27, 2024

@pppppM @LZHgrla
code_interpreter_call的部分不用参与训练嘛?
image

8c869fa845ae1330445aa292017eae5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants