Skip to content

Commit

Permalink
Support LLM pipelines in CPU-only mode (#1906)
Browse files Browse the repository at this point in the history
* Works-around the issue where CPU-only mode requires using the Python impl of `MessageMeta` a pandas DF, however the `LLMEngineStage` is implemented in C++ and only compatible with the C++ impl of `MessageMeta` with a cudf DF.
* Stores the Python impl of `MessageMeta` within the `ControlMessage` metadata which is able to store a Python object as-is.
* Updates the Simple Agents & Completion pipelines to optionally execute in CPU-only mode when the `--use_cpu_only` flag is given

Requires PR #1851 to be merged first

## By Submitting this PR I confirm:
- I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md).
- When the PR is ready for review, new or existing tests cover these changes.
- When the PR is ready for review, the documentation is up to date with these changes.

Authors:
  - David Gardner (https://github.com/dagardner-nv)
  - Yuchen Zhang (https://github.com/yczhang-nv)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #1906
  • Loading branch information
dagardner-nv authored Oct 18, 2024
1 parent e13e345 commit 85d5ad4
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 28 deletions.
1 change: 1 addition & 0 deletions examples/llm/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def run():


@run.command(help="Runs a simple finite pipeline with a single execution of a LangChain agent from a fixed input")
@click.option('--use_cpu_only', default=False, type=bool, is_flag=True, help=("Whether or not to run in CPU only mode"))
@click.option(
"--num_threads",
default=len(os.sched_getaffinity(0)),
Expand Down
10 changes: 6 additions & 4 deletions examples/llm/agents/simple_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,29 @@
import logging
import time

import cudf

from morpheus.config import Config
from morpheus.config import ExecutionMode
from morpheus.config import PipelineModes
from morpheus.pipeline.linear_pipeline import LinearPipeline
from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage
from morpheus.utils.concat_df import concat_dataframes
from morpheus.utils.type_utils import get_df_class

from .common import build_common_pipeline

logger = logging.getLogger(__name__)


def pipeline(
use_cpu_only: bool,
num_threads: int,
pipeline_batch_size,
model_max_batch_size,
model_name,
repeat_count,
) -> float:
config = Config()
config.execution_mode = ExecutionMode.CPU if use_cpu_only else ExecutionMode.GPU
config.mode = PipelineModes.OTHER

# Below properties are specified by the command line
Expand All @@ -45,9 +47,9 @@ def pipeline(
config.mode = PipelineModes.NLP
config.edge_buffer_size = 128

df_class = get_df_class(config.execution_mode)
source_dfs = [
cudf.DataFrame(
{"questions": ["Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?"]})
df_class({"questions": ["Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?"]})
]

completion_task = {"task_type": "completion", "task_dict": {"input_keys": ["questions"], }}
Expand Down
14 changes: 9 additions & 5 deletions examples/llm/completion/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import logging
import time

import cudf

from morpheus.config import Config
from morpheus.config import ExecutionMode
from morpheus.config import PipelineModes
from morpheus.io.deserializers import read_file_to_df
from morpheus.pipeline.linear_pipeline import LinearPipeline
Expand All @@ -26,6 +25,8 @@
from morpheus.stages.output.in_memory_sink_stage import InMemorySinkStage
from morpheus.stages.preprocess.deserialize_stage import DeserializeStage
from morpheus.utils.concat_df import concat_dataframes
from morpheus.utils.type_utils import exec_mode_to_df_type_str
from morpheus.utils.type_utils import get_df_class
from morpheus_llm.llm import LLMEngine
from morpheus_llm.llm.nodes.extracter_node import ExtracterNode
from morpheus_llm.llm.nodes.llm_generate_node import LLMGenerateNode
Expand Down Expand Up @@ -71,7 +72,8 @@ def _build_engine(llm_service: str):
return engine


def pipeline(num_threads: int,
def pipeline(use_cpu_only: bool,
num_threads: int,
pipeline_batch_size: int,
model_max_batch_size: int,
repeat_count: int,
Expand All @@ -80,6 +82,7 @@ def pipeline(num_threads: int,
shuffle: bool = False) -> float:

config = Config()
config.execution_mode = ExecutionMode.CPU if use_cpu_only else ExecutionMode.GPU

# Below properties are specified by the command line
config.num_threads = num_threads
Expand All @@ -89,9 +92,10 @@ def pipeline(num_threads: int,
config.edge_buffer_size = 128

if input_file is not None:
source_df = read_file_to_df(input_file, df_type='cudf')
source_df = read_file_to_df(input_file, df_type=exec_mode_to_df_type_str(config.execution_mode))
else:
source_df = cudf.DataFrame({
df_class = get_df_class(config.execution_mode)
source_df = df_class({
"country": [
"France",
"Spain",
Expand Down
1 change: 1 addition & 0 deletions examples/llm/completion/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def run():


@run.command()
@click.option('--use_cpu_only', default=False, type=bool, is_flag=True, help=("Whether or not to run in CPU only mode"))
@click.option(
"--num_threads",
default=len(os.sched_getaffinity(0)),
Expand Down
8 changes: 6 additions & 2 deletions python/morpheus_llm/morpheus_llm/llm/nodes/extracter_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np

from morpheus.messages import MessageMeta
from morpheus_llm.llm import LLMContext
from morpheus_llm.llm import LLMNodeBase

Expand Down Expand Up @@ -59,7 +60,9 @@ async def execute(self, context: LLMContext) -> LLMContext: # pylint: disable=i
# Get the keys from the task
input_keys: list[str] = typing.cast(list[str], context.task()["input_keys"])

with context.message().payload().mutable_dataframe() as df:
meta: MessageMeta = context.message().get_metadata("llm_message_meta")

with meta.mutable_dataframe() as df:
input_dict: list[dict] = df[input_keys].to_dict(orient="list")

input_dict = _array_to_list(input_dict)
Expand Down Expand Up @@ -95,7 +98,8 @@ def get_input_names(self) -> list[str]:
async def execute(self, context: LLMContext) -> LLMContext: # pylint: disable=invalid-overridden-method

# Get the data from the DataFrame
with context.message().payload().mutable_dataframe() as df:
meta: MessageMeta = context.message().get_metadata("llm_message_meta")
with meta.mutable_dataframe() as df:
input_dict: list[dict] = df[self._input_names].to_dict(orient="list")

input_dict = _array_to_list(input_dict)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging

from morpheus.messages import ControlMessage
from morpheus.messages import MessageMeta
from morpheus_llm.llm import LLMContext
from morpheus_llm.llm import LLMTaskHandler

Expand Down Expand Up @@ -48,7 +49,8 @@ async def try_handle(self, context: LLMContext) -> list[ControlMessage]:

input_dict = context.get_inputs()

with context.message().payload().mutable_dataframe() as df:
meta: MessageMeta = context.message().get_metadata("llm_message_meta")
with meta.mutable_dataframe() as df:
# Write the values to the dataframe
for key, value in input_dict.items():
df[key] = value
Expand Down
84 changes: 68 additions & 16 deletions python/morpheus_llm/morpheus_llm/stages/llm/llm_engine_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@
from mrc.core import operators as ops

from morpheus.config import Config
from morpheus.config import CppConfig
from morpheus.config import ExecutionMode
from morpheus.messages import ControlMessage
from morpheus.pipeline.execution_mode_mixins import GpuAndCpuMixin
from morpheus.pipeline.pass_thru_type_mixin import PassThruTypeMixin
from morpheus.pipeline.single_port_stage import SinglePortStage
from morpheus_llm.llm import LLMEngine

logger = logging.getLogger(__name__)


class LLMEngineStage(PassThruTypeMixin, SinglePortStage):
class LLMEngineStage(PassThruTypeMixin, GpuAndCpuMixin, SinglePortStage):
"""
Stage for executing an LLM engine within a Morpheus pipeline.
Expand All @@ -52,44 +53,95 @@ def name(self) -> str:
"""Return the name of the stage"""
return "llm-engine"

def accepted_types(self) -> typing.Tuple:
def accepted_types(self) -> tuple:
"""
Returns accepted input types for this stage.
Returns
-------
typing.Tuple(`ControlMessage`, )
tuple(`ControlMessage`, )
Accepted input types.
"""
return (ControlMessage, )

def supports_cpp_node(self):
def supports_cpp_node(self) -> bool:
"""Indicates whether this stage supports a C++ node."""
return True

def _cast_control_message(self, message: ControlMessage, *, cpp_messages_lib: types.ModuleType) -> ControlMessage:
def _store_payload(self, message: ControlMessage) -> ControlMessage:
"""
LLMEngineStage does not contain a Python implementation, however it is capable of running in Python/cpu-only
mode. This method is needed to cast the Python ControlMessage to a C++ ControlMessage.
Store the MessageMeta in the ControlMessage's metadata.
In CPU-only allows the ControlMessage to hold an instance of a Python MessageMeta containing a pandas DataFrame.
"""
message.set_metadata("llm_message_meta", message.payload())
return message

def _copy_tasks_and_metadata(self,
src: ControlMessage,
dst: ControlMessage,
metadata: dict[str, typing.Any] = None):
if metadata is None:
metadata = src.get_metadata()

for (key, value) in metadata.items():
dst.set_metadata(key, value)

tasks = src.get_tasks()
for (task, task_value) in tasks.items():
for tv in task_value:
dst.add_task(task, tv)

def _cast_to_cpp_control_message(self, py_message: ControlMessage, *,
cpp_messages_lib: types.ModuleType) -> ControlMessage:
"""
LLMEngineStage does not contain a Python implementation, however it is capable of running in cpu-only mode.
This method is needed to create an instance of a C++ ControlMessage.
This is different than casting from the Python bindings for the C++ ControlMessage to a C++ ControlMessage.
"""
return cpp_messages_lib.ControlMessage(message)
cpp_message = cpp_messages_lib.ControlMessage()
self._copy_tasks_and_metadata(py_message, cpp_message)

return cpp_message

def _restore_payload(self, message: ControlMessage) -> ControlMessage:
"""
Pop llm_message_meta from the metadata and set it as the payload.
In CPU-only mode this has the effect of converting the C++ ControlMessage back to a Python ControlMessage.
"""
metadata = message.get_metadata()
message_meta = metadata.pop("llm_message_meta")

out_message = ControlMessage()
out_message.payload(message_meta)

self._copy_tasks_and_metadata(message, out_message, metadata=metadata)

return out_message

def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject:
import morpheus_llm._lib.llm as _llm

store_payload_node = builder.make_node(f"{self.unique_name}-store-payload", ops.map(self._store_payload))
builder.make_edge(input_node, store_payload_node)

node = _llm.LLMEngineStage(builder, self.unique_name, self._engine)
node.launch_options.pe_count = 1

if not CppConfig.get_should_use_cpp():
if self._config.execution_mode == ExecutionMode.CPU:
import morpheus._lib.messages as _messages
cast_fn = functools.partial(self._cast_control_message, cpp_messages_lib=_messages)
pre_node = builder.make_node(f"{self.unique_name}-pre-cast", ops.map(cast_fn))
builder.make_edge(input_node, pre_node)
cast_to_cpp_fn = functools.partial(self._cast_to_cpp_control_message, cpp_messages_lib=_messages)
cast_to_cpp_node = builder.make_node(f"{self.unique_name}-pre-msg-cast", ops.map(cast_to_cpp_fn))
builder.make_edge(store_payload_node, cast_to_cpp_node)
builder.make_edge(cast_to_cpp_node, node)

input_node = pre_node
else:
builder.make_edge(store_payload_node, node)

builder.make_edge(input_node, node)
restore_payload_node = builder.make_node(f"{self.unique_name}-restore-payload", ops.map(self._restore_payload))
builder.make_edge(node, restore_payload_node)

return node
return restore_payload_node
1 change: 1 addition & 0 deletions tests/morpheus_llm/llm/nodes/test_extractor_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_execute():
df = cudf.DataFrame({"insects": insects.copy(), "mammals": mammals.copy(), "reptiles": reptiles.copy()})
message = ControlMessage()
message.payload(MessageMeta(df))
message.set_metadata("llm_message_meta", message.payload())

task_dict = {"input_keys": ["mammals", "reptiles"]}
node = ExtracterNode()
Expand Down
1 change: 1 addition & 0 deletions tests/morpheus_llm/llm/nodes/test_manual_extractor_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_execute():
df = cudf.DataFrame({"insects": insects.copy(), "mammals": mammals.copy(), "reptiles": reptiles.copy()})
message = ControlMessage()
message.payload(MessageMeta(df))
message.set_metadata("llm_message_meta", message.payload())

task_dict = {"input_keys": ["insects"]}
node = ManualExtracterNode(["mammals", "reptiles"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_try_handle(dataset_cudf: DatasetManager):

message = ControlMessage()
message.payload(MessageMeta(df))
message.set_metadata("llm_message_meta", message.payload())

task_handler = SimpleTaskHandler(['reptiles'])

Expand Down

0 comments on commit 85d5ad4

Please sign in to comment.