diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 6721e9afff..d4e1830884 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -342,6 +342,7 @@ jobs: - flytekit-kf-mpi - flytekit-kf-pytorch - flytekit-kf-tensorflow + - flytekit-memray - flytekit-mlflow - flytekit-mmcloud - flytekit-modin diff --git a/plugins/flytekit-memray/README.md b/plugins/flytekit-memray/README.md new file mode 100644 index 0000000000..a231fb2a1b --- /dev/null +++ b/plugins/flytekit-memray/README.md @@ -0,0 +1,54 @@ +# Memray Profiling Plugin + +Memray tracks and reports memory allocations, both in python code and in compiled extension modules. +This Memray Profiling plugin enables memory tracking on the Flyte task level and renders a memgraph profiling graph on Flyte Deck. + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-memray +``` + +Example +```python +from flytekit import workflow, task, ImageSpec +from flytekitplugins.memray import memray_profiling +import time + + +image = ImageSpec( + name="memray_demo", + packages=["flytekitplugins_memray"], + registry="", +) + + +def generate_data(n: int): + leak_list = [] + for _ in range(n): # Arbitrary large number for demonstration + large_data = " " * 10**6 # 1 MB string + leak_list.append(large_data) # Keeps appending without releasing + time.sleep(0.1) # Slow down the loop to observe memory changes + + +@task(container_image=image, enable_deck=True) +@memray_profiling(memray_html_reporter="table") +def memory_usage(n: int) -> str: + generate_data(n=n) + + return "Well" + + +@task(container_image=image, enable_deck=True) +@memray_profiling(trace_python_allocators=True, memray_reporter_args=["--leaks"]) +def memory_leakage(n: int) -> str: + generate_data(n=n) + + return "Well" + + +@workflow +def wf(n: int = 500): + memory_usage(n=n) + memory_leakage(n=n) +``` diff --git a/plugins/flytekit-memray/flytekitplugins/memray/__init__.py b/plugins/flytekit-memray/flytekitplugins/memray/__init__.py new file mode 100644 index 0000000000..e70d47ebe4 --- /dev/null +++ b/plugins/flytekit-memray/flytekitplugins/memray/__init__.py @@ -0,0 +1,15 @@ +""" +.. currentmodule:: flytekitplugins.wandb + +This package contains things that are useful when extending Flytekit. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + wandb_init +""" + +from .profiling import memray_profiling + +__all__ = ["memray_profiling"] diff --git a/plugins/flytekit-memray/flytekitplugins/memray/profiling.py b/plugins/flytekit-memray/flytekitplugins/memray/profiling.py new file mode 100644 index 0000000000..7957804e23 --- /dev/null +++ b/plugins/flytekit-memray/flytekitplugins/memray/profiling.py @@ -0,0 +1,99 @@ +import os +import time +from typing import Callable, List, Optional + +import memray +from flytekit import Deck +from flytekit.core.utils import ClassDecorator + + +class memray_profiling(ClassDecorator): + def __init__( + self, + task_function: Optional[Callable] = None, + native_traces: bool = False, + trace_python_allocators: bool = False, + follow_fork: bool = False, + memory_interval_ms: int = 10, + memray_html_reporter: str = "flamegraph", + memray_reporter_args: Optional[List[str]] = None, + ): + """Memray profiling plugin. + Args: + task_function (function, optional): The user function to be decorated. Defaults to None. + native_traces (bool): Whether or not to capture native stack frames, in addition to Python stack frames (see [Native tracking](https://bloomberg.github.io/memray/run.html#native-tracking)) + trace_python_allocators (bool): Whether or not to trace Python allocators as independent allocations. (see [Python allocators](https://bloomberg.github.io/memray/python_allocators.html#python-allocators)) + follow_fork (bool): Whether or not to continue tracking in a subprocess that is forked from the tracked process (see [Tracking across forks](https://bloomberg.github.io/memray/run.html#tracking-across-forks)) + memory_interval_ms (int): How many milliseconds to wait between sending periodic resident set size updates. + By default, every 10 milliseconds a record is written that contains the current timestamp and the total number of bytes of virtual memory allocated by the process. + These records are used to create the graph of memory usage over time that appears at the top of the flame graph, for instance. + This parameter lets you adjust the frequency between updates, though you shouldn't need to change it. + memray_html_reporter (str): The name of the memray reporter which generates an html report. + Today there is only 'flamegraph' & 'table'. + memray_reporter_args (List[str], optional): A list of arguments to pass to the reporter commands. + See the [flamegraph](https://bloomberg.github.io/memray/flamegraph.html#reference) + and [table](https://bloomberg.github.io/memray/table.html#cli-reference) docs for details on supported arguments. + """ + + if memray_html_reporter not in ["flamegraph", "table"]: + raise ValueError(f"{memray_html_reporter} is not a supported html reporter.") + + if memray_reporter_args is not None and not all( + isinstance(arg, str) and "--" in arg for arg in memray_reporter_args + ): + raise ValueError( + f"unrecognized arguments for {memray_html_reporter} reporter. Please check https://bloomberg.github.io/memray/{memray_html_reporter}.html" + ) + + self.native_traces = native_traces + self.trace_python_allocators = trace_python_allocators + self.follow_fork = follow_fork + self.memory_interval_ms = memory_interval_ms + self.dir_name = "memray" + self.memray_html_reporter = memray_html_reporter + self.memray_reporter_args = memray_reporter_args if memray_reporter_args else [] + + super().__init__( + task_function, + native_traces=native_traces, + trace_python_allocators=trace_python_allocators, + follow_fork=follow_fork, + memory_interval_ms=memory_interval_ms, + memray_html_reporter=memray_html_reporter, + memray_reporter_args=memray_reporter_args, + ) + + def execute(self, *args, **kwargs): + if not os.path.exists(self.dir_name): + os.makedirs(self.dir_name) + + bin_filepath = f"{self.dir_name}/{self.task_function.__name__}.{time.strftime('%Y%m%d%H%M%S')}.bin" + + with memray.Tracker( + bin_filepath, + native_traces=self.native_traces, + trace_python_allocators=self.trace_python_allocators, + follow_fork=self.follow_fork, + memory_interval_ms=self.memory_interval_ms, + ): + output = self.task_function(*args, **kwargs) + + self.generate_flytedeck_html(reporter=self.memray_html_reporter, bin_filepath=bin_filepath) + + return output + + def generate_flytedeck_html(self, reporter, bin_filepath): + html_filepath = bin_filepath.replace( + self.task_function.__name__, f"{reporter}.{self.task_function.__name__}" + ).replace(".bin", ".html") + + memray_reporter_args_str = " ".join(self.memray_reporter_args) + + if os.system(f"memray {reporter} -o {html_filepath} {memray_reporter_args_str} {bin_filepath}") == 0: + with open(html_filepath, "r", encoding="utf-8") as file: + html_content = file.read() + + Deck(f"Memray {reporter.capitalize()}", html_content) + + def get_extra_config(self): + return {} diff --git a/plugins/flytekit-memray/setup.py b/plugins/flytekit-memray/setup.py new file mode 100644 index 0000000000..10f8513b04 --- /dev/null +++ b/plugins/flytekit-memray/setup.py @@ -0,0 +1,37 @@ +from setuptools import setup + +PLUGIN_NAME = "memray" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.12.0", "memray"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package enables memory profiling for tasks with memray", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-memray/tests/test_memray_profiling.py b/plugins/flytekit-memray/tests/test_memray_profiling.py new file mode 100644 index 0000000000..776ddd4219 --- /dev/null +++ b/plugins/flytekit-memray/tests/test_memray_profiling.py @@ -0,0 +1,39 @@ +from unittest.mock import Mock, patch +import pytest +from flytekit import task, current_context +from flytekitplugins.memray import memray_profiling + + +@task(enable_deck=True) +@memray_profiling +def heavy_compute(i: int) -> int: + return i + 1 + + +def test_local_exec(): + heavy_compute(i=7) + assert ( + len(current_context().decks) == 6 + ) # memray flamegraph, timeline, input, and output, source code, dependencies + + +def test_errors(): + reporter = "summary" + with pytest.raises( + ValueError, match=f"{reporter} is not a supported html reporter." + ): + memray_profiling(memray_html_reporter=reporter) + + reporter = "flamegraph" + with pytest.raises( + ValueError, + match=f"unrecognized arguments for {reporter} reporter. Please check https://bloomberg.github.io/memray/{reporter}.html", + ): + memray_profiling(memray_reporter_args=["--leaks", "trash"]) + + reporter = "flamegraph" + with pytest.raises( + ValueError, + match=f"unrecognized arguments for {reporter} reporter. Please check https://bloomberg.github.io/memray/{reporter}.html", + ): + memray_profiling(memray_reporter_args=[0, 1, 2])