diff --git a/SERVE.md b/SERVE.md index f6e34750cd..e64756e8f4 100644 --- a/SERVE.md +++ b/SERVE.md @@ -187,9 +187,6 @@ We provide five prompt datasets for evaluating FlexFlow Serve: [Chatbot instruct FlexFlow Serve is still under active development. We currently focus on the following tasks and strongly welcome all contributions from bug fixes to new features and extensions. * AMD benchmarking. We are actively working on benchmarking FlexFlow Serve on AMD GPUs and comparing it with the performance on NVIDIA GPUs. -* Chatbot prompt templates and Multi-round conversations -* Support for FastAPI server -* Integration with LangChain for document question answering ## Acknowledgements This project is initiated by members from CMU, Stanford, and UCSD. We will be continuing developing and supporting FlexFlow Serve. Please cite FlexFlow Serve as: diff --git a/docs/source/chatbot.rst b/docs/source/chatbot.rst new file mode 100644 index 0000000000..fc6f616fae --- /dev/null +++ b/docs/source/chatbot.rst @@ -0,0 +1,64 @@ +:tocdepth: 1 +******** +Chatbot +******** + +The chatbot use case involves setting up a conversational AI model using FlexFlow Serve, capable of engaging in interactive dialogues with users. + +Requirements +============ + +- FlexFlow Serve setup with required configurations. +- Gradio or any interactive interface tool. + +Implementation +============== + +1. FlexFlow Initialization + Initialize FlexFlow Serve with desired configurations and specific LLM model. + +2. Gradio Interface Setup + Define a function for response generation based on user inputs. Setup Gradio Chat Interface for interaction. + + .. code-block:: python + + def generate_response(user_input): + result = llm.generate(user_input) + return result.output_text.decode('utf-8') + + +3. Running the Interface + Launch the Gradio interface and interact with the model by entering text inputs. + + .. image:: /imgs/gradio_interface.png + :alt: Gradio Chatbot Interface + :align: center + +4. Shutdown + Stop the FlexFlow server after interaction. + +Example +======= + +Complete code example can be found here: + +1. `Chatbot Example with incremental decoding `__ + +2. `Chatbot Example with speculative inference `__ + + +Example Implementation: + + .. code-block:: python + + import gradio as gr + import flexflow.serve as ff + + ff.init(num_gpus=2, memory_per_gpu=14000, ...) + + def generate_response(user_input): + result = llm.generate(user_input) + return result.output_text.decode('utf-8') + + iface = gr.ChatInterface(fn=generate_response) + iface.launch() \ No newline at end of file diff --git a/docs/source/imgs/gradio_api.png b/docs/source/imgs/gradio_api.png new file mode 100644 index 0000000000..7bf1b99a5e Binary files /dev/null and b/docs/source/imgs/gradio_api.png differ diff --git a/docs/source/imgs/gradio_interface.png b/docs/source/imgs/gradio_interface.png new file mode 100644 index 0000000000..9584d76fb3 Binary files /dev/null and b/docs/source/imgs/gradio_interface.png differ diff --git a/docs/source/index.rst b/docs/source/index.rst index a7ea2ff3ac..6aa47d157b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -18,6 +18,8 @@ Welcome to FlexFlow's documentation! :caption: FlexFlow Serve serve_overview + serve_usecases + serve_api .. toctree:: :caption: FlexFlow Train diff --git a/docs/source/prompt_template.rst b/docs/source/prompt_template.rst new file mode 100644 index 0000000000..4e0f1beab5 --- /dev/null +++ b/docs/source/prompt_template.rst @@ -0,0 +1,55 @@ +:tocdepth: 1 +**************** +Prompt Template +**************** + +Prompt templates guide the model's response generation. This use case demonstrates setting up FlexFlow Serve to integrate with Langchain and using prompt templates to handle dynamic prompt templates. + +Requirements +============ + +- FlexFlow Serve setup with appropriate configurations. +- Langchain integration with templates for prompt management. + +Implementation +============== + +1. FlexFlow Initialization + Initialize and configure FlexFlow Serve. + +2. LLM Setup + Compile and start the server for text generation. + +3. Prompt Template Setup + Setup a prompt template for guiding model's responses. + +4. Response Generation + Use the LLM with the prompt template to generate a response. + +5. Shutdown + Stop the FlexFlow server after generating the response. + +Example +======= + +Complete code example can be found here: + +1. `Prompt Template Example with incremental decoding `__ + +2. `Prompt Template Example with speculative inference `__ + + +Example Implementation: + + .. code-block:: python + + import flexflow.serve as ff + from langchain.prompts import PromptTemplate + + ff_llm = FlexFlowLLM(...) + ff_llm.compile_and_start(...) + + template = "Question: {question}\nAnswer:" + prompt = PromptTemplate(template=template, input_variables=["question"]) + + response = ff_llm.generate("Who was the US president in 1997?") diff --git a/docs/source/rag.rst b/docs/source/rag.rst new file mode 100644 index 0000000000..4b869c2352 --- /dev/null +++ b/docs/source/rag.rst @@ -0,0 +1,90 @@ +:tocdepth: 1 +******** +RAG Q&A +******** + +Retrieval Augmented Generation (RAG) combines language models with external knowledge. This use case integrates RAG with FlexFlow Serve for Q&A with documents. + +Requirements +============ + +- FlexFlow Serve setup. +- Retriever setup for RAG. + +Implementation +============== + +1. FlexFlow Initialization + Initialize and configure FlexFlow Serve. + +2. Data Retrieval Setup + Setup a retriever for sourcing information relevant to user queries. + +3. RAG Integration + Integrate the retriever with FlexFlow Serve. + +4. Response Generation + Use the LLM with RAG to generate responses based on model's knowledge and retrieved information. + +5. Shutdown + The FlexFlow server automatically shuts down after generating the response. + +Example +======= + +A complete code example for a web-document Q&A using FlexFlow can be found here: + +1. `Rag Q&A Example with incremental decoding `__ + +2. `Rag Q&A Example with speculative inference `__ + + +Example Implementation: + + .. code-block:: python + + # imports + + # compile and start server + ff_llm = FlexFlowLLM(...) + gen_config = ff.GenerationConfig(...) + ff_llm.compile_and_start(...) + ff_llm_wrapper = FF_LLM_wrapper(flexflow_llm=ff_llm) + + + # Load web page content + loader = WebBaseLoader("https://example.com/data") + data = loader.load() + + # Split text + text_splitter = RecursiveCharacterTextSplitter(...) + all_splits = text_splitter.split_documents(data) + + # Initialize embeddings + embeddings = OpenAIEmbeddings(...) + + # Create VectorStore + vectorstore = Chroma.from_documents(all_splits, embeddings) + + # Use VectorStore as a retriever + retriever = vectorstore.as_retriever() + + # Apply similarity search + question = "Example Question" + docs = vectorstore.similarity_search(question) + max_chars_per_doc = 100 + docs_text = ''.join([docs[i].page_content[:max_chars_per_doc] for i in range(len(docs))]) + + # Using a Prompt Template + prompt_rag = PromptTemplate.from_template( + "Summarize the main themes in these retrieved docs: {docs_text}" + ) + + # Build Chain + llm_chain_rag = LLMChain(llm=ff_llm_wrapper, prompt=prompt_rag) + + # Run + rag_result = llm_chain_rag(docs_text) + + # Stop the server + ff_llm.stop_server() \ No newline at end of file diff --git a/docs/source/serve_api.rst b/docs/source/serve_api.rst new file mode 100644 index 0000000000..6a607cbf0c --- /dev/null +++ b/docs/source/serve_api.rst @@ -0,0 +1,7 @@ +************************** +FlexFlow Serve Python API +************************** + +.. toctree:: + serve_fastapi + serve_gradioapi \ No newline at end of file diff --git a/docs/source/serve_fastapi.rst b/docs/source/serve_fastapi.rst new file mode 100644 index 0000000000..0aa6634670 --- /dev/null +++ b/docs/source/serve_fastapi.rst @@ -0,0 +1,106 @@ +:tocdepth: 1 +*********************** +FlexFlow Serve FastAPI +*********************** + +Introduction +============ + +The Python API for FlexFlow Serve enables users to initialize, manage and interact with large language models (LLMs) via FastAPI or Gradio. + +Requirements +------------ + +- FlexFlow Serve setup with necessary configurations. +- FastAPI and Uvicorn for running the API server. + +API Configuration +================= + +Users can configure the API using FastAPI to handle requests and manage the model. + +1. FastAPI Application Initialization + Initialize the FastAPI application to create API endpoints. + +2. Request Model Definition + Define the model for API requests using Pydantic. + +3. Global Variable for LLM Model + Declare a global variable to store the LLM model. + +Example +------- + +.. code-block:: python + + from fastapi import FastAPI + from pydantic import BaseModel + import flexflow.serve as ff + + app = FastAPI() + + class PromptRequest(BaseModel): + prompt: str + + llm = None + +Endpoint Creation +================= + +Create API endpoints for LLM interactions to handle generation requests. + +1. Initialize Model on Startup + Use the FastAPI event handler to initialize and compile the LLM model when the API server starts. + +2. Generate Response Endpoint + Create a POST endpoint to generate responses based on the user's prompt. + +Example +------- + +.. code-block:: python + + @app.on_event("startup") + async def startup_event(): + global llm + # Initialize and compile the LLM model + llm.compile( + generation_config, + # ... other params as needed + ) + llm.start_server() + + @app.post("/generate/") + async def generate(prompt_request: PromptRequest): + # ... exception handling + full_output = llm.generate([prompt_request.prompt])[0].output_text.decode('utf-8') + # ... split prompt and response text for returning results + return {"prompt": prompt_request.prompt, "response": full_output} + +Running and Testing +=================== + +Instructions for running and testing the FastAPI server. + +1. Run the FastAPI Server + Use Uvicorn to run the FastAPI server with specified host and port. + +2. Testing the API + Make requests to the API endpoints and verify the responses. + +Example +------- + +.. code-block:: bash + + # Running within the inference/python folder: + uvicorn entrypoint.fastapi_incr:app --reload --port 3000 + +Full API Entrypoint Code +========================= + +A complete code example for a web-document Q&A using FlexFlow can be found here: + +1. `FastAPI Example with incremental decoding `__ + +2. `FastAPI Example with speculative inference `__ diff --git a/docs/source/serve_gradioapi.rst b/docs/source/serve_gradioapi.rst new file mode 100644 index 0000000000..ed19e05347 --- /dev/null +++ b/docs/source/serve_gradioapi.rst @@ -0,0 +1,30 @@ +:tocdepth: 1 +************************* +FlexFlow Serve Gradio API +************************* + +Introduction +============ + +Users can also set up the API endpoints with a Gradio Chatbot Interface. + +Requirements +------------ + +- FlexFlow Serve setup with necessary configurations. +- Running the gradio chatbot interface. + +Example +======== + +In a running gradio chatbot interface, hit the "Use via API" button on the bottom left. + + .. image:: /imgs/gradio_interface.png + :alt: Gradio Chatbot Interface + :align: center + +Users can easily access an API endpoint for sending prompts to the model. + + .. image:: /imgs/gradio_api.png + :alt: Gradio API + :align: center \ No newline at end of file diff --git a/docs/source/serve_usecases.rst b/docs/source/serve_usecases.rst new file mode 100644 index 0000000000..4aa3fd2807 --- /dev/null +++ b/docs/source/serve_usecases.rst @@ -0,0 +1,8 @@ +******************* +Serving Usecases +******************* + +.. toctree:: + chatbot + prompt_template + rag \ No newline at end of file diff --git a/inference/.gitignore b/inference/.gitignore index 8ab99cb1eb..1da34a668b 100644 --- a/inference/.gitignore +++ b/inference/.gitignore @@ -3,3 +3,4 @@ weights tokenizers prompt output +.env \ No newline at end of file diff --git a/inference/python/entrypoint/fastapi_incr.py b/inference/python/entrypoint/fastapi_incr.py new file mode 100644 index 0000000000..34f61739fb --- /dev/null +++ b/inference/python/entrypoint/fastapi_incr.py @@ -0,0 +1,162 @@ +# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Running Instructions: +- To run this FastAPI application, make sure you have FastAPI and Uvicorn installed. +- Save this script as 'fastapi_incr.py'. +- Run the application using the command: `uvicorn fastapi_incr:app --reload --port PORT_NUMBER` +- The server will start on `http://localhost:PORT_NUMBER`. Use this base URL to make API requests. +- Go to `http://localhost:PORT_NUMBER/docs` for API documentation. +""" + + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +import flexflow.serve as ff +import uvicorn +import json, os, argparse +from types import SimpleNamespace + +# Initialize FastAPI application +app = FastAPI() + +# Define the request model +class PromptRequest(BaseModel): + prompt: str + +# Global variable to store the LLM model +llm = None + + +def get_configs(): + + # Fetch configuration file path from environment variable + config_file = os.getenv("CONFIG_FILE", "") + + # Load configs from JSON file (if specified) + if config_file: + if not os.path.isfile(config_file): + raise FileNotFoundError(f"Config file {config_file} not found.") + try: + with open(config_file) as f: + return json.load(f) + except json.JSONDecodeError as e: + print("JSON format error:") + print(e) + else: + # Define sample configs + ff_init_configs = { + # required parameters + "num_gpus": 2, + "memory_per_gpu": 14000, + "zero_copy_memory_per_node": 40000, + # optional parameters + "num_cpus": 4, + "legion_utility_processors": 4, + "data_parallelism_degree": 1, + "tensor_parallelism_degree": 1, + "pipeline_parallelism_degree": 2, + "offload": False, + "offload_reserve_space_size": 1024**2, + "use_4bit_quantization": False, + "use_8bit_quantization": False, + "profiling": False, + "inference_debugging": False, + "fusion": True, + } + llm_configs = { + # required parameters + "llm_model": "tiiuae/falcon-7b", + # optional parameters + "cache_path": "", + "refresh_cache": False, + "full_precision": False, + "prompt": "", + "output_file": "", + } + # Merge dictionaries + ff_init_configs.update(llm_configs) + return ff_init_configs + + +# Initialize model on startup +@app.on_event("startup") +async def startup_event(): + global llm + + # Initialize your LLM model configuration here + configs_dict = get_configs() + configs = SimpleNamespace(**configs_dict) + ff.init(configs_dict) + + ff_data_type = ff.DataType.DT_FLOAT if configs.full_precision else ff.DataType.DT_HALF + llm = ff.LLM( + configs.llm_model, + data_type=ff_data_type, + cache_path=configs.cache_path, + refresh_cache=configs.refresh_cache, + output_file=configs.output_file, + ) + + generation_config = ff.GenerationConfig( + do_sample=False, temperature=0.9, topp=0.8, topk=1 + ) + llm.compile( + generation_config, + max_requests_per_batch=1, + max_seq_length=256, + max_tokens_per_batch=64, + ) + llm.start_server() + +# API endpoint to generate response +@app.post("/generate/") +async def generate(prompt_request: PromptRequest): + if llm is None: + raise HTTPException(status_code=503, detail="LLM model is not initialized.") + + # Call the model to generate a response + full_output = llm.generate([prompt_request.prompt])[0].output_text.decode('utf-8') + + # Separate the prompt and response + split_output = full_output.split('\n', 1) + if len(split_output) > 1: + response_text = split_output[1] + else: + response_text = "" + + # Return the prompt and the response in JSON format + return { + "prompt": prompt_request.prompt, + "response": response_text + } + +# Shutdown event to stop the model server +@app.on_event("shutdown") +async def shutdown_event(): + global llm + if llm is not None: + llm.stop_server() + +# Main function to run Uvicorn server +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) + +# Running within the entrypoint folder: +# uvicorn fastapi_incr:app --reload --port + +# Running within the python folder: +# uvicorn entrypoint.fastapi_incr:app --reload --port 3000 diff --git a/inference/python/entrypoint/fastapi_specinfer.py b/inference/python/entrypoint/fastapi_specinfer.py new file mode 100644 index 0000000000..416aee6dc5 --- /dev/null +++ b/inference/python/entrypoint/fastapi_specinfer.py @@ -0,0 +1,202 @@ +# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Running Instructions: +- To run this FastAPI application, make sure you have FastAPI and Uvicorn installed. +- Save this script as 'fastapi_specinfer.py'. +- Run the application using the command: `uvicorn fastapi_specinfer:app --reload --port PORT_NUMBER` +- The server will start on `http://localhost:PORT_NUMBER`. Use this base URL to make API requests. +- Go to `http://localhost:PORT_NUMBER/docs` for API documentation. +""" + + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +import flexflow.serve as ff +import uvicorn +import json, os, argparse +from types import SimpleNamespace + +# Initialize FastAPI application +app = FastAPI() + +# Define the request model +class PromptRequest(BaseModel): + prompt: str + +# Global variable to store the LLM model +llm = None + +def get_configs(): + # Fetch configuration file path from environment variable + config_file = os.getenv("CONFIG_FILE", "") + + # Load configs from JSON file (if specified) + if config_file: + if not os.path.isfile(config_file): + raise FileNotFoundError(f"Config file {config_file} not found.") + try: + with open(config_file) as f: + return json.load(f) + except json.JSONDecodeError as e: + print("JSON format error:") + print(e) + else: + # Define sample configs + ff_init_configs = { + # required parameters + "num_gpus": 2, + "memory_per_gpu": 14000, + "zero_copy_memory_per_node": 40000, + # optional parameters + "num_cpus": 4, + "legion_utility_processors": 4, + "data_parallelism_degree": 1, + "tensor_parallelism_degree": 1, + "pipeline_parallelism_degree": 2, + "offload": False, + "offload_reserve_space_size": 1024**2, + "use_4bit_quantization": False, + "use_8bit_quantization": False, + "profiling": False, + "inference_debugging": False, + "fusion": True, + } + llm_configs = { + # required llm arguments + "llm_model": "meta-llama/Llama-2-7b-hf", + # optional llm parameters + "cache_path": "", + "refresh_cache": False, + "full_precision": False, + "ssms": [ + { + # required ssm parameter + "ssm_model": "JackFram/llama-160m", + # optional ssm parameters + "cache_path": "", + "refresh_cache": False, + "full_precision": False, + } + ], + # "prompt": "", + "output_file": "", + } + # Merge dictionaries + ff_init_configs.update(llm_configs) + return ff_init_configs + +# Initialize model on startup +@app.on_event("startup") +async def startup_event(): + global llm + + # Initialize your LLM model configuration here + configs_dict = get_configs() + configs = SimpleNamespace(**configs_dict) + ff.init(configs_dict) + + # Create the FlexFlow LLM + ff_data_type = ( + ff.DataType.DT_FLOAT if configs.full_precision else ff.DataType.DT_HALF + ) + llm = ff.LLM( + configs.llm_model, + data_type=ff_data_type, + cache_path=configs.cache_path, + refresh_cache=configs.refresh_cache, + output_file=configs.output_file, + ) + + # Create the SSMs + ssms = [] + for ssm_config in configs.ssms: + ssm_config = SimpleNamespace(**ssm_config) + ff_data_type = ( + ff.DataType.DT_FLOAT if ssm_config.full_precision else ff.DataType.DT_HALF + ) + ssm = ff.SSM( + ssm_config.ssm_model, + data_type=ff_data_type, + cache_path=ssm_config.cache_path, + refresh_cache=ssm_config.refresh_cache, + output_file=configs.output_file, + ) + ssms.append(ssm) + + # Create the sampling configs + generation_config = ff.GenerationConfig( + do_sample=False, temperature=0.9, topp=0.8, topk=1 + ) + + # Compile the SSMs for inference and load the weights into memory + for ssm in ssms: + ssm.compile( + generation_config, + max_requests_per_batch=1, + max_seq_length=256, + max_tokens_per_batch=64, + ) + + # Compile the LLM for inference and load the weights into memory + llm.compile( + generation_config, + max_requests_per_batch=1, + max_seq_length=256, + max_tokens_per_batch=64, + ssms=ssms, + ) + + llm.start_server() + +# API endpoint to generate response +@app.post("/generate/") +async def generate(prompt_request: PromptRequest): + if llm is None: + raise HTTPException(status_code=503, detail="LLM model is not initialized.") + + # Call the model to generate a response + full_output = llm.generate([prompt_request.prompt])[0].output_text.decode('utf-8') + + # Separate the prompt and response + split_output = full_output.split('\n', 1) + if len(split_output) > 1: + response_text = split_output[1] + else: + response_text = "" + + # Return the prompt and the response in JSON format + return { + "prompt": prompt_request.prompt, + "response": response_text + } + +# Shutdown event to stop the model server +@app.on_event("shutdown") +async def shutdown_event(): + global llm + if llm is not None: + llm.stop_server() + +# Main function to run Uvicorn server +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) + +# Running within the entrypoint folder: +# uvicorn fastapi_specinfer:app --reload --port + +# Running within the python folder: +# uvicorn entrypoint.fastapi_specinfer:app --reload --port 3000 diff --git a/inference/python/incr_decoding.py b/inference/python/incr_decoding.py index 6706cf3c29..f7707816c8 100644 --- a/inference/python/incr_decoding.py +++ b/inference/python/incr_decoding.py @@ -41,7 +41,7 @@ def get_configs(): # Define sample configs ff_init_configs = { # required parameters - "num_gpus": 4, + "num_gpus": 2, "memory_per_gpu": 14000, "zero_copy_memory_per_node": 40000, # optional parameters @@ -49,7 +49,7 @@ def get_configs(): "legion_utility_processors": 4, "data_parallelism_degree": 1, "tensor_parallelism_degree": 1, - "pipeline_parallelism_degree": 4, + "pipeline_parallelism_degree": 2, "offload": False, "offload_reserve_space_size": 1024**2, "use_4bit_quantization": False, @@ -64,7 +64,7 @@ def get_configs(): # optional parameters "cache_path": "", "refresh_cache": False, - "full_precision": True, + "full_precision": False, "prompt": "", "output_file": "", } diff --git a/inference/python/spec_infer.py b/inference/python/spec_infer.py index 8b9a116dc5..fcb1b8f891 100644 --- a/inference/python/spec_infer.py +++ b/inference/python/spec_infer.py @@ -41,14 +41,14 @@ def get_configs(): # Define sample configs ff_init_configs = { # required parameters - "num_gpus": 4, + "num_gpus": 2, "memory_per_gpu": 14000, "zero_copy_memory_per_node": 40000, # optional parameters "num_cpus": 4, "legion_utility_processors": 4, "data_parallelism_degree": 1, - "tensor_parallelism_degree": 2, + "tensor_parallelism_degree": 1, "pipeline_parallelism_degree": 2, "offload": False, "offload_reserve_space_size": 1024**2, @@ -75,7 +75,7 @@ def get_configs(): "full_precision": False, } ], - "prompt": "", + # "prompt": "", "output_file": "", } # Merge dictionaries diff --git a/inference/python/usecases/gradio_incr.py b/inference/python/usecases/gradio_incr.py new file mode 100644 index 0000000000..2735b665bb --- /dev/null +++ b/inference/python/usecases/gradio_incr.py @@ -0,0 +1,162 @@ +# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Functionality: +1. Configuration Handling: + - Parses command-line arguments to get a configuration file path. + - Loads configuration settings from a JSON file if provided, or uses default settings. + +2. FlexFlow Model Initialization: + - Initializes FlexFlow with the provided or default configurations. + - Sets up the LLM with the specified model and configurations. + - Compiles the model with generation settings and starts the FlexFlow server. + +3. Gradio Interface Setup: + - Defines a function to generate responses based on user input using FlexFlow. + - Sets up a Gradio Chat Interface to interact with the model in a conversational format. + +4. Main Execution: + - Calls the main function to initialize configurations, set up the FlexFlow LLM, and launch the Gradio interface. + - Stops the FlexFlow server after the Gradio interface is closed. + +Usage: +1. Run the script with an optional configuration file argument for custom settings. +2. Interact with the FlexFlow model through the Gradio web interface. +3. Enter text inputs to receive generated responses from the model. +4. The script will stop the FlexFlow server automatically upon closing the Gradio interface. +""" + +import gradio as gr +import flexflow.serve as ff +import argparse, json, os +from types import SimpleNamespace + + +def get_configs(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-config-file", + help="The path to a JSON file with the configs. If omitted, a sample model and configs will be used instead.", + type=str, + default="", + ) + args = parser.parse_args() + + # Load configs from JSON file (if specified) + if len(args.config_file) > 0: + if not os.path.isfile(args.config_file): + raise FileNotFoundError(f"Config file {args.config_file} not found.") + try: + with open(args.config_file) as f: + return json.load(f) + except json.JSONDecodeError as e: + print("JSON format error:") + print(e) + else: + # Define sample configs + ff_init_configs = { + # required parameters + "num_gpus": 2, + "memory_per_gpu": 14000, + "zero_copy_memory_per_node": 40000, + # optional parameters + "num_cpus": 4, + "legion_utility_processors": 4, + "data_parallelism_degree": 1, + "tensor_parallelism_degree": 1, + "pipeline_parallelism_degree": 2, + "offload": False, + "offload_reserve_space_size": 1024**2, + "use_4bit_quantization": False, + "use_8bit_quantization": False, + "profiling": False, + "inference_debugging": False, + "fusion": True, + } + llm_configs = { + # required parameters + "llm_model": "tiiuae/falcon-7b", + # optional parameters + "cache_path": "", + "refresh_cache": False, + "full_precision": False, + "prompt": "", + "output_file": "", + } + # Merge dictionaries + ff_init_configs.update(llm_configs) + return ff_init_configs + + +# def generate_response(user_input): +# result = llm.generate(user_input) +# return result.output_text.decode('utf-8') + +def generate_response(message, history): + user_input = message + results = llm.generate(user_input) + if isinstance(results, list): + result_txt = results[0].output_text.decode('utf-8') + else: + result_txt = results.output_text.decode('utf-8') + return result_txt + + +def main(): + + global llm + + configs_dict = get_configs() + configs = SimpleNamespace(**configs_dict) + + ff.init(configs_dict) + + ff_data_type = ( + ff.DataType.DT_FLOAT if configs.full_precision else ff.DataType.DT_HALF + ) + llm = ff.LLM( + configs.llm_model, + data_type=ff_data_type, + cache_path=configs.cache_path, + refresh_cache=configs.refresh_cache, + output_file=configs.output_file, + ) + + generation_config = ff.GenerationConfig( + do_sample=False, temperature=0.9, topp=0.8, topk=1 + ) + llm.compile( + generation_config, + max_requests_per_batch=1, + max_seq_length=256, + max_tokens_per_batch=64, + ) + + # # interface version 1 + # iface = gr.Interface( + # fn=generate_response, + # inputs="text", + # outputs="text" + # ) + + # interface version 2 + iface = gr.ChatInterface(fn=generate_response) + llm.start_server() + iface.launch() + llm.stop_server() + +if __name__ == "__main__": + print("flexflow inference example with gradio interface") + main() \ No newline at end of file diff --git a/inference/python/usecases/gradio_specinfer.py b/inference/python/usecases/gradio_specinfer.py new file mode 100644 index 0000000000..08cde3f00b --- /dev/null +++ b/inference/python/usecases/gradio_specinfer.py @@ -0,0 +1,205 @@ +# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Functionality: +1. Configuration Handling: + - Parses command-line arguments to get a configuration file path. + - Loads configuration settings from a JSON file if provided, or uses default settings. + +2. FlexFlow Model Initialization: + - Initializes FlexFlow with the provided or default configurations. + - Sets up the LLM with the specified model and configurations. + - Compiles the model with generation settings and starts the FlexFlow server. + +3. Gradio Interface Setup: + - Defines a function to generate responses based on user input using FlexFlow. + - Sets up a Gradio Chat Interface to interact with the model in a conversational format. + +4. Main Execution: + - Calls the main function to initialize configurations, set up the FlexFlow LLM, and launch the Gradio interface. + - Stops the FlexFlow server after the Gradio interface is closed. + +Usage: +1. Run the script with an optional configuration file argument for custom settings. +2. Interact with the FlexFlow model through the Gradio web interface. +3. Enter text inputs to receive generated responses from the model. +4. The script will stop the FlexFlow server automatically upon closing the Gradio interface. +""" + +""" +TODO: fix current issue: model init is stuck at "prepare next batch init" and "prepare next batch verify" +""" + +import gradio as gr +import flexflow.serve as ff +import argparse, json, os +from types import SimpleNamespace + +def get_configs(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-config-file", + help="The path to a JSON file with the configs. If omitted, a sample model and configs will be used instead.", + type=str, + default="", + ) + args = parser.parse_args() + + # Load configs from JSON file (if specified) + if len(args.config_file) > 0: + if not os.path.isfile(args.config_file): + raise FileNotFoundError(f"Config file {args.config_file} not found.") + try: + with open(args.config_file) as f: + return json.load(f) + except json.JSONDecodeError as e: + print("JSON format error:") + print(e) + else: + # Define sample configs + ff_init_configs = { + # required parameters + "num_gpus": 2, + "memory_per_gpu": 14000, + "zero_copy_memory_per_node": 40000, + # optional parameters + "num_cpus": 4, + "legion_utility_processors": 4, + "data_parallelism_degree": 1, + "tensor_parallelism_degree": 1, + "pipeline_parallelism_degree": 2, + "offload": False, + "offload_reserve_space_size": 1024**2, + "use_4bit_quantization": False, + "use_8bit_quantization": False, + "profiling": False, + "inference_debugging": False, + "fusion": True, + } + llm_configs = { + # required llm arguments + "llm_model": "meta-llama/Llama-2-7b-hf", + # optional llm parameters + "cache_path": "", + "refresh_cache": False, + "full_precision": False, + "ssms": [ + { + # required ssm parameter + "ssm_model": "JackFram/llama-160m", + # optional ssm parameters + "cache_path": "", + "refresh_cache": False, + "full_precision": False, + } + ], + # "prompt": "", + "output_file": "", + } + # Merge dictionaries + ff_init_configs.update(llm_configs) + return ff_init_configs + + +# def generate_response(user_input): +# result = llm.generate(user_input) +# return result.output_text.decode('utf-8') + +def generate_response(message, history): + user_input = message + results = llm.generate(user_input) + if isinstance(results, list): + result_txt = results[0].output_text.decode('utf-8') + else: + result_txt = results.output_text.decode('utf-8') + return result_txt + +def main(): + + global llm + + configs_dict = get_configs() + configs = SimpleNamespace(**configs_dict) + + # Initialize the FlexFlow runtime. ff.init() takes a dictionary or the path to a JSON file with the configs + ff.init(configs_dict) + + # Create the FlexFlow LLM + ff_data_type = ( + ff.DataType.DT_FLOAT if configs.full_precision else ff.DataType.DT_HALF + ) + llm = ff.LLM( + configs.llm_model, + data_type=ff_data_type, + cache_path=configs.cache_path, + refresh_cache=configs.refresh_cache, + output_file=configs.output_file, + ) + + # Create the SSMs + ssms = [] + for ssm_config in configs.ssms: + ssm_config = SimpleNamespace(**ssm_config) + ff_data_type = ( + ff.DataType.DT_FLOAT if ssm_config.full_precision else ff.DataType.DT_HALF + ) + ssm = ff.SSM( + ssm_config.ssm_model, + data_type=ff_data_type, + cache_path=ssm_config.cache_path, + refresh_cache=ssm_config.refresh_cache, + output_file=configs.output_file, + ) + ssms.append(ssm) + + # Create the sampling configs + generation_config = ff.GenerationConfig( + do_sample=False, temperature=0.9, topp=0.8, topk=1 + ) + + # Compile the SSMs for inference and load the weights into memory + for ssm in ssms: + ssm.compile( + generation_config, + max_requests_per_batch=1, + max_seq_length=256, + max_tokens_per_batch=256, + ) + + # Compile the LLM for inference and load the weights into memory + llm.compile( + generation_config, + max_requests_per_batch=1, + max_seq_length=256, + max_tokens_per_batch=256, + ssms=ssms, + ) + + # # interface version 1 + # iface = gr.Interface( + # fn=generate_response, + # inputs="text", + # outputs="text" + # ) + + # interface version 2 + iface = gr.ChatInterface(fn=generate_response) + llm.start_server() + iface.launch() + llm.stop_server() + +if __name__ == "__main__": + print("flexflow inference example with gradio interface") + main() \ No newline at end of file diff --git a/inference/python/usecases/prompt_template_incr.py b/inference/python/usecases/prompt_template_incr.py new file mode 100644 index 0000000000..8bffe9ddad --- /dev/null +++ b/inference/python/usecases/prompt_template_incr.py @@ -0,0 +1,187 @@ +# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script implements the usecase of prompt template upon FlexFlow. + +Functionality: +1. FlexFlowLLM Class: + - Initializes and configures FlexFlow. + - Loads configurations from a file or uses default settings. + - Compiles and starts the language model server for text generation. + - Stops the server when operations are complete. + +2. FF_LLM_wrapper Class: + - Serves as a wrapper for FlexFlow. + - Implements the necessary interface to interact with the LangChain library. + +3. Main: + - Initializes FlexFlow. + - Compiles and starts the server with specific generation configurations. + - Sets up a prompt template for generating responses to questions. + - Use LLMChain to run the model and generate response. + - Stops the FlexFlow server after generating the response. +""" + +import flexflow.serve as ff +import argparse, json, os +from types import SimpleNamespace +from langchain.llms.base import LLM +from typing import Any, List, Mapping, Optional +from langchain.chains import LLMChain +from langchain.prompts import PromptTemplate + +class FlexFlowLLM: + def __init__(self, config_file=""): + self.configs = self.get_configs(config_file) + ff.init(self.configs) + self.llm = self.create_llm() + + def get_configs(self, config_file): + # Load configurations from a file or use default settings + if config_file and os.path.isfile(config_file): + with open(config_file) as f: + return json.load(f) + else: + # Define sample configs + ff_init_configs = { + # required parameters + "num_gpus": 2, + "memory_per_gpu": 14000, + "zero_copy_memory_per_node": 40000, + # optional parameters + "num_cpus": 4, + "legion_utility_processors": 4, + "data_parallelism_degree": 1, + "tensor_parallelism_degree": 1, + "pipeline_parallelism_degree": 2, + "offload": False, + "offload_reserve_space_size": 1024**2, + "use_4bit_quantization": False, + "use_8bit_quantization": False, + "profiling": False, + "inference_debugging": False, + "fusion": True, + } + llm_configs = { + # required parameters + "llm_model": "tiiuae/falcon-7b", + # optional parameters + "cache_path": "", + "refresh_cache": False, + "full_precision": False, + "prompt": "", + "output_file": "", + } + # Merge dictionaries + ff_init_configs.update(llm_configs) + return ff_init_configs + + def create_llm(self): + configs = SimpleNamespace(**self.configs) + ff_data_type = ff.DataType.DT_FLOAT if configs.full_precision else ff.DataType.DT_HALF + llm = ff.LLM( + configs.llm_model, + data_type=ff_data_type, + cache_path=configs.cache_path, + refresh_cache=configs.refresh_cache, + output_file=configs.output_file, + ) + return llm + + def compile_and_start(self, generation_config, max_requests_per_batch, max_seq_length, max_tokens_per_batch): + self.llm.compile(generation_config, max_requests_per_batch, max_seq_length, max_tokens_per_batch) + self.llm.start_server() + + def generate(self, prompt): + results = self.llm.generate(prompt) + if isinstance(results, list): + result_txt = results[0].output_text.decode('utf-8') + else: + result_txt = results.output_text.decode('utf-8') + return result_txt + + def stop_server(self): + self.llm.stop_server() + + def __enter__(self): + return self.llm.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + return self.llm.__exit__(exc_type, exc_value, traceback) + +class FF_LLM_wrapper(LLM): + flexflow_llm: FlexFlowLLM + + @property + def _llm_type(self) -> str: + return "custom" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> str: + if stop is not None: + raise ValueError("stop kwargs are not permitted.") + response = self.flexflow_llm.generate(prompt) + return response + + +if __name__ == "__main__": + # initialization + ff_llm = FlexFlowLLM() + + # compile and start server + gen_config = ff.GenerationConfig(do_sample=False, temperature=0.9, topp=0.8, topk=1) + ff_llm.compile_and_start( + gen_config, + max_requests_per_batch=1, + max_seq_length=256, + max_tokens_per_batch=64 + ) + + # the wrapper class serves as the 'Model' in LCEL + ff_llm_wrapper = FF_LLM_wrapper(flexflow_llm=ff_llm) + + # USE CASE 1: Prompt Template + template = """Question: {question} + Answer: Let's think step by step.""" + + # Build prompt template and langchain + prompt = PromptTemplate(template=template, input_variables=["question"]) + llm_chain = LLMChain(prompt=prompt, llm=ff_llm_wrapper) + + question = "Who was the US president in the year the first Pokemon game was released?" + print(llm_chain.run(question)) + + # stop the server + ff_llm.stop_server() + diff --git a/inference/python/usecases/prompt_template_specinfer.py b/inference/python/usecases/prompt_template_specinfer.py new file mode 100644 index 0000000000..dfc92e9ac2 --- /dev/null +++ b/inference/python/usecases/prompt_template_specinfer.py @@ -0,0 +1,236 @@ +# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script implements the usecase of prompt template upon FlexFlow. + +Functionality: +1. FlexFlowLLM Class: + - Initializes and configures FlexFlow. + - Loads configurations from a file or uses default settings. + - Compiles and starts the language model server for text generation. + - Stops the server when operations are complete. + +2. FF_LLM_wrapper Class: + - Serves as a wrapper for FlexFlow. + - Implements the necessary interface to interact with the LangChain library. + +3. Main: + - Initializes FlexFlow. + - Compiles and starts the server with specific generation configurations. + - Sets up a prompt template for generating responses to questions. + - Use LLMChain to run the model and generate response. + - Stops the FlexFlow server after generating the response. +""" + +import flexflow.serve as ff +import argparse, json, os +from types import SimpleNamespace +from langchain.llms.base import LLM +from typing import Any, List, Mapping, Optional +from langchain.chains import LLMChain +from langchain.prompts import PromptTemplate + + +class FlexFlowLLM: + def __init__(self, config_file=""): + self.configs = self.get_configs(config_file) + ff.init(self.configs) + self.llm = self.create_llm() + self.ssms = self.create_ssms() + + def get_configs(self, config_file): + # Load configurations from a file or use default settings + if config_file and os.path.isfile(config_file): + with open(config_file) as f: + return json.load(f) + else: + # Define sample configs + ff_init_configs = { + # required parameters + "num_gpus": 2, + "memory_per_gpu": 14000, + "zero_copy_memory_per_node": 40000, + # optional parameters + "num_cpus": 4, + "legion_utility_processors": 4, + "data_parallelism_degree": 1, + "tensor_parallelism_degree": 1, + "pipeline_parallelism_degree": 2, + "offload": False, + "offload_reserve_space_size": 1024**2, + "use_4bit_quantization": False, + "use_8bit_quantization": False, + "profiling": False, + "inference_debugging": False, + "fusion": True, + } + llm_configs = { + # required llm arguments + "llm_model": "meta-llama/Llama-2-7b-hf", + # optional llm parameters + "cache_path": "", + "refresh_cache": False, + "full_precision": False, + "ssms": [ + { + # required ssm parameter + "ssm_model": "JackFram/llama-160m", + # optional ssm parameters + "cache_path": "", + "refresh_cache": False, + "full_precision": False, + } + ], + # "prompt": "", + "output_file": "", + } + # Merge dictionaries + ff_init_configs.update(llm_configs) + return ff_init_configs + + def create_llm(self): + configs = SimpleNamespace(**self.configs) + ff_data_type = ff.DataType.DT_FLOAT if configs.full_precision else ff.DataType.DT_HALF + llm = ff.LLM( + configs.llm_model, + data_type=ff_data_type, + cache_path=configs.cache_path, + refresh_cache=configs.refresh_cache, + output_file=configs.output_file, + ) + return llm + + def create_ssms(self): + # Create the SSMs + configs = SimpleNamespace(**self.configs) + ssms = [] + for ssm_config in configs.ssms: + ssm_config = SimpleNamespace(**ssm_config) + ff_data_type = ( + ff.DataType.DT_FLOAT if ssm_config.full_precision else ff.DataType.DT_HALF + ) + ssm = ff.SSM( + ssm_config.ssm_model, + data_type=ff_data_type, + cache_path=ssm_config.cache_path, + refresh_cache=ssm_config.refresh_cache, + output_file=configs.output_file, + ) + ssms.append(ssm) + return ssms + + def compile_and_start(self, generation_config, max_requests_per_batch, max_seq_length, max_tokens_per_batch): + + # Compile the SSMs for inference and load the weights into memory + for ssm in self.ssms: + ssm.compile( + generation_config, + max_requests_per_batch, + max_seq_length, + max_tokens_per_batch, + ) + + # Compile the LLM for inference and load the weights into memory + self.llm.compile( + generation_config, + max_requests_per_batch, + max_seq_length, + max_tokens_per_batch, + ssms = self.ssms + ) + self.llm.start_server() + + def generate(self, prompt): + results = self.llm.generate(prompt) + if isinstance(results, list): + result_txt = results[0].output_text.decode('utf-8') + else: + result_txt = results.output_text.decode('utf-8') + return result_txt + + def stop_server(self): + self.llm.stop_server() + + def __enter__(self): + return self.llm.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + return self.llm.__exit__(exc_type, exc_value, traceback) + +class FF_LLM_wrapper(LLM): + flexflow_llm: FlexFlowLLM + + @property + def _llm_type(self) -> str: + return "custom" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> str: + if stop is not None: + raise ValueError("stop kwargs are not permitted.") + response = self.flexflow_llm.generate(prompt) + return response + + +if __name__ == "__main__": + # initialization + ff_llm = FlexFlowLLM() + + # compile and start server + gen_config = ff.GenerationConfig(do_sample=False, temperature=0.9, topp=0.8, topk=1) + ff_llm.compile_and_start( + gen_config, + max_requests_per_batch=1, + max_seq_length=256, + max_tokens_per_batch=64 + ) + + # the wrapper class serves as the 'Model' in LCEL + ff_llm_wrapper = FF_LLM_wrapper(flexflow_llm=ff_llm) + + # USE CASE 1: Prompt Template + template = """Question: {question} + Answer: Let's think step by step.""" + + # Build prompt template and langchain + prompt = PromptTemplate(template=template, input_variables=["question"]) + llm_chain = LLMChain(prompt=prompt, llm=ff_llm_wrapper) + + question = "Who was the US president in the year the first Pokemon game was released?" + print(llm_chain.run(question)) + + # stop the server + ff_llm.stop_server() + + diff --git a/inference/python/usecases/rag_incr.py b/inference/python/usecases/rag_incr.py new file mode 100644 index 0000000000..15e7f3d092 --- /dev/null +++ b/inference/python/usecases/rag_incr.py @@ -0,0 +1,220 @@ +# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script implements the usecase of rag-search upon FlexFlow. + +Functionality: +1. FlexFlowLLM Class: + - Initializes and configures FlexFlow. + - Loads configurations from a file or uses default settings. + - Compiles and starts the language model server for text generation. + - Stops the server when operations are complete. + +2. FF_LLM_wrapper Class: + - Serves as a wrapper for FlexFlow. + - Implements the necessary interface to interact with the LangChain library. + +3. Main: + - Initializes FlexFlow. + - Compiles and starts the server with specific generation configurations. + - Taking in specific source information with RAG(Retrieval Augmented Generation) technique for Q&A towards specific realm/knowledgebase. + - Use LLMChain to run the model and generate response. + - Stops the FlexFlow server after generating the response. +""" + +import flexflow.serve as ff +import argparse, json, os +from types import SimpleNamespace +from langchain.llms.base import LLM +from typing import Any, List, Mapping, Optional +from langchain.chains import LLMChain +from langchain.prompts import PromptTemplate +from langchain.document_loaders import WebBaseLoader +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain.embeddings import OpenAIEmbeddings +from langchain.vectorstores import Chroma +from langchain.vectorstores import FAISS + +class FlexFlowLLM: + def __init__(self, config_file=""): + self.configs = self.get_configs(config_file) + ff.init(self.configs) + self.llm = self.create_llm() + + def get_configs(self, config_file): + # Load configurations from a file or use default settings + if config_file and os.path.isfile(config_file): + with open(config_file) as f: + return json.load(f) + else: + # Define sample configs + ff_init_configs = { + # required parameters + "num_gpus": 2, + "memory_per_gpu": 14000, + "zero_copy_memory_per_node": 40000, + # optional parameters + "num_cpus": 4, + "legion_utility_processors": 4, + "data_parallelism_degree": 1, + "tensor_parallelism_degree": 1, + "pipeline_parallelism_degree": 2, + "offload": False, + "offload_reserve_space_size": 1024**2, + "use_4bit_quantization": False, + "use_8bit_quantization": False, + "profiling": False, + "inference_debugging": False, + "fusion": True, + } + llm_configs = { + # required parameters + "llm_model": "tiiuae/falcon-7b", + # optional parameters + "cache_path": "", + "refresh_cache": False, + "full_precision": False, + "prompt": "", + "output_file": "", + } + # Merge dictionaries + ff_init_configs.update(llm_configs) + return ff_init_configs + + def create_llm(self): + configs = SimpleNamespace(**self.configs) + ff_data_type = ff.DataType.DT_FLOAT if configs.full_precision else ff.DataType.DT_HALF + llm = ff.LLM( + configs.llm_model, + data_type=ff_data_type, + cache_path=configs.cache_path, + refresh_cache=configs.refresh_cache, + output_file=configs.output_file, + ) + return llm + + def compile_and_start(self, generation_config, max_requests_per_batch, max_seq_length, max_tokens_per_batch): + self.llm.compile(generation_config, max_requests_per_batch, max_seq_length, max_tokens_per_batch) + self.llm.start_server() + + def generate(self, prompt): + results = self.llm.generate(prompt) + if isinstance(results, list): + result_txt = results[0].output_text.decode('utf-8') + else: + result_txt = results.output_text.decode('utf-8') + return result_txt + + def stop_server(self): + self.llm.stop_server() + + def __enter__(self): + return self.llm.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + return self.llm.__exit__(exc_type, exc_value, traceback) + + +class FF_LLM_wrapper(LLM): + flexflow_llm: FlexFlowLLM + + @property + def _llm_type(self) -> str: + return "custom" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> str: + if stop is not None: + raise ValueError("stop kwargs are not permitted.") + response = self.flexflow_llm.generate(prompt) + return response + + +if __name__ == "__main__": + # initialization + ff_llm = FlexFlowLLM() + + # compile and start server + gen_config = ff.GenerationConfig(do_sample=False, temperature=0.9, topp=0.8, topk=1) + ff_llm.compile_and_start( + gen_config, + max_requests_per_batch=1, + max_seq_length=256, + max_tokens_per_batch=64 + ) + + # the wrapper class serves as the 'Model' in LCEL + ff_llm_wrapper = FF_LLM_wrapper(flexflow_llm=ff_llm) + + # USE CASE 2: Rag Search + + # Load web page content + loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/") + data = loader.load() + + # Split text + text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0) + all_splits = text_splitter.split_documents(data) + + # Initialize embeddings + embeddings = OpenAIEmbeddings(openai_api_key=os.getenv('OPENAI_API_KEY')) # fill in openai api key + + # Create VectorStore + vectorstore = Chroma.from_documents(all_splits, embeddings) + + # Use VectorStore as a retriever + retriever = vectorstore.as_retriever() + + # Test if similarity search is working + question = "What are the approaches to Task Decomposition?" + docs = vectorstore.similarity_search(question) + max_chars_per_doc = 100 + # docs_text_list = [docs[i].page_content for i in range(len(docs))] + docs_text_list = [docs[i].page_content[:max_chars_per_doc] for i in range(len(docs))] + docs_text = ''.join(docs_text_list) + + # Using a Prompt Template + prompt_rag = PromptTemplate.from_template( + "Summarize the main themes in these retrieved docs: {docs_text}" + ) + + # Chain + llm_chain_rag = LLMChain(llm=ff_llm_wrapper, prompt=prompt_rag) + + # Run + rag_result = llm_chain_rag(docs_text) + + # Stop the server + ff_llm.stop_server() + diff --git a/inference/python/usecases/rag_specinfer.py b/inference/python/usecases/rag_specinfer.py new file mode 100644 index 0000000000..512b973955 --- /dev/null +++ b/inference/python/usecases/rag_specinfer.py @@ -0,0 +1,266 @@ +# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script implements the usecase of rag-search upon FlexFlow. + +Functionality: +1. FlexFlowLLM Class: + - Initializes and configures FlexFlow. + - Loads configurations from a file or uses default settings. + - Compiles and starts the language model server for text generation. + - Stops the server when operations are complete. + +2. FF_LLM_wrapper Class: + - Serves as a wrapper for FlexFlow. + - Implements the necessary interface to interact with the LangChain library. + +3. Main: + - Initializes FlexFlow. + - Compiles and starts the server with specific generation configurations. + - Taking in specific source information with RAG(Retrieval Augmented Generation) technique for Q&A towards specific realm/knowledgebase. + - Use LLMChain to run the model and generate response. + - Stops the FlexFlow server after generating the response. +""" + +import flexflow.serve as ff +import argparse, json, os +from types import SimpleNamespace +from langchain.llms.base import LLM +from typing import Any, List, Mapping, Optional +from langchain.chains import LLMChain +from langchain.prompts import PromptTemplate +from langchain.document_loaders import WebBaseLoader +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain.embeddings import OpenAIEmbeddings +from langchain.vectorstores import Chroma +from langchain.vectorstores import FAISS + +class FlexFlowLLM: + def __init__(self, config_file=""): + self.configs = self.get_configs(config_file) + ff.init(self.configs) + self.llm = self.create_llm() + self.ssms = self.create_ssms() + + def get_configs(self, config_file): + # Load configurations from a file or use default settings + if config_file and os.path.isfile(config_file): + with open(config_file) as f: + return json.load(f) + else: + # Define sample configs + ff_init_configs = { + # required parameters + "num_gpus": 2, + "memory_per_gpu": 14000, + "zero_copy_memory_per_node": 40000, + # optional parameters + "num_cpus": 4, + "legion_utility_processors": 4, + "data_parallelism_degree": 1, + "tensor_parallelism_degree": 1, + "pipeline_parallelism_degree": 2, + "offload": False, + "offload_reserve_space_size": 1024**2, + "use_4bit_quantization": False, + "use_8bit_quantization": False, + "profiling": False, + "inference_debugging": False, + "fusion": True, + } + llm_configs = { + # required llm arguments + "llm_model": "meta-llama/Llama-2-7b-hf", + # optional llm parameters + "cache_path": "", + "refresh_cache": False, + "full_precision": False, + "ssms": [ + { + # required ssm parameter + "ssm_model": "JackFram/llama-160m", + # optional ssm parameters + "cache_path": "", + "refresh_cache": False, + "full_precision": False, + } + ], + # "prompt": "", + "output_file": "", + } + # Merge dictionaries + ff_init_configs.update(llm_configs) + return ff_init_configs + + def create_llm(self): + configs = SimpleNamespace(**self.configs) + ff_data_type = ff.DataType.DT_FLOAT if configs.full_precision else ff.DataType.DT_HALF + llm = ff.LLM( + configs.llm_model, + data_type=ff_data_type, + cache_path=configs.cache_path, + refresh_cache=configs.refresh_cache, + output_file=configs.output_file, + ) + return llm + + def create_ssms(self): + # Create the SSMs + configs = SimpleNamespace(**self.configs) + ssms = [] + for ssm_config in configs.ssms: + ssm_config = SimpleNamespace(**ssm_config) + ff_data_type = ( + ff.DataType.DT_FLOAT if ssm_config.full_precision else ff.DataType.DT_HALF + ) + ssm = ff.SSM( + ssm_config.ssm_model, + data_type=ff_data_type, + cache_path=ssm_config.cache_path, + refresh_cache=ssm_config.refresh_cache, + output_file=configs.output_file, + ) + ssms.append(ssm) + return ssms + + def compile_and_start(self, generation_config, max_requests_per_batch, max_seq_length, max_tokens_per_batch): + + # Compile the SSMs for inference and load the weights into memory + for ssm in self.ssms: + ssm.compile( + generation_config, + max_requests_per_batch, + max_seq_length, + max_tokens_per_batch, + ) + + # Compile the LLM for inference and load the weights into memory + self.llm.compile( + generation_config, + max_requests_per_batch, + max_seq_length, + max_tokens_per_batch, + ssms = self.ssms + ) + # start server + self.llm.start_server() + + def generate(self, prompt): + results = self.llm.generate(prompt) + if isinstance(results, list): + result_txt = results[0].output_text.decode('utf-8') + else: + result_txt = results.output_text.decode('utf-8') + return result_txt + + def stop_server(self): + self.llm.stop_server() + + def __enter__(self): + return self.llm.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + return self.llm.__exit__(exc_type, exc_value, traceback) + +class FF_LLM_wrapper(LLM): + flexflow_llm: FlexFlowLLM + + @property + def _llm_type(self) -> str: + return "custom" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> str: + if stop is not None: + raise ValueError("stop kwargs are not permitted.") + response = self.flexflow_llm.generate(prompt) + return response + + +if __name__ == "__main__": + # initialization + ff_llm = FlexFlowLLM() + + # compile and start server + gen_config = ff.GenerationConfig(do_sample=False, temperature=0.9, topp=0.8, topk=1) + ff_llm.compile_and_start( + gen_config, + max_requests_per_batch=1, + max_seq_length=256, + max_tokens_per_batch=200 + ) + + # the wrapper class serves as the 'Model' in LCEL + ff_llm_wrapper = FF_LLM_wrapper(flexflow_llm=ff_llm) + + # USE CASE 2: Rag Search + + # Load web page content + loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/") + data = loader.load() + + # Split text + text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0) + all_splits = text_splitter.split_documents(data) + + # Initialize embeddings + embeddings = OpenAIEmbeddings(openai_api_key=os.getenv('OPENAI_API_KEY')) # fill in openai api key + + # Create VectorStore + vectorstore = Chroma.from_documents(all_splits, embeddings) + + # Use VectorStore as a retriever + retriever = vectorstore.as_retriever() + + # Test if similarity search is working + question = "What are the approaches to Task Decomposition?" + docs = vectorstore.similarity_search(question) + max_chars_per_doc = 50 + # docs_text_list = [docs[i].page_content for i in range(len(docs))] + docs_text_list = [docs[i].page_content[:max_chars_per_doc] for i in range(len(docs))] + docs_text = ''.join(docs_text_list) + + # Using a Prompt Template + prompt_rag = PromptTemplate.from_template( + "Summarize the main themes in these retrieved docs: {docs_text}" + ) + + # Chain + llm_chain_rag = LLMChain(llm=ff_llm_wrapper, prompt=prompt_rag) + + # Run + rag_result = llm_chain_rag(docs_text) + + # stop the server + ff_llm.stop_server() diff --git a/tests/training_tests.sh b/tests/training_tests.sh index 2d1f00883b..a6cab7d117 100755 --- a/tests/training_tests.sh +++ b/tests/training_tests.sh @@ -2,6 +2,9 @@ set -x set -e +# Enable backtrace in case we run into a segfault or assertion failure +export LEGION_BACKTRACE=1 + # Default to single-node, single GPU GPUS=${1:-1} # number of GPUS per node NUM_NODES=${2:-1} # number of nodes @@ -87,3 +90,4 @@ $EXE "$FF_HOME"/examples/python/keras/func_cifar10_cnn_concat.py -config-file /t $EXE "$FF_HOME"/examples/python/keras/func_cifar10_cnn_concat_model.py -config-file /tmp/flexflow/training_tests/test_params.json $EXE "$FF_HOME"/examples/python/keras/func_cifar10_cnn_concat_seq_model.py -config-file /tmp/flexflow/training_tests/test_params.json $EXE "$FF_HOME"/examples/python/native/cifar10_cnn_concat.py -config-file /tmp/flexflow/training_tests/test_params_40_epochs_no_batch_size.json +