Skip to content

Commit

Permalink
fix forward error
Browse files Browse the repository at this point in the history
  • Loading branch information
pppppM committed Mar 22, 2024
1 parent b36cbb5 commit e954d5c
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 109 deletions.
31 changes: 19 additions & 12 deletions xtuner/dataset/hybrid/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def hybrid_collate_fn(instances: Sequence[Dict],
pixel_values = []
cumulative_len = []
image_ranges = []
image_belong = []
image_belongs = []
position_ids = []

for i, data in enumerate(instances):
input_ids.append(torch.LongTensor(data['input_ids']))
labels.append(torch.LongTensor(data['labels']))
Expand All @@ -27,28 +27,33 @@ def hybrid_collate_fn(instances: Sequence[Dict],
if 'cumulative_len' in data:
cumulative_len.append(torch.IntTensor(data['cumulative_len']))

image_belong.append(i)
pixel_values.extend(data['pixel_values'])
image_ranges.extend(torch.IntTensor(data['image_ranges']))


_values = data['pixel_values']
_ranges = data['image_ranges']

assert len(_values) == len(_ranges)
for v, rng in zip(_values, _ranges):
pixel_values.append(v)
image_ranges.append(rng)
image_belongs.append(i)

if len(pixel_values) > 0:
assert len(image_ranges) > 0
assert len(image_belong) > 0
assert len(image_belongs) > 0

pixel_values = torch.stack(pixel_values)
image_ranges = torch.stack(image_ranges)
image_belong = torch.IntTensor(image_belong)
# image_belongs = torch.IntTensor(image_belongs)
else:
pixel_values = None
image_ranges = None
image_belong = None
image_belongs = None

if len(instances) > 1:
input_ids = pad_sequence(
input_ids, batch_first=True, padding_value=pad_index)
labels = pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX)
position_ids = pad_sequence(labels, batch_first=True, padding_value=0)
position_ids = pad_sequence(position_ids, batch_first=True, padding_value=0)
else:
input_ids = torch.stack(input_ids)
labels = torch.stack(labels)
Expand All @@ -57,6 +62,7 @@ def hybrid_collate_fn(instances: Sequence[Dict],
if len(cumulative_len) == 0:
cumulative_len = None

# breakpoint()
data_dict = {
'input_ids': input_ids,
'position_ids': position_ids,
Expand All @@ -65,8 +71,9 @@ def hybrid_collate_fn(instances: Sequence[Dict],
'pixel_values': pixel_values,
'cumulative_len': cumulative_len,
'image_ranges': image_ranges,
'image_belong': image_belong
'image_belongs': image_belongs
}


if return_hf_format:
return data_dict
Expand Down
68 changes: 0 additions & 68 deletions xtuner/dataset/hybrid/hybrid.py

This file was deleted.

128 changes: 100 additions & 28 deletions xtuner/model/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
import torch
from mmengine.model import BaseModel
from peft import LoraConfig
from mmengine import print_log
from torch import nn

import math
from xtuner.registry import BUILDER
from xtuner.utils.config import build_from_cfg_or_obj
from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
from .utils import (LoadWoInit, enable_hf_model_gradient_checkpointing,
get_peft_model_state_dict, prepare_for_llm_lora,
prepare_for_vision_lora,
smart_tokenizer_and_embedding_resize)


import torch.distributed as dist
from mmengine import runner
class HybridFinetune(BaseModel):

def __init__(
Expand Down Expand Up @@ -106,54 +107,125 @@ def forward(self, data, data_samples=None, mode='loss'):
"""Overload parent class method, only support training."""

if mode == 'loss':
return self.compute_loss(data, data_samples)
return self.compute_loss(data)
else:
raise NotImplementedError(
f"{type(self)}'s forward is only supported for use during "
'training. If you want to get predictions or chat, please '
"directly use `llm`'s forward.")

def compute_loss(self, data, data_samples=None):



def _get_vision_embeds_and_ranges(self, data):

input_ids = data['input_ids']
labels = data['labels']
position_ids = data['position_ids']
attention_mask = data['attention_mask']
pixel_values = data['pixel_values']
img_rngs = data['image_ranges']
img_belong = data['image_belong']

input_embeds = self.llm.get_input_embeddings()(input_ids)
img_belongs = data['image_belongs']

bs, tokens = input_ids.shape

img_embeds = []
ranges_in_flat_batch = []

if pixel_values is not None:
assert isinstance(pixel_values, torch.Tensor)
assert len(img_rngs) == len(img_belongs) == pixel_values.size(0)

batch_total_imgs = len(img_rngs)

visual_outputs = self.visual_encoder(
pixel_values, output_hidden_states=True)
img_embeds = self.projector(
features = self.projector(
visual_outputs.hidden_states[self.visual_select_layer][:, 1:])

empty_embs = torch.zeros_like(input_embeds)
for emb, rng, b_id in zip(img_embeds, img_rngs, img_belong):
left, right = rng
if emb.size(0) == right - left:
empty_embs[b_id, left:right, :] = emb
elif not emb.size(0) == right - left and left == 0:
empty_embs[b_id, left:right, :] = emb[-right:]
elif not emb.size(
0) == right - left and right == empty_embs.size(1):
empty_embs[b_id, left:right, :] = emb[:right - left]
batch_total_imgs, actual_img_tokens, _ = features.shape


for i in range(batch_total_imgs):
img_start, img_end = img_rngs[i]
expect_img_tokens = img_end - img_start
img_emb = features[i]
img_bs_ind = img_belongs[i]

if actual_img_tokens == expect_img_tokens:
img_embeds.append(img_emb)
elif not actual_img_tokens == expect_img_tokens and img_start == 0:
img_embeds.append(img_emb[actual_img_tokens-img_end:])
elif not actual_img_tokens == expect_img_tokens and img_end == tokens:
img_embeds.append(img_emb[:expect_img_tokens])
else:
breakpoint()
raise RuntimeError

flat_offset = tokens * img_bs_ind

left = flat_offset + img_start
right = flat_offset + img_end
ranges_in_flat_batch.append((left, right))

return img_embeds, ranges_in_flat_batch


def _insert_mm_embeddings(self, flat_embeds, mm_embeds, ranges):

assert len(mm_embeds) == len(ranges)
if len(mm_embeds) == 0:
return flat_embeds

chunk_embeds = []
chunk_sizes = []
mm_chunk_ids = []

cursor = 0
_empty_embeds = torch.zeros_like(flat_embeds)
for (start, end), emb in zip(ranges, mm_embeds):
_empty_embeds[start: end] += emb
# if start - cursor > 0:
# chunk_sizes.append(start - cursor)
# cursor = start

# mm_chunk_ids.append(len(chunk_sizes))


# chunk_embeds.append(emb)
# chunk_sizes.append(end - start)
# cursor = end

# tokens = flat_embeds.size(0)
# if sum(chunk_sizes) < tokens :
# chunk_sizes.append(tokens - sum(chunk_sizes))

# chunk_embs = list(torch.split(flat_embeds, chunk_sizes))
# for ind, mm_emb in zip(mm_chunk_ids, mm_embeds) :
# chunk_embs[ind] = mm_emb

# flat_embeds = torch.cat(chunk_embs, dim=0)
flat_embeds = flat_embeds * (_empty_embeds == 0)

return flat_embeds + _empty_embeds

def compute_loss(self, data):

non_img_mask = (empty_embs == 0)
input_embeds = input_embeds * non_img_mask + empty_embs
input_ids = data['input_ids']
labels = data['labels']
position_ids = data['position_ids']
attention_mask = data['attention_mask']

input_embeds = self.llm.get_input_embeddings()(input_ids)

bs, tokens, dim = input_embeds.shape
flat_embeds = input_embeds.flatten(0,1)

img_embs, flat_bs_img_rngs = self._get_vision_embeds_and_ranges(data)
flat_embeds = self._insert_mm_embeddings(flat_embeds, img_embs, flat_bs_img_rngs)
input_embeds = flat_embeds.reshape(bs, tokens, dim)

outputs = self.llm(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=input_embeds,
labels=labels)

loss_dict = {'loss': outputs.loss}
return loss_dict

Expand Down
44 changes: 44 additions & 0 deletions xtuner/types/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,30 @@ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
return chat_template.decorate_function_result(self.content)


class CodeInterpreterCallMsg(BaseModel):

role: Literal['assistant']
content: str
conde_interpreter_call: Union[str, Dict]

def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:

return chat_template.decorate_code_interpreter_call(
self.content, self.conde_interpreter_call)



class CodeInterpreterResultMsg(BaseModel):
role: Literal['function']
name: str
content: Union[str, Dict]

def apply_chat_template(self, chat_template: HybridChatTemplate) -> str:
return chat_template.decorate_code_internpreter_result(self.content)




class Functions(BaseModel):

# class Parameters(BaseModel):
Expand All @@ -108,6 +132,26 @@ class Functions(BaseModel):
name: str
description: Union[str, Dict]
parameters: Union[str, Dict]



class CodeInterpreter(BaseModel):

# class Parameters(BaseModel):

# class Property(BaseModel):
# type: str
# description: str
# enum: Optional[List] = None

# type: Literal['object']
# properties: Dict[str, Property]
# required: List[str]

name: str
description: Union[str, Dict]




HybridChatMsgType = Union[ChatMsg, FunctionCallMsg, FunctionResultMsg]
Expand Down
2 changes: 1 addition & 1 deletion xtuner/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def build_from_cfg_or_obj(cfg_or_obj: Union[dict, OBJ_T],
raise TypeError(
f'Expect an object of {accept}, but there is an object of '
f'{type(obj)}.')
return BUILDER.build(cfg_or_obj)
return obj

else:
raise TypeError(f'cfg_or_obj must be a dict, or {accept}, but got '
Expand Down

0 comments on commit e954d5c

Please sign in to comment.