Skip to content

Commit

Permalink
Update huggingface_model.py
Browse files Browse the repository at this point in the history
fixed issue related to `format` method
  • Loading branch information
zyzhang1130 committed Apr 26, 2024
1 parent b4f4f40 commit e33b3de
Showing 1 changed file with 70 additions and 26 deletions.
96 changes: 70 additions & 26 deletions examples/load_finetune_huggingface_model/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
- Finetune_DialogAgent: Extends DialogAgent for model fine-tuning.
"""
from typing import Optional, List, Dict, Any
from typing import Sequence, Any, Union, List, Optional, Dict
import os

import torch
Expand All @@ -23,6 +23,8 @@

from agentscope.agents import DialogAgent
from agentscope.models import ModelWrapperBase, ModelResponse
from agentscope.message import MessageBase
from agentscope.utils.tools import _convert_to_str


class HuggingFaceWrapper(ModelWrapperBase):
Expand Down Expand Up @@ -127,11 +129,11 @@ def __call__(

try:
# Tokenize the input text
concatenated_input = "\n".join(
[f"{d.get('name', 'System')}: {d['content']}" for d in inputs],
concatenated_input = "\n ".join(
[f"{d.get('role')}: {d['content']}" for d in inputs],
)
input_ids = self.tokenizer.encode(
f"{concatenated_input}\nAssistent: ",
f"{concatenated_input}\n assistent: ",
return_tensors="pt",
)
# Generate response using the model
Expand All @@ -150,18 +152,56 @@ def __call__(
logger.error(f"Generation error: {e}")
raise

def format(self, data: Any) -> Any:
"""
Pass-through for data formatting. Assume
data is already in the correct format.
def format(
self,
*args: Union[MessageBase, Sequence[MessageBase]],
) -> List[dict]:
"""A basic strategy to format the input into the required format of
Hugging Face models.
Arguments:
data (Any): Data to be formatted.
Args:
args (`Union[MessageBase, Sequence[MessageBase]]`):
The input arguments to be formatted, where each argument
should be a `Msg` object, or a list of `Msg` objects.
In distribution, placeholder is also allowed.
Returns:
Any: The input data unchanged.
`List[dict]`:
The formatted messages.
"""
return data
huggingface_msgs = []
for msg in args:
if msg is None:
continue
if isinstance(msg, MessageBase):
# content shouldn't be empty string
if msg.content == "":
logger.warning(
"The content field cannot be "
"empty string. To avoid error, the empty string is "
"replaced by a blank space automatically, but the "
"model may not work as expected.",
)
msg.content = " "

huggingface_msg = {
"role": msg.role,
"content": _convert_to_str(msg.content),
}

# image url
if msg.url is not None:
huggingface_msg["images"] = [msg.url]

huggingface_msgs.append(huggingface_msg)
elif isinstance(msg, list):
huggingface_msgs.extend(self.format(*msg))
else:
raise TypeError(
f"Invalid message type: {type(msg)}, `Msg` is expected.",
)

return huggingface_msgs

def load_model(
self,
Expand Down Expand Up @@ -191,20 +231,22 @@ def load_model(
token=self.huggingface_token,
device_map="auto",
)
print("new model")
info_msg = (
f"Successfully loaded new model '{model_id}' from "
f"Hugging Face"
)
else:
self.model = AutoModelForCausalLM.from_pretrained(
local_model_path,
local_files_only=True,
device_map="auto",
)
print("local model")
info_msg = (
f"Successfully loaded new model '{model_id}' from "
f"'{local_model_path}'"
)

# log the successful model loading
info_msg = (
f"Successfully loaded new model '{model_id}' from "
f"'{local_model_path}'"
)
logger.info(info_msg)

except Exception as e:
Expand Down Expand Up @@ -244,18 +286,20 @@ def load_tokenizer(
model_id,
token=self.huggingface_token,
)
print("new tokenizer")
# log the successful tokenizer loading
logger.info(
f"Successfully loaded new tokenizer for model "
f"'{model_id}' from Hugging Face",
)
else:
self.tokenizer = AutoTokenizer.from_pretrained(
local_tokenizer_path,
)
print("local tokenizer")

# log the successful tokenizer loading
logger.info(
f"Successfully loaded new tokenizer for model '{model_id}' "
f"from '{local_tokenizer_path}'",
)
# log the successful tokenizer loading
logger.info(
f"Successfully loaded new tokenizer for model "
f"'{model_id}' from '{local_tokenizer_path}'",
)

except Exception as e:
# Handle exceptions during model loading,
Expand Down

0 comments on commit e33b3de

Please sign in to comment.