diff --git a/.github/workflows/gpu-ci.yml b/.github/workflows/gpu-ci.yml index 6ca50027d1..9ee4693f91 100644 --- a/.github/workflows/gpu-ci.yml +++ b/.github/workflows/gpu-ci.yml @@ -56,7 +56,7 @@ jobs: CONDA: "3" needs: gpu-ci-concierge container: - image: ghcr.io/flexflow/flexflow-environment-cuda-11.8:latest + image: ghcr.io/flexflow/flexflow-environment-cuda-12.1:latest options: --gpus all --shm-size=8192m steps: - name: Keep alive @@ -75,7 +75,7 @@ jobs: CONDA: "3" needs: gpu-ci-concierge container: - image: ghcr.io/flexflow/flexflow-environment-cuda-11.8:latest + image: ghcr.io/flexflow/flexflow-environment-cuda-12.1:latest options: --gpus all --shm-size=8192m steps: - name: Install updated git version @@ -151,7 +151,7 @@ jobs: HUGGINGFACE_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }} needs: gpu-ci-concierge container: - image: ghcr.io/flexflow/flexflow-environment-cuda-11.8:latest + image: ghcr.io/flexflow/flexflow-environment-cuda-12.1:latest options: --gpus all --shm-size=8192m steps: - name: Install updated git version @@ -239,7 +239,7 @@ jobs: CONDA: "3" needs: inference-tests container: - image: ghcr.io/flexflow/flexflow-environment-cuda-11.8:latest + image: ghcr.io/flexflow/flexflow-environment-cuda-12.1:latest options: --gpus all --shm-size=8192m steps: - name: Install updated git version diff --git a/cmake/nccl.cmake b/cmake/nccl.cmake index 82cf3b4122..abb4864588 100644 --- a/cmake/nccl.cmake +++ b/cmake/nccl.cmake @@ -36,11 +36,12 @@ if(NCCL_LIBRARY AND NCCL_INCLUDE_DIR) string(REGEX MATCH "([0-9]+)" NCCL_MAJOR ${NCCL_VERSION_DEFINES}) string(REGEX MATCH "([0-9]+)" NCCL_MINOR ${NCCL_VERSION_DEFINES2}) set(NCCL_VERSION "${NCCL_MAJOR}.${NCCL_MINOR}") - if(NCCL_VERSION VERSION_LESS 2.23) - set(NCCL_OLD TRUE) - else() - set(NCCL_OLD FALSE) - endif() + set(NCCL_OLD FALSE) + # if(NCCL_VERSION VERSION_LESS 2.23) + # set(NCCL_OLD TRUE) + # else() + # set(NCCL_OLD FALSE) + # endif() message(STATUS "Found NCCL version: ${NCCL_VERSION}") else() message(WARNING "NCCL header not found, unable to determine version") diff --git a/docker/flexflow-environment/Dockerfile b/docker/flexflow-environment/Dockerfile index ee13a07375..7028fc4b2e 100644 --- a/docker/flexflow-environment/Dockerfile +++ b/docker/flexflow-environment/Dockerfile @@ -55,18 +55,18 @@ ENV CUDA_DIR /usr/local/cuda ARG FF_GPU_BACKEND "cuda" # Update NCCL if FF_GPU_BACKEND is cuda -RUN /bin/bash -c 'if [ "$FF_GPU_BACKEND" = "cuda" ]; then \ - echo "FF_GPU_BACKEND: ${FF_GPU_BACKEND}. Updating NCCL"; \ - ubuntu_version=$(lsb_release -rs); \ - ubuntu_version=${ubuntu_version//./}; \ - wget "https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${ubuntu_version}/x86_64/cuda-keyring_1.0-1_all.deb"; \ - DEBIAN_FRONTEND=noninteractive dpkg -i cuda-keyring_1.0-1_all.deb; \ - DEBIAN_FRONTEND=noninteractive apt-get update -y --allow-change-held-packages; \ - rm -f cuda-keyring_1.0-1_all.deb; \ - DEBIAN_FRONTEND=noninteractive apt install -y --allow-change-held-packages libnccl2 libnccl-dev; \ - else \ - echo "FF_GPU_BACKEND: ${FF_GPU_BACKEND}. Skipping updating NCCL"; \ - fi' +# RUN /bin/bash -c 'if [ "$FF_GPU_BACKEND" = "cuda" ]; then \ +# echo "FF_GPU_BACKEND: ${FF_GPU_BACKEND}. Updating NCCL"; \ +# ubuntu_version=$(lsb_release -rs); \ +# ubuntu_version=${ubuntu_version//./}; \ +# wget "https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${ubuntu_version}/x86_64/cuda-keyring_1.0-1_all.deb"; \ +# DEBIAN_FRONTEND=noninteractive dpkg -i cuda-keyring_1.0-1_all.deb; \ +# DEBIAN_FRONTEND=noninteractive apt-get update -y --allow-change-held-packages; \ +# rm -f cuda-keyring_1.0-1_all.deb; \ +# DEBIAN_FRONTEND=noninteractive apt install -y --allow-change-held-packages libnccl2 libnccl-dev; \ +# else \ +# echo "FF_GPU_BACKEND: ${FF_GPU_BACKEND}. Skipping updating NCCL"; \ +# fi' # Install hip dependencies if FF_GPU_BACKEND is hip_cuda or hip_rocm # Note that amd's docs say to also install the `hip-runtime-nvidia` package. This diff --git a/include/flexflow/ops/kernels/lora_linear_kernels.h b/include/flexflow/ops/kernels/lora_linear_kernels.h index 5360b5f8ea..eee9875d30 100644 --- a/include/flexflow/ops/kernels/lora_linear_kernels.h +++ b/include/flexflow/ops/kernels/lora_linear_kernels.h @@ -8,7 +8,8 @@ #include "flexflow/ops/lora_linear.h" namespace FlexFlow { - +using Legion::Context; +using Legion::Runtime; struct LoraLinearWeight { // weights void *w0_ptr, *w1_ptr; @@ -46,7 +47,9 @@ void inference_kernel_wrapper(LoraLinearMeta *m, BatchConfig const *bc, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output); -void peft_bwd_kernel_wrapper(LoraLinearMeta *m, +void peft_bwd_kernel_wrapper(Context ctx, + Runtime *runtime, + LoraLinearMeta *m, BatchConfig const *bc, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &output_grad); @@ -63,7 +66,9 @@ void inference_kernel(LoraLinearMeta *m, int out_dim, ffStream_t stream); template -void peft_bwd_kernel(LoraLinearMeta *m, +void peft_bwd_kernel(Context ctx, + Runtime *runtime, + LoraLinearMeta *m, BatchConfig const *bc, DT *input_grad_ptr, DT const *output_grad_ptr, diff --git a/include/flexflow/optimizer.h b/include/flexflow/optimizer.h index bab7e6e4ed..4917df73c3 100644 --- a/include/flexflow/optimizer.h +++ b/include/flexflow/optimizer.h @@ -20,7 +20,8 @@ #include "legion.h" namespace FlexFlow { - +using Legion::Context; +using Legion::Runtime; class FFModel; class OpMeta; @@ -60,7 +61,9 @@ class SGDOptimizer : public Optimizer { std::vector const ®ions, Legion::Context ctx, Legion::Runtime *runtime); - static void nccl_update_task_gpu(SGDOptimizer const *op, + static void nccl_update_task_gpu(Context ctx, + Runtime *runtime, + SGDOptimizer const *op, OpMeta const *meta, float const *w_grad_ptr, size_t size, @@ -103,7 +106,9 @@ class AdamOptimizer : public Optimizer { std::vector const ®ions, Legion::Context ctx, Legion::Runtime *runtime); - static void nccl_update_task_gpu(AdamOptimizer const *op, + static void nccl_update_task_gpu(Context ctx, + Runtime *runtime, + AdamOptimizer const *op, OpMeta const *meta, float const *w_grad_ptr, size_t size, diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index 36a56012fc..94bfc74244 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -68,7 +68,7 @@ struct Request { BatchConfig::RequestGuid guid; PEFTModelID peft_model_id = PEFTModelID::NO_ID; int max_length = -1; - int max_new_tokens = 128; + int max_new_tokens = -1; int initial_len; int ssm_cache_size = 0; int llm_cache_size = 0; @@ -302,6 +302,7 @@ class RequestManager { ModelType model_type; int bos_token_id; int eos_token_id; + bool old_llama_tokenizer = false; std::string output_filepath; std::queue pending_infr_request_queue; std::queue pending_peft_request_queue; diff --git a/inference/peft/peft.cc b/inference/peft/peft.cc index ee5bd1b460..14fc653eba 100644 --- a/inference/peft/peft.cc +++ b/inference/peft/peft.cc @@ -340,7 +340,7 @@ void FlexFlow::top_level_task(Task const *task, printf("Inference prompt[%d]: %s\n", total_num_requests, text.c_str()); Request inference_req; inference_req.prompt = text; - inference_req.max_length = 128; + inference_req.max_new_tokens = 128; inference_req.peft_model_id = (peft_model_id != nullptr) ? *peft_model_id : PEFTModelID::NO_ID; requests.push_back(inference_req); diff --git a/inference/python/ff_peft.py b/inference/python/ff_peft.py index a7d38a66b6..13da7aee20 100644 --- a/inference/python/ff_peft.py +++ b/inference/python/ff_peft.py @@ -162,7 +162,7 @@ def main(): ff.Request( ff.RequestType.REQ_INFERENCE, prompt=prompt, - max_sequence_length=128, + max_new_tokens=128, peft_model_id=llm.get_ff_peft_id(lora_inference_config), ) for prompt in prompts @@ -172,7 +172,6 @@ def main(): if len(configs.finetuning_dataset) > 0: finetuning_request = ff.Request( ff.RequestType.REQ_FINETUNING, - max_sequence_length=128, peft_model_id=llm.get_ff_peft_id(lora_finetuning_config), dataset_filepath=configs.finetuning_dataset, max_training_steps=2, diff --git a/inference/python/incr_decoding.py b/inference/python/incr_decoding.py index 1df5a05a8f..232ef1699c 100644 --- a/inference/python/incr_decoding.py +++ b/inference/python/incr_decoding.py @@ -51,12 +51,12 @@ def get_configs(): "tensor_parallelism_degree": 1, "pipeline_parallelism_degree": 2, "offload": False, - "offload_reserve_space_size": 8 * 1024, # 8GB + "offload_reserve_space_size": 8 * 1024, # 8GB "use_4bit_quantization": False, "use_8bit_quantization": False, "enable_peft": False, - "peft_activation_reserve_space_size": 1024, # 1GB - "peft_weight_reserve_space_size": 1024, # 1GB + "peft_activation_reserve_space_size": 1024, # 1GB + "peft_weight_reserve_space_size": 1024, # 1GB "profiling": False, "benchmarking": False, "inference_debugging": False, @@ -71,6 +71,7 @@ def get_configs(): "full_precision": False, "prompt": "", "output_file": "", + "max_length": 128, } # Merge dictionaries ff_init_configs.update(llm_configs) @@ -106,9 +107,9 @@ def main(): max_seq_length=256, max_tokens_per_batch=64, ) - + llm.start_server() - + if len(configs.prompt) > 0: prompts = [s for s in json.load(open(configs.prompt))] if "max_length" not in configs_dict: @@ -119,8 +120,10 @@ def main(): if "max_length" not in configs_dict: result = llm.generate("Three tips for staying healthy are: ") else: - result = llm.generate("Three tips for staying healthy are: ", max_length=configs.max_length) - + result = llm.generate( + "Three tips for staying healthy are: ", max_length=configs.max_length + ) + llm.stop_server() diff --git a/inference/python/spec_infer.py b/inference/python/spec_infer.py index 39529abda3..7ae752cffc 100644 --- a/inference/python/spec_infer.py +++ b/inference/python/spec_infer.py @@ -51,12 +51,12 @@ def get_configs(): "tensor_parallelism_degree": 1, "pipeline_parallelism_degree": 2, "offload": False, - "offload_reserve_space_size": 8 * 1024, # 8GB + "offload_reserve_space_size": 8 * 1024, # 8GB "use_4bit_quantization": False, "use_8bit_quantization": False, "enable_peft": False, - "peft_activation_reserve_space_size": 1024, # 1GB - "peft_weight_reserve_space_size": 1024, # 1GB + "peft_activation_reserve_space_size": 1024, # 1GB + "peft_weight_reserve_space_size": 1024, # 1GB "profiling": False, "benchmarking": False, "inference_debugging": False, @@ -81,6 +81,7 @@ def get_configs(): ], "prompt": "", "output_file": "", + "max_length": 128, } # Merge dictionaries ff_init_configs.update(llm_configs) @@ -144,17 +145,26 @@ def main(): max_tokens_per_batch=64, ssms=ssms, ) - + llm.start_server() if len(configs.prompt) > 0: prompts = [s for s in json.load(open(configs.prompt))] - results = llm.generate(prompts) + if "max_length" not in configs_dict: + results = llm.generate(prompts) + else: + results = llm.generate(prompts, max_length=configs.max_length) else: - result = llm.generate("Three tips for staying healthy are: ") - + if "max_length" not in configs_dict: + result = llm.generate("Three tips for staying healthy are: ") + else: + result = llm.generate( + "Three tips for staying healthy are: ", max_length=configs.max_length + ) + llm.stop_server() + if __name__ == "__main__": print("flexflow inference example (speculative inference)") main() diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py index 9b35b249d9..e2240f0b4f 100644 --- a/python/flexflow/core/flexflow_cffi.py +++ b/python/flexflow/core/flexflow_cffi.py @@ -1795,7 +1795,7 @@ def __init__( raise ValueError( "Target modules can only be specified when trainable=True" ) - + # Check rank, lora_alpha, lora_dropout values if rank is not None or lora_alpha is not None or lora_dropout is not None: if not trainable or not init_lora_weights: @@ -1805,7 +1805,7 @@ def __init__( rank = rank if rank is not None else 8 lora_alpha = lora_alpha if lora_alpha is not None else 8.0 lora_dropout = lora_dropout if lora_dropout is not None else 0.0 - + # If passed, check if the values of rank, lora_alpha, and lora_dropout are valid if rank < 1 or type(rank) != int: raise ValueError("Rank must be >= 1 and an integer") @@ -1813,7 +1813,7 @@ def __init__( raise ValueError("Lora_alpha must be > 0") if lora_dropout < 0 or lora_dropout > 1: raise ValueError("Lora_dropout must be in the interval [0, 1]") - + self.ff_initialized = False self._cache_folder = cache_folder self._peft_model_id = peft_model_id @@ -2051,13 +2051,15 @@ def no_id_handle(): # Request # ----------------------------------------------------------------------- + @dataclass class Request: """A class to record the metadata of an inference or finetuning request.""" + req_type: RequestType prompt: Optional[str] = None max_length: int = -1 - max_new_tokens: int = 128 + max_new_tokens: int = -1 peft_model_id: Optional[PEFTModelID] = None dataset_filepath: Optional[str] = None max_training_steps: int = 1 @@ -4650,26 +4652,65 @@ def get_output_tensor(self, ffmodel, data_type): assert ret_val == True return np_array - def generate_inf_only(self, prompt_list: List[str], max_length: int = -1, max_new_tokens: int = 128): + def _estimate_max_num_tokens( + max_length: int, max_new_tokens: int, prompt: Optional[str] + ): + if prompt is None: + assert max_new_tokens == -1 + return ( + math.ceil(max_new_tokens + len(prompt.split()) * 1.5) + if max_new_tokens != -1 + else max_length + ) + + def _estimate_max_num_chars( + max_length: int, max_new_tokens: int, prompt: Optional[str] + ): + return ( + 5 * FFModel._estimate_max_num_tokens(max_length, max_new_tokens, prompt) + + 100 + ) + + # deprecated + def generate_inf_only( + self, + prompt_list: List[str], + max_length: int, + max_new_tokens: int, + ): if max_length != -1 and max_new_tokens != -1: - warnings.warn(f"Both `max_new_tokens` (={self.max_new_tokens}) and `max_length`(={self.max_length}) seem to have been set. `max_new_tokens` will take precedence.") + raise ValueError( + f"Both `max_new_tokens` (={max_new_tokens}) and `max_length`(={max_length}) seem to have been set." + ) + if max_length == -1 and max_new_tokens == -1: + raise ValueError( + f"Both `max_new_tokens` (={max_new_tokens}) and `max_length`(={max_length}) were left unset." + ) assert isinstance(prompt_list, list) c_input_texts = [get_c_name(prompt) for prompt in prompt_list] - estimated_max_tokens = math.ceil(max_new_tokens + max([len(prompt.split()) for prompt in prompt_list])*1.5) if max_new_tokens != -1 else max_length - max_num_chars = 5 * (estimated_max_tokens + 100) - c_output_texts = [ffi.new("char[]", max_num_chars) for prompt in prompt_list] + c_output_texts = [ + ffi.new( + "char[]", + FFModel._estimate_max_num_chars(max_length, max_new_tokens, prompt), + ) + for prompt in prompt_list + ] c_output_length_and_tokens = [ - ffi.new("int[]", estimated_max_tokens + 100) for prompt in prompt_list + ffi.new( + "int[]", + FFModel._estimate_max_num_tokens(max_length, max_new_tokens, prompt) + + 100, + ) + for prompt in prompt_list ] c_request_types = [ - enum_to_int(RequestType, RequestType.REQ_INFERENCE) - for prompt in prompt_list + enum_to_int(RequestType, RequestType.REQ_INFERENCE) for _ in prompt_list ] - max_lengths = [max_length for prompt in prompt_list] - max_new_tokens_ = [max_new_tokens for prompt in prompt_list] - peft_model_ids = [PEFTModelID.no_id_handle() for prompt in prompt_list] - dataset_filepaths = [ffi.NULL for prompt in prompt_list] - training_steps = [0 for prompt in prompt_list] + max_lengths = [max_length for _ in prompt_list] + max_new_tokens_ = [max_new_tokens for _ in prompt_list] + peft_model_ids = [PEFTModelID.no_id_handle() for _ in prompt_list] + dataset_filepaths = [ffi.NULL for _ in prompt_list] + training_steps = [0 for _ in prompt_list] num_finetuning_losses = ffi.new("int *") c_finetuning_losses = ffi.new("float[]", 0) ffc().flexflow_model_generate( @@ -4698,34 +4739,55 @@ def generate_inf_only(self, prompt_list: List[str], max_length: int = -1, max_ne def generate(self, requests_list: List[Request]): assert isinstance(requests_list, list) + for request in requests_list: + assert isinstance(request, Request) + if request.max_length != -1 and request.max_new_tokens != -1: + raise ValueError( + f"Both `max_new_tokens` (={request.max_new_tokens}) and `max_length`(={request.max_length}) seem to have been set." + ) + if request.max_length == -1 and request.max_new_tokens == -1: + raise ValueError( + f"Both `max_new_tokens` (={request.max_new_tokens}) and `max_length`(={request.max_length}) were left unset." + ) + if ( + request.req_type == RequestType.REQ_FINETUNING + and request.max_new_tokens != -1 + ): + raise ValueError( + f"Finetuning requests should not have `max_new_tokens` set." + ) c_input_texts = [ get_c_name(request.prompt) for request in requests_list ] # entry will be None for finetuning requests c_output_texts = [ ( - ffi.new("char[]", 5 * (request.max_sequence_length + 100)) + ffi.new( + "char[]", + FFModel._estimate_max_num_chars( + request.max_length, request.max_new_tokens, request.prompt + ), + ) if request.req_type == RequestType.REQ_INFERENCE else ffi.NULL ) for request in requests_list ] c_output_length_and_tokens = [ - ffi.new("int[]", request.max_sequence_length + 100) + ffi.new( + "int[]", + FFModel._estimate_max_num_tokens( + request.max_length, request.max_new_tokens, request.prompt + ) + + 100, + ) for request in requests_list ] c_request_types = [ enum_to_int(RequestType, request.req_type) for request in requests_list ] - max_lengths = [ - request.max_length for request in requests_list - ] - max_new_tokens_ = [ - request.max_new_tokens for request in requests_list - ] - for i in range(len(requests_list)): - if max_lengths[i] != -1 and max_new_tokens_[i] != -1: - warnings.warn(f"Both `max_new_tokens` (={max_new_tokens_[i]}) and `max_length`(={max_lengths[i]}) seem to have been set. `max_new_tokens` will take precedence.") - + max_lengths = [request.max_length for request in requests_list] + max_new_tokens_ = [request.max_new_tokens for request in requests_list] + peft_model_ids = [ ( request.peft_model_id @@ -4742,7 +4804,7 @@ def generate(self, requests_list: List[Request]): # c_finetuning_losses = ffi.new("float**") # TODO: set this value automatically c_finetuning_losses = ffi.new("float[]", 10000) - + ffc().flexflow_model_generate( self.handle, len(requests_list), @@ -4774,7 +4836,6 @@ def generate(self, requests_list: List[Request]): finetuning_losses=finetuning_losses, ) ) - return results def set_position_offset(self, offset): ffc().flexflow_model_set_position_offset(self.handle, offset) diff --git a/python/flexflow/serve/serve.py b/python/flexflow/serve/serve.py index e3b6b47466..c8540a6ed3 100644 --- a/python/flexflow/serve/serve.py +++ b/python/flexflow/serve/serve.py @@ -27,15 +27,18 @@ MPTConfig, ) from flexflow.core import * -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer +from transformers import AutoConfig, AutoModelForCausalLM from peft import PeftModel, PeftConfig, LoraConfig from huggingface_hub import HfApi import torch, shutil, hashlib, json, gc from typing import Union, List +from huggingface_hub import snapshot_download class _SupportedModels: - def __init__(self,): + def __init__( + self, + ): self.supported_models = { "LlamaForCausalLM": (ModelType.LLAMA, FlexFlowLLAMA, LLAMAConfig), "LLaMAForCausalLM": (ModelType.LLAMA, FlexFlowLLAMA, LLAMAConfig), @@ -292,8 +295,8 @@ def download_peft_weights(): weights_path = get_weights_path(peft_model_id) refresh_cache_if_needed(peft_model_id) - ff_revision, ff_revision_file, latest_revision = self.__get_revision_hashes( - peft_model_id, weights_path + ff_revision, ff_revision_file, latest_revision = ( + self.__get_revision_hashes(peft_model_id, weights_path) ) if ff_revision != latest_revision: @@ -349,10 +352,25 @@ def download_hf_tokenizer_if_needed(self): print( f"'{self.model_name}' tokenizer needs updating! Downloading tokenizer now..." ) - # Download tokenizer from HuggingFace, or load it from the local folder - hf_tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) - # Save tokenizer - hf_tokenizer.save_pretrained(self.tokenizer_path) + # Load/download the tokenizer files + target_tokenizer_files = [ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "vocab.json", + "merges.txt", + ] + if os.path.exists(self.model_name): + hf_tokenizer_path = self.model_name + else: + hf_tokenizer_path = snapshot_download( + repo_id=self.model_name, allow_patterns=target_tokenizer_files + ) + for file in target_tokenizer_files: + src_path = os.path.join(hf_tokenizer_path, file) + dst_path = os.path.join(self.tokenizer_path, file) + if os.path.exists(src_path): + shutil.copy(src_path, dst_path) print("Done updating HF tokenizer.") # Save new revision hash to file with open(ff_revision_file, "w+") as f: @@ -417,6 +435,8 @@ def compile( model_specific_pipeline_parallelism_degree ) + self.max_seq_length = max_seq_length + # Create request manager and set serving configuration self.rm = RequestManager() self.rm.set_max_requests_per_batch(max_requests_per_batch) @@ -495,11 +515,44 @@ def compile( atexit.register(self.rm.stop_server) + def _generate(self, requests: List[Request]): + if len(requests) == 0: + return [] + for req in requests: + if req.req_type == RequestType.REQ_INFERENCE: + # check max_length and max_new_tokens parameters + if req.max_length == -1 and req.max_new_tokens == -1: + req.max_length = self.max_seq_length -1 + elif req.max_length != -1 and req.max_new_tokens != -1: + warnings.warn( + f"Both `max_new_tokens` (={req.max_new_tokens}) and `max_length`(={req.max_length}) seem to have been set. `max_new_tokens` will take precedence." + ) + req.max_length = -1 + if ( + req.max_length >= self.max_seq_length + or req.max_new_tokens >= self.max_seq_length + ): + raise ValueError( + f"max_length ({req.max_length}) or max_new_tokens ({req.max_new_tokens}) exceeds the maximum sequence length ({self.max_seq_length})" + ) + else: + if req.max_new_tokens != -1: + raise ValueError( + f"max_new_tokens ({req.max_new_tokens}) is not allowed for finetuning requests." + ) + if req.max_length == -1: + req.max_length = self.max_seq_length -1 + if req.max_length >= self.max_seq_length: + raise ValueError( + f"max_length ({req.max_length}) exceeds the maximum sequence length ({self.max_seq_length})" + ) + return self.model.ffmodel.generate(requests) + def generate( self, requests_or_prompts: Union[str, List[str], Request, List[Request]], max_length: int = -1, - max_new_tokens: int = 128, + max_new_tokens: int = -1, ): """Generate tokens based on the input prompt(s) @@ -514,24 +567,35 @@ def generate( """ if type(requests_or_prompts) == str: if len(requests_or_prompts) == 0: - return None - return self.model.ffmodel.generate_inf_only( - [requests_or_prompts], max_length, max_new_tokens + return [] + request = Request( + req_type=RequestType.REQ_INFERENCE, + prompt=requests_or_prompts, + max_length=max_length, + max_new_tokens=max_new_tokens, ) + return self._generate([request]) elif type(requests_or_prompts) == Request: - return self.model.ffmodel.generate(requests_or_prompts) + return self._generate([requests_or_prompts]) elif type(requests_or_prompts) == list: if len(requests_or_prompts) == 0: return [] if type(requests_or_prompts[0]) == str: - return self.model.ffmodel.generate_inf_only( - requests_or_prompts, max_length, max_new_tokens - ) + requests = [ + Request( + req_type=RequestType.REQ_INFERENCE, + prompt=req, + max_length=max_length, + max_new_tokens=max_new_tokens, + ) + for req in requests_or_prompts + ] + return self._generate(requests) else: print(requests_or_prompts) - return self.model.ffmodel.generate(requests_or_prompts) + return self._generate(requests_or_prompts) else: - assert False, "Please pass a non-empty string or list of strings" + assert False, "Please pass a string, list of strings, Request, or list of Requests" def start_server(self): self.rm.start_server(self.model.ffmodel) diff --git a/src/ops/fused.cc b/src/ops/fused.cc index 720d678a4a..984691fa66 100644 --- a/src/ops/fused.cc +++ b/src/ops/fused.cc @@ -476,7 +476,6 @@ void FusedOp::init(FFModel const &ff) { false /*must*/, 0 /*mapper_id*/, outputs[0]->machine_view.hash()); - launcher.concurrent = true; FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); switch (domain.get_dim()) { @@ -571,7 +570,6 @@ void FusedOp::init_inference(FFModel const &ff, false /*must*/, 0 /*mapper_id*/, machine_view_hash); - launcher.concurrent = true; FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); switch (domain.get_dim()) { diff --git a/src/ops/fused.cpp b/src/ops/fused.cpp index 2cede662f3..dfb524d206 100644 --- a/src/ops/fused.cpp +++ b/src/ops/fused.cpp @@ -612,8 +612,10 @@ __host__ void assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_outputs[op] == 1); AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op]; + runtime->concurrent_task_barrier(ctx); Kernels::AllReduce::inference_kernel_wrapper( m, bc, my_input_accessor[0], my_output_accessor[0]); + runtime->concurrent_task_barrier(ctx); break; } case OP_PARALLEL_IDENTITY: { @@ -870,7 +872,12 @@ __host__ void FusedOp::peft_bwd_task(Task const *task, // since we ``inplace'' the output for LoRA assert(my_input_grad_accessor[1].ptr == my_output_grad_accessor[0].ptr); Kernels::LoraLinear::peft_bwd_kernel_wrapper( - m, bc, my_input_grad_accessor[0], my_output_grad_accessor[0]); + ctx, + runtime, + m, + bc, + my_input_grad_accessor[0], + my_output_grad_accessor[0]); break; } case OP_BATCHMATMUL: { @@ -1129,8 +1136,10 @@ __host__ void FusedOp::peft_bwd_task(Task const *task, assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_outputs[op] == 1); ParallelIdentityMeta const *m = (ParallelIdentityMeta *)metas->meta[op]; + runtime->concurrent_task_barrier(ctx); Kernels::ParallelIdentity::peft_bwd_kernel_wrapper( m, bc, my_input_grad_accessor[0], my_output_grad_accessor[0]); + runtime->concurrent_task_barrier(ctx); break; } default: { diff --git a/src/ops/fused.cu b/src/ops/fused.cu index 5aed2cd69a..62845c0f8e 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -623,8 +623,10 @@ __host__ void assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_outputs[op] == 1); AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op]; + runtime->concurrent_task_barrier(ctx); Kernels::AllReduce::inference_kernel_wrapper( m, bc, my_input_accessor[0], my_output_accessor[0]); + runtime->concurrent_task_barrier(ctx); break; } case OP_PARALLEL_IDENTITY: { @@ -888,7 +890,12 @@ __host__ void FusedOp::peft_bwd_task(Task const *task, // since we ``inplace'' the output for LoRA assert(my_input_grad_accessor[1].ptr == my_output_grad_accessor[0].ptr); Kernels::LoraLinear::peft_bwd_kernel_wrapper( - m, bc, my_input_grad_accessor[0], my_output_grad_accessor[0]); + ctx, + runtime, + m, + bc, + my_input_grad_accessor[0], + my_output_grad_accessor[0]); break; } case OP_BATCHMATMUL: { @@ -1149,8 +1156,10 @@ __host__ void FusedOp::peft_bwd_task(Task const *task, assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_outputs[op] == 1); ParallelIdentityMeta const *m = (ParallelIdentityMeta *)metas->meta[op]; + runtime->concurrent_task_barrier(ctx); Kernels::ParallelIdentity::peft_bwd_kernel_wrapper( m, bc, my_input_grad_accessor[0], my_output_grad_accessor[0]); + runtime->concurrent_task_barrier(ctx); break; } default: { diff --git a/src/ops/inc_multihead_self_attention.cpp b/src/ops/inc_multihead_self_attention.cpp index a4604a11a2..8818cd9673 100644 --- a/src/ops/inc_multihead_self_attention.cpp +++ b/src/ops/inc_multihead_self_attention.cpp @@ -147,7 +147,7 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta *m, int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + bc->requestsInfo[i].num_tokens_in_batch; - int max_peft_tokens = bc->requestsInfo[i].max_sequence_length; + int max_peft_tokens = bc->requestsInfo[i].max_length; // Copy query to m->query_activation_buffer if we need to compute // PEFT backward if (bc->requestsInfo[i].peft_bwd) { diff --git a/src/ops/kernels/lora_linear_kernels.cu b/src/ops/kernels/lora_linear_kernels.cu index 93e5820f9c..638cee8cae 100644 --- a/src/ops/kernels/lora_linear_kernels.cu +++ b/src/ops/kernels/lora_linear_kernels.cu @@ -96,7 +96,9 @@ void inference_kernel_wrapper(LoraLinearMeta *m, } } -void peft_bwd_kernel_wrapper(LoraLinearMeta *m, +void peft_bwd_kernel_wrapper(Context ctx, + Runtime *runtime, + LoraLinearMeta *m, BatchConfig const *bc, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &output_grad) { @@ -111,7 +113,9 @@ void peft_bwd_kernel_wrapper(LoraLinearMeta *m, int in_dim = input_grad.domain.hi()[0] - input_grad.domain.lo()[0] + 1; int out_dim = output_grad.domain.hi()[0] - output_grad.domain.lo()[0] + 1; if (m->input_type[0] == DT_FLOAT) { - Internal::peft_bwd_kernel(m, + Internal::peft_bwd_kernel(ctx, + runtime, + m, bc, input_grad.get_float_ptr(), output_grad.get_float_ptr(), @@ -119,7 +123,9 @@ void peft_bwd_kernel_wrapper(LoraLinearMeta *m, out_dim, stream); } else if (m->input_type[0] == DT_HALF) { - Internal::peft_bwd_kernel(m, + Internal::peft_bwd_kernel(ctx, + runtime, + m, bc, input_grad.get_half_ptr(), output_grad.get_half_ptr(), @@ -361,7 +367,9 @@ __global__ void sgd_update(size_t count, } template -void peft_bwd_kernel(LoraLinearMeta *m, +void peft_bwd_kernel(Context ctx, + Runtime *runtime, + LoraLinearMeta *m, BatchConfig const *bc, DT *input_grad_ptr, DT const *output_grad_ptr, @@ -543,13 +551,15 @@ void peft_bwd_kernel(LoraLinearMeta *m, // and sum first #ifdef FF_USE_NCCL ncclDataType_t nccl_data_type = ff_to_nccl_datatype(m->output_type[0]); - checkCUDA(ncclAllReduce(static_cast
(weight.w1_grad_ptr), + runtime->concurrent_task_barrier(ctx); + checkNCCL(ncclAllReduce(static_cast
(weight.w1_grad_ptr), static_cast
(weight.w1_grad_ptr), w1_num_elements, nccl_data_type, ncclSum, m->handle.ncclComm, stream)); + runtime->concurrent_task_barrier(ctx); #else assert(false && "Must enable FF_USE_NCCL to use AllReduce operators"); #endif diff --git a/src/ops/lora_linear.cc b/src/ops/lora_linear.cc index 513147f3b7..3749cce994 100644 --- a/src/ops/lora_linear.cc +++ b/src/ops/lora_linear.cc @@ -296,7 +296,6 @@ void LoraLinear::init_inference( false /*must*/, 0 /*mapper_id*/, machine_view_hash); - launcher.concurrent = true; launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, READ_ONLY, @@ -1066,7 +1065,7 @@ void LoraLinear::peft_bwd_task(Task const *task, int out_dim = output_grad.domain.hi()[0] - output_grad.domain.lo()[0] + 1; // int num_infr_tokens = bc->num_active_infr_tokens(); // int num_peft_tokens = bc->num_active_peft_tokens(); - peft_bwd_kernel_wrapper(m, bc, input_grad, output_grad); + peft_bwd_kernel_wrapper(ctx, runtime, m, bc, input_grad, output_grad); save_peft_weights_if_needed(m, bc, in_dim, out_dim, shard_id); diff --git a/src/ops/spec_inc_multihead_self_attention.cc b/src/ops/spec_inc_multihead_self_attention.cc index aa74ecc6f5..6b2a4be507 100644 --- a/src/ops/spec_inc_multihead_self_attention.cc +++ b/src/ops/spec_inc_multihead_self_attention.cc @@ -170,7 +170,7 @@ Op *SpecIncMultiHeadSelfAttention::create_operator_from_layer( Layer const *layer, std::vector const &inputs) { - std::cout << "spec create operator: " << layer->name << "\n"; + // std::cout << "spec create operator: " << layer->name << "\n"; long long value; layer->get_int_property("embed_dim", value); int embed_dim = value; @@ -182,10 +182,10 @@ Op *SpecIncMultiHeadSelfAttention::create_operator_from_layer( int kdim = value; layer->get_int_property("vdim", value); int vdim = value; - float dropout; - layer->get_float_property("dropout", dropout); layer->get_int_property("add_zero_attn", value); bool add_zero_attn = (bool)value; + float dropout; + layer->get_float_property("dropout", dropout); RotaryEmbeddingMeta rotary_embedding_meta; layer->get_int_property("apply_rotary_embedding", value); rotary_embedding_meta.apply_rotary_embedding = (bool)value; diff --git a/src/ops/tree_inc_multihead_self_attention.cc b/src/ops/tree_inc_multihead_self_attention.cc index ae0795ac1e..ac0011d9eb 100644 --- a/src/ops/tree_inc_multihead_self_attention.cc +++ b/src/ops/tree_inc_multihead_self_attention.cc @@ -163,6 +163,7 @@ Tensor FFModel::inc_multiquery_self_attention_verify( rotary_embedding_meta.original_max_position_embeddings); li->add_int_property("scaling_query", scaling_query); li->add_float_property("scaling_factor", scaling_factor); + li->add_int_property("qk_prod_scaling", qk_prod_scaling); li->add_int_property("position_bias", position_bias); li->add_int_property("quantization_type", quantization_type); li->add_int_property("offload", offload); @@ -187,10 +188,10 @@ Op *TreeIncMultiHeadSelfAttention::create_operator_from_layer( int kdim = value; layer->get_int_property("vdim", value); int vdim = value; - float dropout; - layer->get_float_property("dropout", dropout); layer->get_int_property("add_zero_attn", value); bool add_zero_attn = (bool)value; + float dropout; + layer->get_float_property("dropout", dropout); RotaryEmbeddingMeta rotary_embedding_meta; layer->get_int_property("apply_rotary_embedding", value); rotary_embedding_meta.apply_rotary_embedding = (bool)value; @@ -203,6 +204,7 @@ Op *TreeIncMultiHeadSelfAttention::create_operator_from_layer( rotary_embedding_meta.high_freq_factor); layer->get_int_property("original_max_position_embeddings", value); rotary_embedding_meta.original_max_position_embeddings = (int)value; + layer->get_int_property("scaling_query", value); bool scaling_query = (bool)value; float scaling_factor; layer->get_float_property("scaling_factor", scaling_factor); diff --git a/src/parallel_ops/allreduce.cc b/src/parallel_ops/allreduce.cc index a4443c4066..6611a6bb1f 100644 --- a/src/parallel_ops/allreduce.cc +++ b/src/parallel_ops/allreduce.cc @@ -197,7 +197,9 @@ void AllReduce::forward_task(Task const *task, m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); assert(input.data_type == output.data_type); + // runtime->concurrent_task_barrier(ctx); forward_kernel_wrapper(m, input, output); + // runtime->concurrent_task_barrier(ctx); } void AllReduce::backward(FFModel const &ff) { @@ -347,7 +349,9 @@ void AllReduce::inference_task(Task const *task, m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); assert(input.data_type == output.data_type); + // runtime->concurrent_task_barrier(ctx); inference_kernel_wrapper(m, bc, input, output); + // runtime->concurrent_task_barrier(ctx); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; diff --git a/src/parallel_ops/parallel_identity.cc b/src/parallel_ops/parallel_identity.cc index 7d68036709..2f76897712 100644 --- a/src/parallel_ops/parallel_identity.cc +++ b/src/parallel_ops/parallel_identity.cc @@ -245,7 +245,9 @@ void ParallelIdentity::backward_task(Task const *task, m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); assert(input_grad.data_type == output_grad.data_type); + // runtime->concurrent_task_barrier(ctx); backward_kernel_wrapper(m, input_grad, output_grad); + // runtime->concurrent_task_barrier(ctx); } void ParallelIdentity::init_inference( @@ -270,7 +272,6 @@ void ParallelIdentity::init_inference( false /*must*/, 0 /*mapper_id*/, machine_view_hash); - launcher.concurrent = true; launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, READ_ONLY, @@ -422,7 +423,9 @@ void ParallelIdentity::peft_bwd_task(Task const *task, m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); assert(input_grad.data_type == output_grad.data_type); + // runtime->concurrent_task_barrier(ctx); peft_bwd_kernel_wrapper(m, bc, input_grad, output_grad); + // runtime->concurrent_task_barrier(ctx); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 69fe3b598d..417cd2c056 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -1677,6 +1677,7 @@ void FFModel::finish_nccl_comms() { false /*must*/, 0 /*mapper_id*/, comm.first); + index_launcher.concurrent = true; FutureMap fm = runtime->execute_index_space(ctx, index_launcher); fm.wait_all_results(); } @@ -6899,7 +6900,6 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(LORA_LINEAR_INIT_TASK_ID, "LoraLinear Init"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); - registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "LoraLinear Init Task"); @@ -6932,6 +6932,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "LoraLinear PEFT Backward Task"); @@ -6963,7 +6964,6 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(FUSEDOP_INIT_TASK_ID, "FusedOp Init"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); - registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedOp Init Task"); @@ -6979,6 +6979,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedOp Inference Task"); @@ -6995,6 +6996,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedOp PEFT Backward Task"); @@ -7011,6 +7013,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedOp Forward Task"); @@ -7026,6 +7029,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedOp Backward Task"); @@ -7262,7 +7266,6 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(ALLREDUCE_INIT_TASK_ID, "AllReduce Init"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); - registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "AllReduce init Task"); @@ -7280,6 +7283,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, // AllReduce forward and backward must run concurrently since they // use ncclAllReduce internally registrar.set_concurrent(); + // registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "AllReduce Forward Task"); @@ -7294,9 +7298,6 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(ALLREDUCE_BWD_TASK_ID, "AllReduce Backward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); - // AllReduce forward and backward must run concurrently since they - // use ncclAllReduce internally - // registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "AllReduce Backward Task"); @@ -7315,6 +7316,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, // AllReduce forward and backward must run concurrently since they // use ncclAllReduce internally registrar.set_concurrent(); + // registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "AllReduce Inference Task"); @@ -7330,9 +7332,6 @@ void register_flexflow_internal_tasks(Runtime *runtime, "AllReduce PEFT Backward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); - // AllReduce forward and backward must run concurrently since they - // use ncclAllReduce internally - // registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "AllReduce PEFT Backward Task"); @@ -7349,7 +7348,6 @@ void register_flexflow_internal_tasks(Runtime *runtime, "ParallelIdentity Init"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); - registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "ParallelIdentity init Task"); @@ -7382,6 +7380,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + // registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "ParallelIdentity Backward Task"); @@ -7415,6 +7414,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + // registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "ParallelIdentity PEFT Backward Task"); @@ -7433,6 +7433,8 @@ void register_flexflow_internal_tasks(Runtime *runtime, "FusedParallel Forward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedParallel Forward Task"); @@ -7448,6 +7450,8 @@ void register_flexflow_internal_tasks(Runtime *runtime, "FusedParallel Backward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedParallel Backward Task"); @@ -7496,6 +7500,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "SGD NCCL Update Task", 111 /*variant ID*/); @@ -7511,6 +7516,8 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(ADAM_UPD_NCCL_TASK_ID, "Adam NCCL Update"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "Adam NCCL Update Task", 111 /*variant ID*/); @@ -7648,6 +7655,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + // registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "NCCL Init Communicators Task", 111 /*variant ID*/); @@ -7664,6 +7672,8 @@ void register_flexflow_internal_tasks(Runtime *runtime, "NCCL Finish Communicators"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); + // registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "NCCL Finish Communicators Task", 111 /*variant ID*/); diff --git a/src/runtime/optimizer.cc b/src/runtime/optimizer.cc index c42a0c9aa6..96b735803c 100644 --- a/src/runtime/optimizer.cc +++ b/src/runtime/optimizer.cc @@ -311,7 +311,7 @@ void SGDOptimizer::nccl_update_task(Task const *task, } } - nccl_update_task_gpu(op, meta, w_grad_ptr, size, w_ptr, v_ptr); + nccl_update_task_gpu(ctx, runtime, op, meta, w_grad_ptr, size, w_ptr, v_ptr); } #endif @@ -603,7 +603,8 @@ void AdamOptimizer::nccl_update_task(Task const *task, } } - nccl_update_task_gpu(op, meta, w_grad_ptr, size, w_ptr, v_ptr, m_ptr); + nccl_update_task_gpu( + ctx, runtime, op, meta, w_grad_ptr, size, w_ptr, v_ptr, m_ptr); } #endif diff --git a/src/runtime/optimizer_kernel.cpp b/src/runtime/optimizer_kernel.cpp index 59efaf5256..9b0d3c8892 100644 --- a/src/runtime/optimizer_kernel.cpp +++ b/src/runtime/optimizer_kernel.cpp @@ -86,7 +86,9 @@ __host__ void SGDOptimizer::ps_update_task_gpu(SGDOptimizer const *op, } #ifdef FF_USE_NCCL -__host__ void SGDOptimizer::nccl_update_task_gpu(SGDOptimizer const *op, +__host__ void SGDOptimizer::nccl_update_task_gpu(Context ctx, + Runtime *runtime, + SGDOptimizer const *op, OpMeta const *meta, float const *w_grad_ptr, size_t size, @@ -96,6 +98,7 @@ __host__ void SGDOptimizer::nccl_update_task_gpu(SGDOptimizer const *op, // fprintf(stderr, "weight(%p) Before ncclAllReduce...\n", w_grad_ptr); hipStream_t stream; checkCUDA(get_legion_stream(&stream)); + runtime->concurrent_task_barrier(ctx); checkNCCL(ncclAllReduce(w_grad_ptr, (float *)w_grad_ptr, size, @@ -103,6 +106,7 @@ __host__ void SGDOptimizer::nccl_update_task_gpu(SGDOptimizer const *op, ncclSum, meta->handle.ncclComm, stream)); + runtime->concurrent_task_barrier(ctx); // fprintf(stderr, "weight(%p) After ncclAllReduce...\n", w_grad_ptr); // Step 2: SGD update @@ -208,7 +212,9 @@ __host__ void AdamOptimizer::ps_update_task_gpu(AdamOptimizer const *op, } #ifdef FF_USE_NCCL -__host__ void AdamOptimizer::nccl_update_task_gpu(AdamOptimizer const *op, +__host__ void AdamOptimizer::nccl_update_task_gpu(Context ctx, + Runtime *runtime, + AdamOptimizer const *op, OpMeta const *meta, float const *w_grad_ptr, size_t size, @@ -218,6 +224,7 @@ __host__ void AdamOptimizer::nccl_update_task_gpu(AdamOptimizer const *op, // Use NCCL to sync gradients hipStream_t stream; checkCUDA(get_legion_stream(&stream)); + runtime->concurrent_task_barrier(ctx); checkNCCL(ncclAllReduce(w_grad_ptr, (float *)w_grad_ptr, size, @@ -225,6 +232,7 @@ __host__ void AdamOptimizer::nccl_update_task_gpu(AdamOptimizer const *op, ncclSum, meta->handle.ncclComm, stream)); + runtime->concurrent_task_barrier(ctx); // fprintf(stderr, "alpha = %.8lf alpha_t = %.8lf decay = %.8lf\n", // op->alpha, op->alpha_t, op->weight_decay); // Step 2: Adam update diff --git a/src/runtime/optimizer_kernel.cu b/src/runtime/optimizer_kernel.cu index df37e3b135..72ee74940f 100644 --- a/src/runtime/optimizer_kernel.cu +++ b/src/runtime/optimizer_kernel.cu @@ -75,7 +75,9 @@ __host__ void SGDOptimizer::ps_update_task_gpu(SGDOptimizer const *op, } #ifdef FF_USE_NCCL -__host__ void SGDOptimizer::nccl_update_task_gpu(SGDOptimizer const *op, +__host__ void SGDOptimizer::nccl_update_task_gpu(Context ctx, + Runtime *runtime, + SGDOptimizer const *op, OpMeta const *meta, float const *w_grad_ptr, size_t size, @@ -85,6 +87,7 @@ __host__ void SGDOptimizer::nccl_update_task_gpu(SGDOptimizer const *op, // fprintf(stderr, "weight(%p) Before ncclAllReduce...\n", w_grad_ptr); cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); + runtime->concurrent_task_barrier(ctx); checkNCCL(ncclAllReduce(w_grad_ptr, (float *)w_grad_ptr, size, @@ -92,6 +95,7 @@ __host__ void SGDOptimizer::nccl_update_task_gpu(SGDOptimizer const *op, ncclSum, meta->handle.ncclComm, stream)); + runtime->concurrent_task_barrier(ctx); // fprintf(stderr, "weight(%p) After ncclAllReduce...\n", w_grad_ptr); // print_tensor((float*)w_grad_ptr, 16, "[After ncclAllReduce]"); @@ -183,7 +187,9 @@ __host__ void AdamOptimizer::ps_update_task_gpu(AdamOptimizer const *op, } #ifdef FF_USE_NCCL -__host__ void AdamOptimizer::nccl_update_task_gpu(AdamOptimizer const *op, +__host__ void AdamOptimizer::nccl_update_task_gpu(Context ctx, + Runtime *runtime, + AdamOptimizer const *op, OpMeta const *meta, float const *w_grad_ptr, size_t size, @@ -193,6 +199,7 @@ __host__ void AdamOptimizer::nccl_update_task_gpu(AdamOptimizer const *op, // Use NCCL to sync gradients cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); + runtime->concurrent_task_barrier(ctx); checkNCCL(ncclAllReduce(w_grad_ptr, (float *)w_grad_ptr, size, @@ -200,6 +207,7 @@ __host__ void AdamOptimizer::nccl_update_task_gpu(AdamOptimizer const *op, ncclSum, meta->handle.ncclComm, stream)); + runtime->concurrent_task_barrier(ctx); // fprintf(stderr, "alpha = %.8lf alpha_t = %.8lf decay = %.8lf\n", // op->alpha, op->alpha_t, op->weight_decay); // Step 2: Adam update diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 44b181fcb3..5fbee65e6d 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -186,28 +186,35 @@ void RequestManager::register_tokenizer(ModelType type, std::filesystem::path tokenizer_folder(path); if (model_type == ModelType::LLAMA) { - std::filesystem::path tokenizer_model_path; + // try with tokenizer.json first + std::filesystem::path tokenizer_json_path; if (std::filesystem::is_directory(tokenizer_folder)) { - tokenizer_model_path = - std::filesystem::path(tokenizer_folder) / "tokenizer.model"; + tokenizer_json_path = + std::filesystem::path(tokenizer_folder) / "tokenizer.json"; } else { - tokenizer_model_path = tokenizer_folder; + tokenizer_json_path = tokenizer_folder; } - if (std::filesystem::exists(tokenizer_model_path)) { - // load from tokenizer.model - this->tokenizer_ = Tokenizer::FromBlobSentencePiece( - LoadBytesFromFile(tokenizer_model_path.string())); - } else { + if (std::filesystem::exists(tokenizer_json_path)) { // load from tokenizer.json - std::filesystem::path tokenizer_json_path = - tokenizer_folder / "tokenizer.json"; - if (!std::filesystem::exists(tokenizer_json_path)) { - std::cerr << "Failed to open file: " << tokenizer_json_path + this->tokenizer_ = Tokenizer::FromBlobJSON( + LoadBytesFromFile(tokenizer_json_path.string())); + } else { + // load from tokenizer.model + std::filesystem::path tokenizer_model_path; + if (std::filesystem::is_directory(tokenizer_folder)) { + tokenizer_model_path = + std::filesystem::path(tokenizer_folder) / "tokenizer.model"; + } else { + tokenizer_model_path = tokenizer_folder; + } + if (!std::filesystem::exists(tokenizer_model_path)) { + std::cerr << "Failed to open file: " << tokenizer_model_path << std::endl; assert(false); } - this->tokenizer_ = Tokenizer::FromBlobJSON( - LoadBytesFromFile(tokenizer_json_path.string())); + old_llama_tokenizer = true; + this->tokenizer_ = Tokenizer::FromBlobSentencePiece( + LoadBytesFromFile(tokenizer_model_path.string())); } } else if (model_type == ModelType::OPT) { std::filesystem::path vocab_file = tokenizer_folder / "vocab.json"; @@ -264,7 +271,13 @@ RequestManager::RequestGuid request.guid = next_available_guid++; request.max_length = request_.max_length; request.max_new_tokens = request_.max_new_tokens; + // both unset + if (request.max_length == -1 && request.max_new_tokens == -1) { + request.max_length = get_max_sequence_length() - 1; + } + // both set if (request.max_length != -1 && request.max_new_tokens != -1) { + request.max_length = -1; std::cout << "Both `max_new_tokens` (=" << request.max_new_tokens << ") and `max_length`(=" << request.max_length @@ -365,15 +378,14 @@ RequestManager::RequestGuid request.initial_len = 0; request.max_length = request_.max_length; request.max_new_tokens = request_.max_new_tokens; - if (request.max_length != -1) { - std::cout << "Warning: max_length is set for PEFT finetuning, but it will " - "be ignored." - << std::endl; - } if (request.max_new_tokens != -1) { - std::cout << "Warning: max_new_tokens is set for PEFT finetuning, but " - "it will be ignored." - << std::endl; + std::cerr + << "Error: max_new_tokens is not allowed for PEFT finetuning requests" + << std::endl; + assert(false); + } + if (request.max_length == -1) { + request.max_length = get_max_sequence_length() - 1; } request.peft_model_id = request_.peft_model_id; request.req_type = RequestType::REQ_FINETUNING; @@ -660,7 +672,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, std::string output = this->tokenizer_->Decode(request.tokens); // Unlike Huggingface, the sentencepiece C++ library automatically // removes the BOS token - if (model_type == ModelType::LLAMA && + if (model_type == ModelType::LLAMA && old_llama_tokenizer && request.tokens.at(0) == bos_token_id) { output = " " + output; } @@ -1121,7 +1133,7 @@ BeamSearchBatchConfig std::string output = this->tokenizer_->Decode(request.tokens); // Unlike Huggingface, the sentencepiece C++ library automatically // removes the BOS token - if (model_type == ModelType::LLAMA && + if (model_type == ModelType::LLAMA && old_llama_tokenizer && request.tokens.at(0) == bos_token_id) { output = " " + output; } @@ -1264,7 +1276,7 @@ BeamSearchBatchConfig std::string output = this->tokenizer_->Decode(request.tokens); // Unlike Huggingface, the sentencepiece C++ library automatically // removes the BOS token - if (model_type == ModelType::LLAMA && + if (model_type == ModelType::LLAMA && old_llama_tokenizer && request.tokens.at(0) == bos_token_id) { output = " " + output; } @@ -1312,7 +1324,7 @@ BeamSearchBatchConfig std::string output = this->tokenizer_->Decode(request.tokens); // Unlike Huggingface, the sentencepiece C++ library automatically removes // the BOS token - if (model_type == ModelType::LLAMA && + if (model_type == ModelType::LLAMA && old_llama_tokenizer && request.tokens.at(0) == bos_token_id) { output = " " + output; } diff --git a/tests/inference/python_test_configs/generate_configs.py b/tests/inference/python_test_configs/generate_configs.py index 0a745c7984..2720304d4f 100644 --- a/tests/inference/python_test_configs/generate_configs.py +++ b/tests/inference/python_test_configs/generate_configs.py @@ -34,6 +34,7 @@ "full_precision": True, "prompt": "", "output_file": "", + "max_length": 128, } ssm_configs = { "ssms": [