Skip to content

Commit

Permalink
rename memray_profiling
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Fiedler <[email protected]>
  • Loading branch information
fiedlerNr9 committed Oct 29, 2024
1 parent 6729cf3 commit 916687e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 30 deletions.
4 changes: 2 additions & 2 deletions plugins/flytekit-memray/flytekitplugins/memray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
wandb_init
"""

from .profiling import mem_profiling
from .profiling import memray_profiling

__all__ = ["mem_profiling"]
__all__ = ["memray_profiling"]
58 changes: 30 additions & 28 deletions plugins/flytekit-memray/flytekitplugins/memray/profiling.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,72 @@
import os
from typing import Callable, Optional
from typing import Callable, Optional, List
import memray
import time
from flytekit.core.utils import ClassDecorator
from flytekit import Deck


class mem_profiling(ClassDecorator):
class memray_profiling(ClassDecorator):

def __init__(
self,
task_function: Optional[Callable] = None,
**init_kwargs: dict,
memray_html_reporter: str = "flamegraph",
memray_reporter_args: Optional[List[str]] = [],
):
"""Memray Profiling Plugin.
Args:
"""
self.init_kwargs = init_kwargs
if memray_html_reporter not in ["flamegraph", "table"]:
raise ValueError(
f"{memray_html_reporter} is not a supported html reporter."
)

self.dir_name = "memray"
self.memray_html_reporter = memray_html_reporter
self.memray_reporter_args = memray_reporter_args

# All kwargs need to be passed up so that the function wrapping works for both
# `@wandb_init` and `@wandb_init(...)`
super().__init__(
task_function,
**init_kwargs,
memray_html_reporter=memray_html_reporter,
memray_reporter_args=memray_reporter_args,
)

def execute(self, *args, **kwargs):

dir_name = "memray"
memray_html_reporter = ["flamegraph", "table"]

if not os.path.exists(dir_name):
os.makedirs(dir_name)
if not os.path.exists(self.dir_name):
os.makedirs(self.dir_name)

bin_filepath = f"{dir_name}/{self.task_function.__name__}.{time.strftime('%Y%m%d%H%M%S')}.bin"
bin_filepath = f"{self.dir_name}/{self.task_function.__name__}.{time.strftime('%Y%m%d%H%M%S')}.bin"

with memray.Tracker(bin_filepath):
output = self.task_function(*args, **kwargs)

for reporter in memray_html_reporter:
self.generate_flytedeck_html(reporter=reporter, bin_filepath=bin_filepath)
self.generate_flytedeck_html(
reporter=self.memray_html_reporter, bin_filepath=bin_filepath
)

return output

def generate_flytedeck_html(self, reporter, bin_filepath):
html_reporter_constants = [
"packed_data",
"merge_threads",
"memory_records",
# "inverted",
"temporal",
]
html_filepath = bin_filepath.replace(
self.task_function.__name__, f"{reporter}.{self.task_function.__name__}"
).replace(".bin", ".html")
os.system(f"memray {reporter} -o {html_filepath} {bin_filepath}")
with open(html_filepath, "r", encoding="utf-8") as file:
html_content = file.read()

for constant in html_reporter_constants:
html_content = html_content.replace(f"{constant}", f"{reporter}_{constant}")
memray_reporter_args_str = " ".join(self.memray_reporter_args)

# with open("output.html", "w") as f:
# f.write(html_content)
if (
os.system(
f"memray {reporter} -o {html_filepath} {memray_reporter_args_str} {bin_filepath}"
)
== 0
):
with open(html_filepath, "r", encoding="utf-8") as file:
html_content = file.read()

Deck(f"Memray {reporter.capitalize()}", html_content)
Deck(f"Memray {reporter.capitalize()}", html_content)

def get_extra_config(self):
return {}

0 comments on commit 916687e

Please sign in to comment.