Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Oct 15, 2024
1 parent 98588f2 commit fe40a7b
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 39 deletions.
3 changes: 1 addition & 2 deletions inference/python/ff_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def main():
ff.Request(
ff.RequestType.REQ_INFERENCE,
prompt=prompt,
max_sequence_length=128,
max_new_tokens=128,
peft_model_id=llm.get_ff_peft_id(lora_inference_config),
)
for prompt in prompts
Expand All @@ -172,7 +172,6 @@ def main():
if len(configs.finetuning_dataset) > 0:
finetuning_request = ff.Request(
ff.RequestType.REQ_FINETUNING,
max_sequence_length=128,
peft_model_id=llm.get_ff_peft_id(lora_finetuning_config),
dataset_filepath=configs.finetuning_dataset,
max_training_steps=2,
Expand Down
122 changes: 91 additions & 31 deletions python/flexflow/core/flexflow_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,7 +1795,7 @@ def __init__(
raise ValueError(
"Target modules can only be specified when trainable=True"
)

# Check rank, lora_alpha, lora_dropout values
if rank is not None or lora_alpha is not None or lora_dropout is not None:
if not trainable or not init_lora_weights:
Expand All @@ -1805,15 +1805,15 @@ def __init__(
rank = rank if rank is not None else 8
lora_alpha = lora_alpha if lora_alpha is not None else 8.0
lora_dropout = lora_dropout if lora_dropout is not None else 0.0

# If passed, check if the values of rank, lora_alpha, and lora_dropout are valid
if rank < 1 or type(rank) != int:
raise ValueError("Rank must be >= 1 and an integer")
if lora_alpha <= 0:
raise ValueError("Lora_alpha must be > 0")
if lora_dropout < 0 or lora_dropout > 1:
raise ValueError("Lora_dropout must be in the interval [0, 1]")

self.ff_initialized = False
self._cache_folder = cache_folder
self._peft_model_id = peft_model_id
Expand Down Expand Up @@ -2051,13 +2051,15 @@ def no_id_handle():
# Request
# -----------------------------------------------------------------------


@dataclass
class Request:
"""A class to record the metadata of an inference or finetuning request."""

req_type: RequestType
prompt: Optional[str] = None
max_length: int = -1
max_new_tokens: int = 128
max_new_tokens: int = -1
peft_model_id: Optional[PEFTModelID] = None
dataset_filepath: Optional[str] = None
max_training_steps: int = 1
Expand Down Expand Up @@ -4650,26 +4652,64 @@ def get_output_tensor(self, ffmodel, data_type):
assert ret_val == True
return np_array

def generate_inf_only(self, prompt_list: List[str], max_length: int = -1, max_new_tokens: int = 128):
def _estimate_max_num_tokens(
max_length: int, max_new_tokens: int, prompt: Optional[str]
):
if prompt is None:
assert max_new_tokens == -1
return (
math.ceil(max_new_tokens + len(prompt.split()) * 1.5)
if max_new_tokens != -1
else max_length
)

def _estimate_max_num_chars(
max_length: int, max_new_tokens: int, prompt: Optional[str]
):
return (
5 * FFModel._estimate_max_num_tokens(max_length, max_new_tokens, prompt)
+ 100
)

def generate_inf_only(
self,
prompt_list: List[str],
max_length: int,
max_new_tokens: int,
):
if max_length != -1 and max_new_tokens != -1:
warnings.warn(f"Both `max_new_tokens` (={self.max_new_tokens}) and `max_length`(={self.max_length}) seem to have been set. `max_new_tokens` will take precedence.")
raise ValueError(
f"Both `max_new_tokens` (={max_new_tokens}) and `max_length`(={max_length}) seem to have been set."
)
if max_length == -1 and max_new_tokens == -1:
raise ValueError(
f"Both `max_new_tokens` (={max_new_tokens}) and `max_length`(={max_length}) were left unset."
)
assert isinstance(prompt_list, list)
c_input_texts = [get_c_name(prompt) for prompt in prompt_list]
estimated_max_tokens = math.ceil(max_new_tokens + max([len(prompt.split()) for prompt in prompt_list])*1.5) if max_new_tokens != -1 else max_length
max_num_chars = 5 * (estimated_max_tokens + 100)
c_output_texts = [ffi.new("char[]", max_num_chars) for prompt in prompt_list]
c_output_texts = [
ffi.new(
"char[]",
FFModel._estimate_max_num_chars(max_length, max_new_tokens, prompt),
)
for prompt in prompt_list
]
c_output_length_and_tokens = [
ffi.new("int[]", estimated_max_tokens + 100) for prompt in prompt_list
ffi.new(
"int[]",
FFModel._estimate_max_num_tokens(max_length, max_new_tokens, prompt)
+ 100,
)
for prompt in prompt_list
]
c_request_types = [
enum_to_int(RequestType, RequestType.REQ_INFERENCE)
for prompt in prompt_list
enum_to_int(RequestType, RequestType.REQ_INFERENCE) for _ in prompt_list
]
max_lengths = [max_length for prompt in prompt_list]
max_new_tokens_ = [max_new_tokens for prompt in prompt_list]
peft_model_ids = [PEFTModelID.no_id_handle() for prompt in prompt_list]
dataset_filepaths = [ffi.NULL for prompt in prompt_list]
training_steps = [0 for prompt in prompt_list]
max_lengths = [max_length for _ in prompt_list]
max_new_tokens_ = [max_new_tokens for _ in prompt_list]
peft_model_ids = [PEFTModelID.no_id_handle() for _ in prompt_list]
dataset_filepaths = [ffi.NULL for _ in prompt_list]
training_steps = [0 for _ in prompt_list]
num_finetuning_losses = ffi.new("int *")
c_finetuning_losses = ffi.new("float[]", 0)
ffc().flexflow_model_generate(
Expand Down Expand Up @@ -4698,34 +4738,55 @@ def generate_inf_only(self, prompt_list: List[str], max_length: int = -1, max_ne

def generate(self, requests_list: List[Request]):
assert isinstance(requests_list, list)
for request in requests_list:
assert isinstance(request, Request)
if request.max_length != -1 and request.max_new_tokens != -1:
raise ValueError(
f"Both `max_new_tokens` (={request.max_new_tokens}) and `max_length`(={request.max_length}) seem to have been set."
)
if request.max_length == -1 and request.max_new_tokens == -1:
raise ValueError(
f"Both `max_new_tokens` (={request.max_new_tokens}) and `max_length`(={request.max_length}) were left unset."
)
if (
request.req_type == RequestType.REQ_FINETUNING
and request.max_new_tokens != -1
):
raise ValueError(
f"Finetuning requests should not have `max_new_tokens` set."
)
c_input_texts = [
get_c_name(request.prompt) for request in requests_list
] # entry will be None for finetuning requests
c_output_texts = [
(
ffi.new("char[]", 5 * (request.max_sequence_length + 100))
ffi.new(
"char[]",
FFModel._estimate_max_num_chars(
request.max_length, request.max_new_tokens, request.prompt
),
)
if request.req_type == RequestType.REQ_INFERENCE
else ffi.NULL
)
for request in requests_list
]
c_output_length_and_tokens = [
ffi.new("int[]", request.max_sequence_length + 100)
ffi.new(
"int[]",
FFModel._estimate_max_num_tokens(
request.max_length, request.max_new_tokens, request.prompt
)
+ 100,
)
for request in requests_list
]
c_request_types = [
enum_to_int(RequestType, request.req_type) for request in requests_list
]
max_lengths = [
request.max_length for request in requests_list
]
max_new_tokens_ = [
request.max_new_tokens for request in requests_list
]
for i in range(len(requests_list)):
if max_lengths[i] != -1 and max_new_tokens_[i] != -1:
warnings.warn(f"Both `max_new_tokens` (={max_new_tokens_[i]}) and `max_length`(={max_lengths[i]}) seem to have been set. `max_new_tokens` will take precedence.")

max_lengths = [request.max_length for request in requests_list]
max_new_tokens_ = [request.max_new_tokens for request in requests_list]

peft_model_ids = [
(
request.peft_model_id
Expand All @@ -4742,7 +4803,7 @@ def generate(self, requests_list: List[Request]):
# c_finetuning_losses = ffi.new("float**")
# TODO: set this value automatically
c_finetuning_losses = ffi.new("float[]", 10000)

ffc().flexflow_model_generate(
self.handle,
len(requests_list),
Expand Down Expand Up @@ -4774,7 +4835,6 @@ def generate(self, requests_list: List[Request]):
finetuning_losses=finetuning_losses,
)
)
return results

def set_position_offset(self, offset):
ffc().flexflow_model_set_position_offset(self.handle, offset)
128 changes: 122 additions & 6 deletions python/flexflow/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@
from typing import Union, List
from huggingface_hub import snapshot_download


class _SupportedModels:
def __init__(self,):
def __init__(
self,
):
self.supported_models = {
"LlamaForCausalLM": (ModelType.LLAMA, FlexFlowLLAMA, LLAMAConfig),
"LLaMAForCausalLM": (ModelType.LLAMA, FlexFlowLLAMA, LLAMAConfig),
Expand Down Expand Up @@ -292,8 +295,8 @@ def download_peft_weights():

weights_path = get_weights_path(peft_model_id)
refresh_cache_if_needed(peft_model_id)
ff_revision, ff_revision_file, latest_revision = self.__get_revision_hashes(
peft_model_id, weights_path
ff_revision, ff_revision_file, latest_revision = (
self.__get_revision_hashes(peft_model_id, weights_path)
)

if ff_revision != latest_revision:
Expand Down Expand Up @@ -350,11 +353,19 @@ def download_hf_tokenizer_if_needed(self):
f"'{self.model_name}' tokenizer needs updating! Downloading tokenizer now..."
)
# Load/download the tokenizer files
target_tokenizer_files = ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "vocab.json", "merges.txt"]
target_tokenizer_files = [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.json",
"merges.txt",
]
if os.path.exists(self.model_name):
hf_tokenizer_path = self.model_name
else:
hf_tokenizer_path = snapshot_download(repo_id=self.model_name, allow_patterns=target_tokenizer_files)
hf_tokenizer_path = snapshot_download(
repo_id=self.model_name, allow_patterns=target_tokenizer_files
)
for file in target_tokenizer_files:
src_path = os.path.join(hf_tokenizer_path, file)
dst_path = os.path.join(self.tokenizer_path, file)
Expand Down Expand Up @@ -424,6 +435,8 @@ def compile(
model_specific_pipeline_parallelism_degree
)

self.max_seq_length = max_seq_length

# Create request manager and set serving configuration
self.rm = RequestManager()
self.rm.set_max_requests_per_batch(max_requests_per_batch)
Expand Down Expand Up @@ -506,7 +519,7 @@ def generate(
self,
requests_or_prompts: Union[str, List[str], Request, List[Request]],
max_length: int = -1,
max_new_tokens: int = 128,
max_new_tokens: int = -1,
):
"""Generate tokens based on the input prompt(s)
Expand All @@ -519,6 +532,109 @@ def generate(
:return: the generation results
:rtype: GenerationResult
"""

def inf_only() -> bool:
if type(requests_or_prompts) == str:
return True
if type(requests_or_prompts) == list:
if type(requests_or_prompts[0]) == str:
return True
return False

# Inference only (type is str or List[str])
# - if max_length and max_new_tokens are both unset, set max_length to max_sequence_length
# - if both are set, give precedence to max_new_tokens
# - if only one of them is set, all good. just check that we don't exceed the max_sequence_length
# Inference + finetunining (type is Request or List[Request]):
# - inference requests: same conditions as above
# - finetuning requests: return error if max_new_tokens is set. If max_lenght is unset, set it to max_sequence_length. If it's set, check that it doesn't exceed max_sequence_length
if inf_only():
# single prompt (str) or list of prompts in str format
if max_length == -1 and max_new_tokens == -1:
max_length = self.max_seq_length
elif max_length != -1 and max_new_tokens != -1:
warnings.warn(
f"Both `max_new_tokens` (={max_new_tokens}) and `max_length`(={max_length}) seem to have been set. `max_new_tokens` will take precedence."
)
max_length = -1
if max_length > self.max_seq_length or max_new_tokens > self.max_seq_length:
raise ValueError(
f"max_length ({max_length}) or max_new_tokens ({max_new_tokens}) exceeds the maximum sequence length ({self.max_seq_length})"
)
elif type(requests_or_prompts) == Request:
# single Request object (inference or finetuning)
if max_length != -1 or max_new_tokens != -1:
warnings.warn(
f"max_length (={max_length}) and max_new_tokens (={max_new_tokens}) are not used for Request objects."
)
if requests_or_prompts.req_type == RequestType.REQ_INFERENCE:
# check max_length and max_new_tokens parameters
if (
requests_or_prompts.max_length == -1
and requests_or_prompts.max_new_tokens == -1
):
requests_or_prompts.max_length = self.max_seq_length
elif (
requests_or_prompts.max_length != -1
and requests_or_prompts.max_new_tokens != -1
):
warnings.warn(
f"Both `max_new_tokens` (={requests_or_prompts.max_new_tokens}) and `max_length`(={requests_or_prompts.max_length}) seem to have been set. `max_new_tokens` will take precedence."
)
requests_or_prompts.max_length = -1
if (
requests_or_prompts.max_length > self.max_seq_length
or requests_or_prompts.max_new_tokens > self.max_seq_length
):
raise ValueError(
f"max_length ({requests_or_prompts.max_length}) or max_new_tokens ({requests_or_prompts.max_new_tokens}) exceeds the maximum sequence length ({self.max_seq_length})"
)
else:
if requests_or_prompts.max_new_tokens != -1:
raise ValueError(
f"max_new_tokens ({requests_or_prompts.max_new_tokens}) is not allowed for finetuning requests."
)
if requests_or_prompts.max_length == -1:
requests_or_prompts.max_length = self.max_seq_length
if requests_or_prompts.max_length > self.max_seq_length:
raise ValueError(
f"max_length ({requests_or_prompts.max_length}) exceeds the maximum sequence length ({self.max_seq_length})"
)
else:
# list of Request objects (inference or finetuning)
if max_length != -1 or max_new_tokens != -1:
warnings.warn(
f"max_length (={max_length}) and max_new_tokens (={max_new_tokens}) are not used for Request objects."
)
for req in requests_or_prompts:
if req.req_type == RequestType.REQ_INFERENCE:
# check max_length and max_new_tokens parameters
if req.max_length == -1 and req.max_new_tokens == -1:
req.max_length = self.max_seq_length
elif req.max_length != -1 and req.max_new_tokens != -1:
warnings.warn(
f"Both `max_new_tokens` (={req.max_new_tokens}) and `max_length`(={req.max_length}) seem to have been set. `max_new_tokens` will take precedence."
)
req.max_length = -1
if (
req.max_length > self.max_seq_length
or req.max_new_tokens > self.max_seq_length
):
raise ValueError(
f"max_length ({req.max_length}) or max_new_tokens ({req.max_new_tokens}) exceeds the maximum sequence length ({self.max_seq_length})"
)
else:
if req.max_new_tokens != -1:
raise ValueError(
f"max_new_tokens ({req.max_new_tokens}) is not allowed for finetuning requests."
)
if req.max_length == -1:
req.max_length = self.max_seq_length
if req.max_length > self.max_seq_length:
raise ValueError(
f"max_length ({req.max_length}) exceeds the maximum sequence length ({self.max_seq_length})"
)

if type(requests_or_prompts) == str:
if len(requests_or_prompts) == 0:
return None
Expand Down

0 comments on commit fe40a7b

Please sign in to comment.